﻿#include "UnityCG.cginc"
#include "TraceUtil.cginc"
#include "VolumeSampler.cginc"

#pragma multi_compile_local _ USE_DDA
#pragma multi_compile_local _ USE_TRANSFERFUNC
#pragma multi_compile_local _ MIP_FIXED
#pragma multi_compile_local _ USE_LIGHT_SOURCE_DIRECTIONAL USE_LIGHT_SOURCE_POINT

// --------------------------------------------------------------
// volume sampling helpers (input vectors assumed in index space!)
float3 vol_bb_min; // check
float3 vol_bb_max; // check
float3 vol_albedo; // check
float vol_phase_g; // check
int mip_range_fixed;

// density brick grid stored as textures
float4x4 vol_density_transform;
float4x4 vol_density_inv_transform;

// environment
//float3x3 env_transform;
//float3x3 env_inv_transform;
float env_strength; // check
float2 env_imp_inv_dim; // check
int env_imp_base_mip; // check

//light source
float3 light_world_pos;
float3 light_world_dir;

UNITY_DECLARE_TEX2D(_EnvMap);
UNITY_DECLARE_TEX2D(_EnvImpMap);

// temperature brick grid stored as textures
//float4x4 vol_emission_transform;
//float4x4 vol_emission_inv_transform;

//Texture3D<uint4> vol_emission_indirection; // maybe wrong alignment
//Texture3D<half2> vol_emission_range;
//Texture3D vol_emission_atlas;

float4x4 cam_to_world;
float4x4 cam_inv_proj;

// path tracing
uint bounces; // check
int show_environment;

struct ray
{
    float3 origin;
    float3 dir;
};

ray create_ray(const float3 origin, const float3 dir)
{
    ray new_ray;
    new_ray.origin = origin;
    new_ray.dir = dir;
    return new_ray;
}

// --------------------------------------------------------------
// camera helper

// replaced view_dir function
// https://www.gamedeveloper.com/programming/gpu-ray-tracing-in-unity-part-1
ray camera_ray(const uint2 xy, const uint2 resolution, const float2 pixel_sample)
{
    float2 uv = (xy + pixel_sample) / float2(resolution) * 2.0f - 1.0f;
    const float3 origin = mul(cam_to_world, float4(0, 0, 0, 1)).xyz;
    
    float3 dir = mul(cam_inv_proj, float4(uv, 0, 1)).xyz;
    dir = normalize(mul(cam_to_world, float4(dir, 0)).xyz);
    return create_ray(origin, dir);
}

/*
// brick grid voxel temperature lookup (nearest neighbor)
float lookup_temperature_brick(const float3 ipos) {
    const int3 iipos = int3(floor(ipos));
    const int3 brick = iipos >> 3;
    const float2 range = vol_emission_range[brick].xy;
    const uint3 ptr = vol_emission_indirection[brick].xyz;
    const float value_unorm = vol_emission_atlas[int3(ptr << 3) + (iipos & 7)].x;
    return range.x + value_unorm * (range.y - range.x);
}

// emission lookup (stochastic trilinear filter)
float3 lookup_emission(const float3 ipos, inout uint seed) {
    const float3 ipos_emission = (mul(vol_emission_inv_transform, mul(vol_density_transform, float4(ipos, 1.0f)))).xyz;
    const float t = lookup_temperature_brick(ipos_emission + rng3(seed) - .5f) * vol_emission_norm;
    return vol_emission_scale * sqr(float3(t, sqr(t), sqr(sqr(t))));
}
*/

// --------------------------------------------------------------
// environment helper (input vectors assumed in world space!)

float3 lookup_environment(const float3 dir) {
    const float3 idir = dir;
    //const float3 idir = mul(env_inv_transform, dir);
    const float u = atan2(idir.z, idir.x) / (2 * M_PI) + 0.5f;
    const float v = 1.f - acos(idir.y) / M_PI;
    return env_strength * UNITY_SAMPLE_TEX2D_LOD(_EnvMap, float2(u, v), 0).rgb;
}

