//#define GENERATE_PARALLEL // use parallel iterators to compute data

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using UnityEngine;
using Unity.Mathematics;
using UnityEditor;
using UnityEngine.Assertions;

public enum VolumeFormat { Indirect, Raw }

public enum RangeFormat { Float, Half }
public enum AtlasFormat { Float, Half }
public enum IndirectionFormat { Float }

public struct BrickDataResult
{
    public Buf3D<Color32> Indirection;
    public Buf3D<Color> Range;
    public Buf3D<Color> Atlas;
    public Buf3D<Color>[] RangeMipmaps;
    public Vector2 MinMax;
}

public class VolumeData : ScriptableObject
{
    private const string SourceFolder = "Assets/Resources/Volumes/";
    private const string TextureSourceFolder = SourceFolder + "Textures/";

    public const int MIN_HOUNSFIELD = -1024;
    public const int MAX_HOUNSFIELD = 3071;
    
    private Buf3D<float> rawData; // null if raw data is not included

    public VolumeFormat volumeFormat { get; private set; }
    public RangeFormat rangeFormat { get; private set; }
    public AtlasFormat atlasFormat { get; private set; }
    public IndirectionFormat indirectionFormat { get; private set; }
    public uint mipCount { get; private set; }

    public Vector2 densityMinMaj;
    public uint3 dimensions;
    [SerializeField] private Texture3D densityIndirection;
    [SerializeField] private Texture3D densityRange;
    [SerializeField] private Texture3D densityAtlas;
    [SerializeField] private Texture3D[] densityRangeMips;
    [SerializeField] private Texture2D histogramTexture;
    public Vector3 volumeScale;
    public Quaternion volumeRotation;
    [SerializeField] private Vector2 volumeMinMax;
    public string dataName;
    
    public const uint BrickSize = 8;
    public const uint BitsPerAxis = 8;
    public const uint MaxBricks = 1 << (int)BitsPerAxis;
    public const uint VoxelsPerBrick = BrickSize * BrickSize * BrickSize;

    public void Init(Buf3D<float> rawData, string dataName, Vector3 scale)
    {
        this.rawData = rawData;
        this.dataName = dataName;
        this.volumeScale = scale;
        this.volumeRotation = Quaternion.identity;
    }

    public bool HasRaw()
    {
        return rawData != null;
    }

    public void BuildTextures(
        VolumeFormat volumeFormat = VolumeFormat.Indirect, 
        RangeFormat rangeFormat = RangeFormat.Float, 
        AtlasFormat atlasFormat = AtlasFormat.Half, 
        IndirectionFormat indirectionFormat = IndirectionFormat.Float, 
        uint mipCount = 3,
        bool createMipTextures = false,
        bool storeAsAsset = false,
        bool keepRaw = true)
    {
        Assert.IsTrue(HasRaw());
        
        this.volumeFormat = volumeFormat;
        this.rangeFormat = rangeFormat;
        this.atlasFormat = atlasFormat;
        this.indirectionFormat = indirectionFormat;
        this.mipCount = mipCount;
        densityMinMaj = new Vector2(0, 1);
        dimensions = rawData.Dimensions;
        
        var data = ComputeBrickData();
        
        GenerateHistogramTexture();

        if (volumeFormat == VolumeFormat.Raw)
        {
            Debug.LogError("Unsupported format");
        }
        else
        {
            densityIndirection = CreateTexture(data.Indirection, TextureFormat.RGB24);
            densityAtlas = CreateTexture(data.Atlas, TextureFormat.RHalf);
            densityRange = CreateTexture(data.Range, TextureFormat.RGFloat, data.RangeMipmaps);

            if (createMipTextures && mipCount > 0)
            {
                densityRangeMips = new Texture3D[mipCount];
                for (var i = 0; i < mipCount; ++i)
                {
                    densityRangeMips[i] = CreateTexture(data.RangeMipmaps[i], TextureFormat.RGFloat);
                } 
            }
            else
            {
                densityRangeMips = Array.Empty<Texture3D>();
            }
        }

        if (!keepRaw)
        {
            rawData = null;
        }

        if (storeAsAsset)
        {
            SaveAssets();
        }
    }