float4 sample_environment(const float2 rng, out float3 w_i) {
    uint2 pos = uint2(0, 0);   // pixel position
    float2 p = rng;           // sub-pixel position
    // warp sample over mip hierarchy
    for (int mip = env_imp_base_mip - 1; mip >= 0; mip--) {
        pos *= 2; // scale to mip
        float w[4]; // four relevant texels
        #ifdef USE_UNITY_SAMPLER
        w[0] = UNITY_SAMPLE_TEX2D_LOD(_EnvMap, pos + uint2(0, 0), mip).r;
        w[1] = UNITY_SAMPLE_TEX2D_LOD(_EnvMap, pos + uint2(1, 0), mip).r;
        w[2] = UNITY_SAMPLE_TEX2D_LOD(_EnvMap, pos + uint2(0, 1), mip).r;
        w[3] = UNITY_SAMPLE_TEX2D_LOD(_EnvMap, pos + uint2(1, 1), mip).r;
        #else
        w[0] = _EnvMap.Load((pos + uint2(0, 0), mip)).r;
        w[1] = _EnvMap.Load((pos + uint2(1, 0), mip)).r;
        w[2] = _EnvMap.Load((pos + uint2(0, 1), mip)).r;
        w[3] = _EnvMap.Load((pos + uint2(1, 1), mip)).r;
        #endif
        float q[2]; // bottom / top
        q[0] = w[0] + w[2];
        q[1] = w[1] + w[3];
        // horizontal
        int off_x;
        const float d = q[0] / max(1e-8f, q[0] + q[1]);
        if (p.x < d) { // left
            off_x = 0;
            p.x = p.x / d;
        } else { // right
            off_x = 1;
            p.x = (p.x - d) / (1.f - d);
        }
        pos.x += off_x;
        // vertical
        float e = w[off_x] / q[off_x];
        if (p.y < e) { // bottom
            //pos.y += 0;
            p.y = p.y / e;
        } else { // top
            pos.y += 1;
            p.y = (p.y - e) / (1.f - e);
        }
    }
    // compute sample uv coordinate and (world-space) direction
    const float2 uv = (float2(pos) + p) * env_imp_inv_dim;
    const float theta = saturate(1.f - uv.y) * M_PI;
    const float phi   = (saturate(uv.x) * 2.0f - 1.f) * M_PI;
    const float sin_t = sin(theta);
    w_i = float3(sin_t * cos(phi), cos(theta), sin_t * sin(phi));
    //w_i = mul(env_transform, float3(sin_t * cos(phi), cos(theta), sin_t * sin(phi)));
    // sample envmap and compute pdf
    
    const float3 Le = env_strength * _EnvMap[uv].rgb;
    
    #ifdef USE_LIGHT_SOURCE_DIRECTIONAL
    const float pdf = environment_light() 
    #else
    const float avg_w = _EnvImpMap.Load(uint3(0, 0, env_imp_base_mip)).r;
    const float pdf = _EnvImpMap[pos].r / avg_w;
    #endif
    
    return float4(Le, pdf * INV_4_PI);
}

float pdf_environment(const float3 dir) {
    #ifdef USE_UNITY_SAMPLER
    const float avg_w = UNITY_SAMPLE_TEX2D_LOD(_EnvImpMap, uint2(0, 0), env_imp_base_mip).r;
    #else
    const float avg_w = _EnvImpMap.Load(uint3(0, 0, env_imp_base_mip)).r;
    #endif
    const float pdf = luma(lookup_environment(dir)) / avg_w;
    return pdf * INV_4_PI;
}

float environment_light(const float3 pos, const float3 dir)
{
    #ifdef USE_LIGHT_SOURCE_DIRECTIONAL
    return 0;
    #else
    return pdf_environment(dir);
    #endif
}

// --------------------------------------------------------------
// transfer function helper

#ifdef USE_TRANSFERFUNC

StructuredBuffer<float4> tf_lut;

uint tf_size;
float tf_window_left;
float tf_window_width;

float tf_window(const float d) {
    return clamp((d - tf_window_left) / tf_window_width, 0.0, 1.0 - 1e-6);
}

float4 tf_lookup(const float d) {
    const float tc = tf_window(d);
    const int idx = int(floor(tc * tf_size));
    return tf_lut[idx];
}

#endif

// --------------------------------------------------------------
// box intersect helper

bool intersect_box(const ray ray, const float3 bb_min, const float3 bb_max, out float2 near_far) {
    const float3 inv_dir = 1.f / ray.dir;
    const float3 lo = (bb_min - ray.origin) * inv_dir;
    const float3 hi = (bb_max - ray.origin) * inv_dir;
    const float3 tmin = min(lo, hi), tmax = max(lo, hi);
    near_far.x = max(0.f, max(tmin.x, max(tmin.y, tmin.z)));
    near_far.y = min(tmax.x, min(tmax.y, tmax.z));
    return near_far.x <= near_far.y;
}

// --------------------------------------------------------------
// null-collision methods

float transmittance(const ray world_ray, inout uint seed, float t_max = 3.402823e+38f) {
    // clip volume
    float2 near_far;
    if (!intersect_box(world_ray, vol_bb_min, vol_bb_max, near_far)) return 1.f;
    near_far.y = min(t_max, near_far.y);
    // to index-space
    const ray index_ray = create_ray(
        mul(vol_density_inv_transform, float4(world_ray.origin, 1)).xyz,
        mul(vol_density_inv_transform, float4(world_ray.dir, 0)).xyz // non-normalized!
    );
    // ratio tracking
    float t = near_far.x - log(1 - rng(seed)) * vol_inv_majorant, Tr = 1.f;
    while (t < near_far.y) {
#ifdef USE_TRANSFERFUNC
        const float4 rgba = tf_lookup(lookup_density_trilinear(index_ray.origin + t * index_ray.dir) * vol_inv_majorant);
        const float d = vol_majorant * rgba.a;
#else
        const float d = lookup_density_stochastic(index_ray.origin + t * index_ray.dir, seed);
#endif
        // track ratio of real to null particles
        Tr *= 1 - d * vol_inv_majorant;
        // russian roulette
        if (Tr < .1f) {
            const float prob = 1 - Tr;
            if (rng(seed) < prob) return 0.f;
            Tr /= 1 - prob;
        }
        // advance
        t -= log(1 - rng(seed)) * vol_inv_majorant;
    }
    return Tr;
}

bool sample_volume(const ray world_ray, out float t, inout float3 throughput, inout float3 Le, inout uint seed) {
    // clip volume
    float2 near_far;
    if (!intersect_box(world_ray, vol_bb_min, vol_bb_max, near_far)) return false;
    
    // if(near_far.x < 0.5f) return false;
    // if(t < 0.5f) return false;
    
    // to index-space
    const ray index_ray = create_ray(
        mul(vol_density_inv_transform, float4(world_ray.origin, 1)).xyz,
        mul(vol_density_inv_transform, float4(world_ray.dir, 0)).xyz // non-normalized!
    );
    // delta tracking
    t = near_far.x - log(1 - rng(seed)) * vol_inv_majorant;
    
    // if(t < 0.5f) return false;
    
    while (t < near_far.y) {
#ifdef USE_TRANSFERFUNC
        const float4 rgba = tf_lookup(lookup_density_trilinear(index_ray.origin + t * index_ray.dir) * vol_inv_majorant);
        const float d = vol_majorant * rgba.a;
#else
        const float d = lookup_density_stochastic(index_ray.origin + t * index_ray.dir, seed);
#endif
        const float P_real = d * vol_inv_majorant;
        Le += throughput * (1 - vol_albedo) * P_real;
        //Le += throughput * (1 - vol_albedo) * lookup_emission(ipos + t * idir, seed) * P_real;
        // classify as real or null collison
        if (rng(seed) < P_real) {
#ifdef USE_TRANSFERFUNC
            throughput *= rgba.rgb * vol_albedo;
#else
            throughput *= vol_albedo;
#endif
            return true;
        }
        // advance
        t -= log(1 - rng(seed)) * vol_inv_majorant;
    }
    return false;
}