    public BrickDataResult ComputeBrickData()
    {
        Assert.IsTrue(HasRaw());
        // potentially computationally expensive
        var minValue = rawData.RawIter.Min();
        var maxValue = rawData.RawIter.Max();
        volumeMinMax = new Vector2(minValue, maxValue);
        
        var brickSize = new uint3(BrickSize);
        var powerOfMipMaps = new uint3(1 << (int) mipCount);
        var nBricks = div_round_up(div_round_up(rawData.Dimensions, brickSize), powerOfMipMaps) * powerOfMipMaps;

        var indirection = new Buf3D<Color32>(nBricks);
        var range = new Buf3D<Color>(nBricks);
        // allow for one more brick in case there are no bricks that are excluded
        var atlas = new Buf3D<Color>((nBricks + new uint3(0, 0, 1)) * brickSize);

        var brickCounter = 0;
#if GENERATE_PARALLEL
        Parallel.ForEach(Util.EnumerateIndices(new uint3(0, 0, 0), nBricks), brickIndex =>
#else
        foreach (var brickIndex in Util.EnumerateIndices(new uint3(0, 0, 0), nBricks))
#endif
            {
                indirection[brickIndex] = new Color32(0, 0, 0, 0);
                var localBrickIndex = (int3)(brickIndex * brickSize);

                float localMin = float.MaxValue, localMax = float.MinValue;
                foreach (var localIndex in Util.EnumerateIndices(new int3(-1), (int3)brickSize + new int3(1)))
                {
                    var index = localBrickIndex + localIndex;
                    var density = LookupRaw(index);
                    var value = density.GetValueOrDefault(0);
                    localMax = Math.Max(localMax, value);
                    localMin = Math.Min(localMin, value);
                }

                range[brickIndex] = new Color(localMin, localMax, 0);

                if (localMax.Equals(localMin))
#if GENERATE_PARALLEL
                    return;
#else
                continue;
#endif
                
#if GENERATE_PARALLEL
                var brickNum = Interlocked.Increment(ref brickCounter);
#else
                var brickNum = brickCounter++;
#endif
                
                var indirectionPtr = indirection.ToCoord(brickNum);
                indirection[brickIndex] =
                    new Color32((byte)indirectionPtr.x, (byte)indirectionPtr.y, (byte)indirectionPtr.z, 0);

                foreach (var localIndex in Util.EnumerateIndices(new uint3(0), brickSize))
                {
                    var atlasIndex = indirectionPtr * brickSize + localIndex;
                    var density = LookupRaw((int3)(brickIndex * brickSize + localIndex));

                    if (!density.HasValue && (math.any(brickIndex >= nBricks * BrickSize) ||
                                              math.any(brickIndex < new uint3(0))))
                    {
                        Debug.LogError("Atlas index " + atlasIndex + " out of range for dimensions " +
                                       nBricks * BrickSize);
                    }

                    var value = density.GetValueOrDefault(0);
                    value = math.clamp((value - localMin) / (localMax - localMin), 0, 1);

                    atlas[atlasIndex] = new Color(value, 0, 0);
                }
            }
#if GENERATE_PARALLEL
        );
#endif
        atlas.Prune(BrickSize * (uint) Mathf.CeilToInt((brickCounter + 1) / (float)(nBricks.x * nBricks.y)));

        var rangeMipmaps = new Buf3D<Color>[mipCount];
        for (uint mip = 0; mip < mipCount; ++mip)
        {
            var mipSize = nBricks / (uint) (1 << (int) (mip + 1));
            rangeMipmaps[mip] = new Buf3D<Color>(mipSize);
            var source = mip == 0 ? range : rangeMipmaps[mip - 1];
            foreach (var brickIndex in Util.EnumerateIndices(new uint3(0), mipSize))
            {
                float rangeMin = float.MaxValue, rangeMax = float.MinValue;
                foreach (var localIndex in Util.EnumerateIndices(new uint3(0), new uint3(2)))
                {
                    var sourceIndex = brickIndex * 2 + localIndex;
                    var current = source[sourceIndex];
                    rangeMin = Math.Min(rangeMin, current.r);
                    rangeMax = Math.Max(rangeMax, current.g);
                }

                rangeMipmaps[mip][brickIndex] = new Color(rangeMin, rangeMax, 0);
            }
        }

        return new BrickDataResult
        {
            Indirection = indirection,
            Atlas = atlas,
            Range = range,
            RangeMipmaps = rangeMipmaps,
            MinMax = new Vector2(minValue, maxValue)
        };
    }