// --------------------------------------------------------------
// DDA-based null-collision methods

#define MIP_START 3
#define MIP_SPEED_UP 0.25
#define MIP_SPEED_DOWN 2

// perform DDA step on given mip level
float stepDDA(const float3 pos, const float3 inv_dir, const int mip) {
    const float dim = 8 << mip;
    const float3 offs = lerp(
        float3(-0.5f, -0.5f, -0.5f),
        float3(dim + 0.5f, dim + 0.5f, dim + 0.5f),
        step(float3(0, 0, 0), inv_dir));
    const float3 tmax = (floor(pos * (1.0f / dim)) * dim + offs - pos) * inv_dir;
    return min(min(tmax.x, tmax.y), tmax.z);
}

float stepDDA2(const float3 ro, const float3 ri, float3 pos, int mip)
{
    const float dim = 8 << mip;
    const float3 ofs = ri * (((ri >= 0.f) ? dim + 0.5f : -0.5f) - ro);
    const float3 tmax = floor(pos * (1.f / dim)) * dim * ri + ofs;
    return min(tmax.x, min(tmax.y, tmax.z));
}

float transmittanceDDA2(const ray world_ray, inout uint seed)
{
    // clip volume
    float2 near_far;
    if (!intersect_box(world_ray, vol_bb_min, vol_bb_max, near_far)) return 1.0f;
    // to index-space
    const ray index_ray = create_ray(
        mul(vol_density_inv_transform, float4(world_ray.origin, 1)).xyz,
        mul(vol_density_inv_transform, float4(world_ray.dir, 0)).xyz // non-normalized!
    );
    float tauToGo = -log(1.f - rng(seed));
    float Tr = 1.0f;
    float mip = MIP_START;
    float t = near_far.x + 1e-6f;
    float3 invRayDir = 1.0 / index_ray.dir;
    while (t < near_far.y)
    {
        const float3 pos = index_ray.origin + t * index_ray.dir;
        #ifdef USE_TRANSFERFUNC
        const float majorant = tf_lookup(lookup_majorant(pos, round(mip))).a;
        #else
        const float majorant = lookup_majorant(pos, round(mip));
        #endif
        float nextt = stepDDA2(index_ray.origin, invRayDir, pos, round(mip));
        
        mip = min(mip + MIP_SPEED_UP, 3.0f);
        float dt = nextt - t;
        float dtau = majorant * dt;
        t = nextt;
        tauToGo -= dtau;
        if (tauToGo > 0) continue;
        mip = max(0.0f, mip - MIP_SPEED_DOWN);
        t += dt * tauToGo / dtau;
        if (t >= near_far.y) break;
        
        #ifdef USE_TRANSFERFUNC
        const float4 rgba = tf_lookup(lookup_density_trilinear(index_ray.origin + t * index_ray.dir) * vol_inv_majorant);
        const float density = vol_majorant * rgba.a;
        #else
        const float density = lookup_density_stochastic(index_ray.origin + t * index_ray.dir, seed);
        #endif
        if (rng(seed) * majorant < density)
        {
            Tr *= max(0.f, 1.f - vol_majorant / majorant); // adjust by ratio of global to local majorant
            // russian roulette
            if (Tr < 0.1f) {
                const float prob = 1 - Tr;
                if (rng(seed) < prob) return 0.f;
                Tr /= 1.0f - prob;
            }
        }
        tauToGo = -log(1.f - rng(seed));
    }
    return Tr;
}

// DDA-based volume sampling
bool sample_volume_dda2(const ray world_ray, out float t, inout float3 throughput, inout float3 Le, inout uint seed) {
    // clip volume
    t = 0.0f;
    float2 near_far;
    if (!intersect_box(world_ray, vol_bb_min, vol_bb_max, near_far)) return false;
    // to index-space
    const ray index_ray = create_ray(
        mul(vol_density_inv_transform, float4(world_ray.origin, 1)).xyz,
        mul(vol_density_inv_transform, float4(world_ray.dir, 0)).xyz // non-normalized!
    );
    const float3 invRayDir = 1.f / index_ray.dir;
    // march brick grid
    t = near_far.x + 1e-6f;
    float tauToGo = -log(1.0f - rng(seed));

    #ifdef MIP_FIXED
    float mip = mip_range_fixed;
    #else
    float mip = MIP_START;
    #endif
    
    while (t < near_far.y) {
        const float3 pos = index_ray.origin + t * index_ray.dir;
#ifdef USE_TRANSFERFUNC
        const float majorant = vol_majorant * tf_lookup(lookup_majorant(pos, int(round(mip))) * vol_inv_majorant).a;
#else
        const float majorant = lookup_majorant(pos, int(round(mip)));
#endif
        const float nextt = stepDDA2(index_ray.origin, invRayDir, pos, int(round(mip)));
        
        mip = min(mip + MIP_SPEED_UP, 3.f);
        
        const float dt = nextt - t;
        const float dtau = majorant * dt;
        t = nextt;
        tauToGo -= dtau;
        if (tauToGo > 0) continue; // no collision, step ahead
        mip = max(0.0, mip - MIP_SPEED_DOWN);
        t += dt * tauToGo / dtau; // step back to point of collision
        if (t >= near_far.y) break;
#ifdef USE_TRANSFERFUNC
        const float4 rgba = tf_lookup(lookup_density_trilinear(index_ray.origin + t * index_ray.dir) * vol_inv_majorant);
        const float density = vol_majorant * rgba.a;
#else
        const float density = lookup_density_stochastic(index_ray.origin + t * index_ray.dir, seed);
#endif
        //Le += throughput * (1.f - vol_albedo) * lookup_emission(ipos + t * idir, seed) * d * vol_inv_majorant;
        Le += throughput * (1.f - vol_albedo) * density * vol_inv_majorant;
        if (rng(seed) * majorant < density) { // check if real or null collision
            throughput *= vol_albedo;
#ifdef USE_TRANSFERFUNC
            throughput *= rgba.rgb;
#endif
            return true;
        }
        tauToGo = -log(1.0 - rng(seed));
    }
    return false;
}