    private float? LookupRaw(int3 index)
    {
        if (math.any(index >= (int3)rawData.Dimensions) || math.any(index < new int3(0))) return null;
        var value = rawData[(uint3) index];
        // normalize value to [0, 1]
        value = (value - volumeMinMax.x) / (volumeMinMax.y - volumeMinMax.x);
        return math.clamp(value, 0, 1);
    }
    
    // Closely follows: https://github.com/mlavik1/UnityVolumeRendering/blob/master/Assets/Scripts/Utils/HistogramTextureGenerator.cs
    public void GenerateHistogramTexture()
    {
        const float minValue = MIN_HOUNSFIELD;
        const float maxValue = MAX_HOUNSFIELD;
        float valueRange = maxValue - minValue;

        int numFrequencies = Mathf.Min((int)valueRange, 1024);
        int[] frequencies = new int[numFrequencies];

        int maxFreq = 0;
        float valRangeRecip = 1.0f / (maxValue - minValue);
        foreach (var rawDensity in rawData.RawIter)
        {
            float dataValue = math.clamp(rawDensity, minValue, maxValue);;
            float tValue = (dataValue - minValue) * valRangeRecip;
            int freqIndex = (int)(tValue * (numFrequencies - 1));
            frequencies[freqIndex] += 1;
            maxFreq = System.Math.Max(frequencies[freqIndex], maxFreq);
        }

        Color[] cols = new Color[numFrequencies];
        histogramTexture = new Texture2D(numFrequencies, 1, TextureFormat.RGBAFloat, false);

        for (int iSample = 0; iSample < numFrequencies; iSample++)
            cols[iSample] = new Color(Mathf.Log10((float)frequencies[iSample]) / Mathf.Log10((float)maxFreq), 0.0f, 0.0f, 1.0f);

        histogramTexture.SetPixels(cols);
        //texture.filterMode = FilterMode.Point;
        histogramTexture.Apply();
    }

    private Texture3D CreateTexture<T>(Buf3D<T> data, TextureFormat format, Buf3D<T>[] mipData = null)
    {
        var dataMipCount = mipData?.Length + 1 ?? 0;
        var texture = new Texture3D(
            (int)data.Dimensions.x,
            (int)data.Dimensions.y,
            (int)data.Dimensions.z,
            format,
            dataMipCount);
        
        texture.wrapMode = TextureWrapMode.Clamp;
        texture.filterMode = FilterMode.Point;
        texture.anisoLevel = 0;
        
        SetTextureData(texture, data, 0);

        if (mipData != null)
        {
            var mipLevel = 0;
            foreach (var mip in mipData)
            {
                SetTextureData(texture, mip, ++mipLevel);
            }
        }

        // Don't let Unity automatically recalculate mipmaps
        texture.Apply(false, true);

        return texture;
    }