// DDA-based transmittance
float transmittanceDDA(const ray world_ray, inout uint seed) {
    // clip volume
    float2 near_far;
    if (!intersect_box(world_ray, vol_bb_min, vol_bb_max, near_far)) return 1.0f;
    // to index-space
    const ray index_ray = create_ray(
        mul(vol_density_inv_transform, float4(world_ray.origin, 1)).xyz,
        mul(vol_density_inv_transform, float4(world_ray.dir, 0)).xyz // non-normalized!
    );
    const float3 ri = 1.0f / index_ray.dir;
    // march brick grid
    float t = near_far.x + 1e-6f;
    float Tr = 1.0f;
    float tau = -log(1.0f - rng(seed));
    
    #ifdef MIP_FIXED
    float mip = mip_range_fixed;
    #else
    float mip = MIP_START;
    #endif
    
    while (t < near_far.y) {
        const float3 curr = index_ray.origin + t * index_ray.dir;
        #ifdef USE_TRANSFERFUNC
        const float majorant = vol_majorant * tf_lookup(lookup_majorant(curr, int(round(mip))) * vol_inv_majorant).a;
        #else
        const float majorant = lookup_majorant(curr, int(round(mip)));
        #endif
        const float dt = stepDDA(curr, ri, int(round(mip)));
        t += dt;
        tau -= majorant * dt;
        #ifndef MIP_FIXED
        mip = min(mip + MIP_SPEED_UP, 3.0f);
        #endif
        if (tau > 0.0f) continue; // no collision, step ahead
        t += tau / majorant; // step back to point of collision
        if (t >= near_far.y) break;
        #ifdef USE_TRANSFERFUNC
        const float4 rgba = tf_lookup(lookup_density_trilinear(index_ray.origin + t * index_ray.dir) * vol_inv_majorant);
        const float d = vol_majorant * rgba.a;
        #else
        const float d = lookup_density_stochastic(index_ray.origin + t * index_ray.dir, seed);
        #endif
        if (rng(seed) * majorant < d) { // check if real or null collision
            Tr *= max(0.f, 1.f - vol_majorant / majorant); // adjust by ratio of global to local majorant
            // russian roulette
            if (Tr < 0.1f) {
                const float prob = 1 - Tr;
                if (rng(seed) < prob) return 0.f;
                Tr /= 1.0f - prob;
            }
        }
        tau = -log(1.0f - rng(seed));
        #ifndef MIP_FIXED
        mip = max(0.f, mip - MIP_SPEED_DOWN);
        #endif
    }
    return Tr;
}

    
// DDA-based volume sampling
bool sample_volume_dda(const ray world_ray, out float t, inout float3 throughput, inout float3 Le, inout uint seed) {
    // clip volume
    t = 0.0f;
    float2 near_far;
    if (!intersect_box(world_ray, vol_bb_min, vol_bb_max, near_far)) return false;
    // to index-space
    const ray index_ray = create_ray(
        mul(vol_density_inv_transform, float4(world_ray.origin, 1)).xyz,
        mul(vol_density_inv_transform, float4(world_ray.dir, 0)).xyz // non-normalized!
    );
    const float3 ri = 1.f / index_ray.dir;
    // march brick grid
    t = near_far.x + 1e-6f;
    float tau = -log(1.0 - rng(seed));

    #ifdef MIP_FIXED
    float mip = mip_range_fixed;
    #else
    float mip = MIP_START;
    #endif
    
    while (t < near_far.y) {
        const float3 curr = index_ray.origin + t * index_ray.dir;
#ifdef USE_TRANSFERFUNC
        const float majorant = vol_majorant * tf_lookup(lookup_majorant(curr, int(round(mip))) * vol_inv_majorant).a;
#else
        const float majorant = lookup_majorant(curr, int(round(mip)));
#endif
        const float dt = stepDDA(curr, ri, int(round(mip)));
        t += dt;
        tau -= majorant * dt;
        #ifndef MIP_FIXED
        mip = min(mip + MIP_SPEED_UP, 3.f);
        #endif
        if (tau > 0) continue; // no collision, step ahead
        t += tau / majorant; // step back to point of collision
        if (t >= near_far.y) break;
#ifdef USE_TRANSFERFUNC
        const float4 rgba = tf_lookup(lookup_density_trilinear(index_ray.origin + t * index_ray.dir) * vol_inv_majorant);
        const float d = vol_majorant * rgba.a;
#else
        const float d = lookup_density_stochastic(index_ray.origin + t * index_ray.dir, seed);
#endif
        //Le += throughput * (1.f - vol_albedo) * lookup_emission(ipos + t * idir, seed) * d * vol_inv_majorant;
        Le += throughput * (1.f - vol_albedo) * d * vol_inv_majorant;
        if (rng(seed) * majorant < d) { // check if real or null collision
            throughput *= vol_albedo;
#ifdef USE_TRANSFERFUNC
            throughput *= rgba.rgb;
#endif
            return true;
        }
        tau = -log(1.0 - rng(seed));
        #ifndef MIP_FIXED 
        mip = max(0.0, mip - MIP_SPEED_DOWN);
        #endif
    }
    return false;
}

// --------------------------------------------------------------
// volumetric path tracing

float4 trace_path(ray ray, inout uint seed, out float depth) {
    depth = -1;
    // trace path
    float3 L = float3(0, 0, 0);
    float3 throughput = float3(1, 1, 1);
    bool free_path = true, hit_volume = false;
    uint n_paths = 0;
    float t, f_p = 0; // t: end of ray segment (i.e. sampled position or out of volume), f_p: last phase function sample for MIS
#ifdef USE_DDA
    while (sample_volume_dda2(ray, t, throughput, L, seed)) {
#else
    while (sample_volume(ray, t, throughput, L, seed)) {
#endif
        if(!hit_volume) {
            depth = t;
            hit_volume = true;
        }
        
        // advance ray
        ray.origin = ray.origin + t * ray.dir;

        // sample light source (environment)
        float3 w_i;
        const float4 Le_pdf = sample_environment(rng2(seed), w_i);
        if (Le_pdf.w > 0) {
            f_p = phase_henyey_greenstein(dot(-ray.dir, w_i), vol_phase_g);
            const float mis_weight = show_environment > 0 ? power_heuristic(Le_pdf.w, f_p) : 1.f;
#ifdef USE_DDA
            const float Tr = transmittanceDDA2(create_ray(ray.origin, w_i), seed);
#else
            const float Tr = transmittance(create_ray(ray.origin, w_i), seed);
#endif
            L += throughput * mis_weight * f_p * Tr * Le_pdf.rgb / Le_pdf.w;
        }

        // early out?
        if (++n_paths >= bounces) { free_path = false; break; }
        // russian roulette
        const float rr_val = luma(throughput);
        if (rr_val < .1f) {
            const float prob = 1 - rr_val;
            if (rng(seed) < prob) { free_path = false; break; }
            throughput /= 1 - prob;
        }

        // scatter ray
        const float3 scatter_dir = sample_phase_henyey_greenstein(ray.dir, vol_phase_g, rng2(seed));
        f_p = phase_henyey_greenstein(dot(-ray.dir, scatter_dir), vol_phase_g);
        ray.dir = scatter_dir;
    }

    // did not hit volume at all, return transparent color so we can blend it with rest of scene
    if (!hit_volume && show_environment == 0)
    {
       return float4(0, 0, 0 ,0);
    }

    // free path? -> add envmap contribution
    if (free_path) {
        const float3 Le = lookup_environment(ray.dir);
        const float mis_weight = n_paths > 0 ? power_heuristic(f_p, environment_light(ray.origin, ray.dir)) : 1.f;
        L += throughput * mis_weight * Le;
    }

    return float4(L, clamp(n_paths, 0.f, 1.f));
}

// --------------------------------------------------------------
// simple direct volume rendering
#define RAYMARCH_STEPS 64
#ifdef USE_TRANSFERFUNC
float4 direct_volume_rendering(ray world_ray, inout uint seed) {
    float3 L = float3(0, 0, 0);
    // clip volume
    float2 near_far;
    if (!intersect_box(world_ray, vol_bb_min, vol_bb_max, near_far))
        return float4(lookup_environment(world_ray.dir), 1) * show_environment;

    // to index-space
    const ray index_ray = create_ray(
        mul(vol_density_inv_transform, float4(world_ray.origin, 1)).xyz,
        mul(vol_density_inv_transform, float4(world_ray.dir, 0)).xyz // non-normalized!
    );
    //idir = normalize(idir);
    // ray marching
    const float dt = (near_far.y - near_far.x) / float(RAYMARCH_STEPS);
    near_far.x += rng(seed) * dt; // jitter starting position
    float Tr = 1.f;
    for (int i = 0; i < RAYMARCH_STEPS; ++i) {
        const float4 rgba = tf_lookup(lookup_density_trilinear(index_ray.origin + min(near_far.x + float(i) * dt, near_far.y) * index_ray.dir));
        const float dtau = rgba.a * dt;
        L += rgba.rgb * dtau * Tr;
        Tr *= exp(-dtau);
        if (Tr <= 1e-6) return float4(L, 1);
    }
    if (show_environment > 0)
    {
        return float4(L + lookup_environment(index_ray.dir) * Tr, 1);
    }
    return float4(L, 1 - Tr);
}
#endif

float4 render_intersect(const ray ray)
{
    float2 near_far;
    if (intersect_box(ray, vol_bb_min, vol_bb_max, near_far))
    {
        return float4(0, 1, 1, 1);
    }
    return float4(0, 0, 0, 0);
}