    private void SetTextureData<T>(Texture3D texture, Buf3D<T> data, int mipLevel)
    {
        if (typeof(T) == typeof(Color32))
        {
            texture.SetPixels32(((Buf3D<Color32>)(object)data).Data, mipLevel);
        }
        else if (typeof(T) == typeof(Color))
        {
            texture.SetPixels(((Buf3D<Color>)(object)data).Data, mipLevel);
        }
        else
        {
            texture.SetPixelData(data.Data, mipLevel);
        }
    }
    private void SaveAssets()
    {
#if UNITY_EDITOR
        AssetDatabase.CreateAsset(densityIndirection, TextureSourceFolder + dataName + "_vol_indirection.asset");
        AssetDatabase.CreateAsset(densityRange, TextureSourceFolder + dataName + "_vol_range.asset");
        AssetDatabase.CreateAsset(densityAtlas, TextureSourceFolder + dataName + "_vol_atlas.asset");
        AssetDatabase.CreateAsset(histogramTexture, TextureSourceFolder + dataName + "_vol_histogram.asset");

        var index = 0;
        foreach (var rangeTexture in densityRangeMips)
        {
            AssetDatabase.CreateAsset(rangeTexture, TextureSourceFolder + dataName + "_vol_range" + (++index) + ".asset");
        }

        AssetDatabase.CreateAsset(this, SourceFolder + dataName + ".asset");
#else
        Debug.LogError($"Cannot run {nameof(SaveAssets)} within build!");
#endif
    }

    public void Bind(ComputeShader shader, int kernelIndex, ShaderIndices shaderIndices)
    {
        shader.SetTexture(kernelIndex, shaderIndices.VolumeDensityIndirection, densityIndirection);
        shader.SetTexture(kernelIndex, shaderIndices.VolumeDensityRange, densityRange);
        shader.SetTexture(kernelIndex, shaderIndices.VolumeDensityAtlas, densityAtlas);

        shader.SetFloat(shaderIndices.VolumeMinorant, densityMinMaj.x);
        shader.SetFloat(shaderIndices.VolumeMajorant, densityMinMaj.y);
        
        shader.SetFloat(shaderIndices.VolumeInverseMajorant, 1 / densityMinMaj.y);
    }
    
    private static uint3 div_round_up(uint3 divident, uint3 divisor)
    {
        return (divident + divisor - 1) / divisor;
    }

    public Texture2D GetHistogramTexture()
    {
        return histogramTexture;
    }

    public Vector2 GetVolumeMinMax()
    {
        return volumeMinMax;
    }
}

public class Buf3D<T>
{
    public T[] Data
    {
        get { return data; }
        set
        {
            if (data.Length != value.Length)
            {
                throw new ArgumentException("Data length must match. Expected " + data.Length + " but was " + value.Length);
            }
            data = value;
        }
    } 
    
    private T[] data;
    private uint3 dimensions;
    
    public Buf3D(uint3 dimensions)
    {
        this.dimensions = dimensions;
        data = new T[GetSize()];
    }

    // Shrink buffer in z direction
    public void Prune(uint slices)
    {
        dimensions = new uint3(dimensions.x, dimensions.y, slices);
        Array.Resize(ref data, (int) GetSize());
    }

    public T this[uint3 coord]
    {
        get
        {
            return data[ToIndex(coord)];
        }
        set { data[ToIndex(coord)] = value; }
    }

    public IEnumerable<T> RawIter
    {
        get { return data; }
    }

    public uint3 Dimensions
    {
        get { return dimensions; }
    }

    public uint ToIndex(uint3 coord)
    {
        return coord.z * dimensions.x * dimensions.y + coord.y * dimensions.x + coord.x;
    }

    public uint3 ToCoord(int index)
    {
        return new uint3(
            (uint) index % dimensions.x,
            (uint) (index / dimensions.x) % dimensions.y,
            (uint) index / (dimensions.x * dimensions.y)
        );
    }

    public uint GetSize()
    {
        return dimensions.x * dimensions.y * dimensions.z;
    }
}