using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using Ryujinx.Graphics.Shader.Translation.Optimizations;
using System.Collections.Generic;

using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper;

namespace Ryujinx.Graphics.Shader.Translation.Transforms
{
    class VertexToCompute : ITransformPass
    {
        public static bool IsEnabled(IGpuAccessor gpuAccessor, ShaderStage stage, TargetLanguage targetLanguage, FeatureFlags usedFeatures)
        {
            return usedFeatures.HasFlag(FeatureFlags.VtgAsCompute);
        }

        public static LinkedListNode<INode> RunPass(TransformContext context, LinkedListNode<INode> node)
        {
            if (context.Definitions.Stage != ShaderStage.Vertex)
            {
                return node;
            }

            Operation operation = (Operation)node.Value;

            LinkedListNode<INode> newNode = node;

            if (operation.Inst == Instruction.Load && operation.StorageKind == StorageKind.Input)
            {
                Operand dest = operation.Dest;

                switch ((IoVariable)operation.GetSource(0).Value)
                {
                    case IoVariable.BaseInstance:
                        newNode = GenerateBaseInstanceLoad(context.ResourceManager, node, dest);
                        break;
                    case IoVariable.BaseVertex:
                        newNode = GenerateBaseVertexLoad(context.ResourceManager, node, dest);
                        break;
                    case IoVariable.InstanceId:
                        newNode = GenerateInstanceIdLoad(node, dest);
                        break;
                    case IoVariable.InstanceIndex:
                        newNode = GenerateInstanceIndexLoad(context.ResourceManager, node, dest);
                        break;
                    case IoVariable.VertexId:
                    case IoVariable.VertexIndex:
                        newNode = GenerateVertexIndexLoad(context.ResourceManager, node, dest);
                        break;
                    case IoVariable.UserDefined:
                        int location = operation.GetSource(1).Value;
                        int component = operation.GetSource(2).Value;

                        if (context.Definitions.IsAttributePacked(location))
                        {
                            bool needsSextNorm = context.Definitions.IsAttributePackedRgb10A2Signed(location);

                            Operand temp = needsSextNorm ? Local() : dest;
                            Operand vertexElemOffset = GenerateVertexOffset(context.ResourceManager, node, location, 0);

                            newNode = node.List.AddBefore(node, new TextureOperation(
                                Instruction.TextureSample,
                                SamplerType.TextureBuffer,
                                TextureFormat.Unknown,
                                TextureFlags.IntCoords,
                                context.ResourceManager.Reservations.GetVertexBufferTextureBinding(location),
                                1 << component,
                                new[] { temp },
                                new[] { vertexElemOffset }));

                            if (needsSextNorm)
                            {
                                bool sint = context.Definitions.IsAttributeSint(location);
                                CopySignExtendedNormalized(node, component == 3 ? 2 : 10, !sint, dest, temp);
                            }
                        }
                        else
                        {
                            Operand temp = component > 0 ? Local() : dest;
                            Operand vertexElemOffset = GenerateVertexOffset(context.ResourceManager, node, location, component);

                            newNode = node.List.AddBefore(node, new TextureOperation(
                                Instruction.TextureSample,
                                SamplerType.TextureBuffer,
                                TextureFormat.Unknown,
                                TextureFlags.IntCoords,
                                context.ResourceManager.Reservations.GetVertexBufferTextureBinding(location),
                                1,
                                new[] { temp },
                                new[] { vertexElemOffset }));

                            if (component > 0)
                            {
                                newNode = CopyMasked(context.ResourceManager, newNode, location, component, dest, temp);
                            }
                        }
                        break;
                    case IoVariable.GlobalId:
                    case IoVariable.SubgroupEqMask:
                    case IoVariable.SubgroupGeMask:
                    case IoVariable.SubgroupGtMask:
                    case IoVariable.SubgroupLaneId:
                    case IoVariable.SubgroupLeMask:
                    case IoVariable.SubgroupLtMask:
                        // Those are valid or expected for vertex shaders.
                        break;
                    default:
                        context.GpuAccessor.Log($"Invalid input \"{(IoVariable)operation.GetSource(0).Value}\".");
                        break;
                }
            }
            else if (operation.Inst == Instruction.Load && operation.StorageKind == StorageKind.Output)
            {
                if (TryGetOutputOffset(context.ResourceManager, operation, out int outputOffset))
                {
                    newNode = node.List.AddBefore(node, new Operation(
                        Instruction.Load,
                        StorageKind.LocalMemory,
                        operation.Dest,
                        new[] { Const(context.ResourceManager.LocalVertexDataMemoryId), Const(outputOffset) }));
                }
                else
                {
                    context.GpuAccessor.Log($"Invalid output \"{(IoVariable)operation.GetSource(0).Value}\".");
                }
            }
            else if (operation.Inst == Instruction.Store && operation.StorageKind == StorageKind.Output)
            {
                if (TryGetOutputOffset(context.ResourceManager, operation, out int outputOffset))
                {
                    Operand value = operation.GetSource(operation.SourcesCount - 1);

                    newNode = node.List.AddBefore(node, new Operation(
                        Instruction.Store,
                        StorageKind.LocalMemory,
                        (Operand)null,
                        new[] { Const(context.ResourceManager.LocalVertexDataMemoryId), Const(outputOffset), value }));
                }
                else
                {
                    context.GpuAccessor.Log($"Invalid output \"{(IoVariable)operation.GetSource(0).Value}\".");
                }
            }

            if (newNode != node)
            {
                Utils.DeleteNode(node, operation);
            }

            return newNode;
        }

        private static Operand GenerateVertexOffset(ResourceManager resourceManager, LinkedListNode<INode> node, int location, int component)
        {
            int vertexInfoCbBinding = resourceManager.Reservations.VertexInfoConstantBufferBinding;

            Operand vertexIdVr = Local();
            GenerateVertexIdVertexRateLoad(resourceManager, node, vertexIdVr);

            Operand vertexIdIr = Local();
            GenerateVertexIdInstanceRateLoad(resourceManager, node, vertexIdIr);

            Operand attributeOffset = Local();
            node.List.AddBefore(node, new Operation(
                Instruction.Load,
                StorageKind.ConstantBuffer,
                attributeOffset,
                new[] { Const(vertexInfoCbBinding), Const((int)VertexInfoBufferField.VertexOffsets), Const(location), Const(0) }));

            Operand isInstanceRate = Local();
            node.List.AddBefore(node, new Operation(
                Instruction.Load,
                StorageKind.ConstantBuffer,
                isInstanceRate,
                new[] { Const(vertexInfoCbBinding), Const((int)VertexInfoBufferField.VertexOffsets), Const(location), Const(1) }));

            Operand vertexId = Local();
            node.List.AddBefore(node, new Operation(
                Instruction.ConditionalSelect,
                vertexId,
                new[] { isInstanceRate, vertexIdIr, vertexIdVr }));

            Operand vertexStride = Local();
            node.List.AddBefore(node, new Operation(
                Instruction.Load,
                StorageKind.ConstantBuffer,
                vertexStride,
                new[] { Const(vertexInfoCbBinding), Const((int)VertexInfoBufferField.VertexStrides), Const(location), Const(0) }));

            Operand vertexBaseOffset = Local();
            node.List.AddBefore(node, new Operation(Instruction.Multiply, vertexBaseOffset, new[] { vertexId, vertexStride }));

            Operand vertexOffset = Local();
            node.List.AddBefore(node, new Operation(Instruction.Add, vertexOffset, new[] { attributeOffset, vertexBaseOffset }));

            Operand vertexElemOffset;

            if (component != 0)
            {
                vertexElemOffset = Local();

                node.List.AddBefore(node, new Operation(Instruction.Add, vertexElemOffset, new[] { vertexOffset, Const(component) }));
            }
            else
            {
                vertexElemOffset = vertexOffset;
            }

            return vertexElemOffset;
        }

        private static LinkedListNode<INode> CopySignExtendedNormalized(LinkedListNode<INode> node, int bits, bool normalize, Operand dest, Operand src)
        {
            Operand leftShifted = Local();
            node = node.List.AddAfter(node, new Operation(
                Instruction.ShiftLeft,
                leftShifted,
                new[] { src, Const(32 - bits) }));

            Operand rightShifted = normalize ? Local() : dest;
            node = node.List.AddAfter(node, new Operation(
                Instruction.ShiftRightS32,
                rightShifted,
                new[] { leftShifted, Const(32 - bits) }));

            if (normalize)
            {
                Operand asFloat = Local();
                node = node.List.AddAfter(node, new Operation(Instruction.ConvertS32ToFP32, asFloat, new[] { rightShifted }));
                node = node.List.AddAfter(node, new Operation(
                    Instruction.FP32 | Instruction.Multiply,
                    dest,
                    new[] { asFloat, ConstF(1f / (1 << (bits - 1))) }));
            }

            return node;
        }

        private static LinkedListNode<INode> CopyMasked(
            ResourceManager resourceManager,
            LinkedListNode<INode> node,
            int location,
            int component,
            Operand dest,
            Operand src)
        {
            Operand componentExists = Local();
            int vertexInfoCbBinding = resourceManager.Reservations.VertexInfoConstantBufferBinding;
            node = node.List.AddAfter(node, new Operation(
                Instruction.Load,
                StorageKind.ConstantBuffer,
                componentExists,
                new[] { Const(vertexInfoCbBinding), Const((int)VertexInfoBufferField.VertexStrides), Const(location), Const(component) }));

            return node.List.AddAfter(node, new Operation(
                Instruction.ConditionalSelect,
                dest,
                new[] { componentExists, src, ConstF(component == 3 ? 1f : 0f) }));
        }

        private static LinkedListNode<INode> GenerateBaseVertexLoad(ResourceManager resourceManager, LinkedListNode<INode> node, Operand dest)
        {
            int vertexInfoCbBinding = resourceManager.Reservations.VertexInfoConstantBufferBinding;

            return node.List.AddBefore(node, new Operation(
                Instruction.Load,
                StorageKind.ConstantBuffer,
                dest,
                new[] { Const(vertexInfoCbBinding), Const((int)VertexInfoBufferField.VertexCounts), Const(2) }));
        }

        private static LinkedListNode<INode> GenerateBaseInstanceLoad(ResourceManager resourceManager, LinkedListNode<INode> node, Operand dest)
        {
            int vertexInfoCbBinding = resourceManager.Reservations.VertexInfoConstantBufferBinding;

            return node.List.AddBefore(node, new Operation(
                Instruction.Load,
                StorageKind.ConstantBuffer,
                dest,
                new[] { Const(vertexInfoCbBinding), Const((int)VertexInfoBufferField.VertexCounts), Const(3) }));
        }

        private static LinkedListNode<INode> GenerateVertexIndexLoad(ResourceManager resourceManager, LinkedListNode<INode> node, Operand dest)
        {
            Operand baseVertex = Local();
            Operand vertexId = Local();

            GenerateBaseVertexLoad(resourceManager, node, baseVertex);
            GenerateVertexIdVertexRateLoad(resourceManager, node, vertexId);

            return node.List.AddBefore(node, new Operation(Instruction.Add, dest, new[] { baseVertex, vertexId }));
        }

        private static LinkedListNode<INode> GenerateInstanceIndexLoad(ResourceManager resourceManager, LinkedListNode<INode> node, Operand dest)
        {
            Operand baseInstance = Local();
            Operand instanceId = Local();

            GenerateBaseInstanceLoad(resourceManager, node, baseInstance);

            node.List.AddBefore(node, new Operation(
                Instruction.Load,
                StorageKind.Input,
                instanceId,
                new[] { Const((int)IoVariable.GlobalId), Const(1) }));

            return node.List.AddBefore(node, new Operation(Instruction.Add, dest, new[] { baseInstance, instanceId }));
        }

        private static LinkedListNode<INode> GenerateVertexIdVertexRateLoad(ResourceManager resourceManager, LinkedListNode<INode> node, Operand dest)
        {
            Operand[] sources = new Operand[] { Const(resourceManager.LocalVertexIndexVertexRateMemoryId) };

            return node.List.AddBefore(node, new Operation(Instruction.Load, StorageKind.LocalMemory, dest, sources));
        }

        private static LinkedListNode<INode> GenerateVertexIdInstanceRateLoad(ResourceManager resourceManager, LinkedListNode<INode> node, Operand dest)
        {
            Operand[] sources = new Operand[] { Const(resourceManager.LocalVertexIndexInstanceRateMemoryId) };

            return node.List.AddBefore(node, new Operation(Instruction.Load, StorageKind.LocalMemory, dest, sources));
        }

        private static LinkedListNode<INode> GenerateInstanceIdLoad(LinkedListNode<INode> node, Operand dest)
        {
            Operand[] sources = new Operand[] { Const((int)IoVariable.GlobalId), Const(1) };

            return node.List.AddBefore(node, new Operation(Instruction.Load, StorageKind.Input, dest, sources));
        }

        private static bool TryGetOutputOffset(ResourceManager resourceManager, Operation operation, out int outputOffset)
        {
            bool isStore = operation.Inst == Instruction.Store;

            IoVariable ioVariable = (IoVariable)operation.GetSource(0).Value;

            bool isValidOutput;

            if (ioVariable == IoVariable.UserDefined)
            {
                int lastIndex = operation.SourcesCount - (isStore ? 2 : 1);

                int location = operation.GetSource(1).Value;
                int component = operation.GetSource(lastIndex).Value;

                isValidOutput = resourceManager.Reservations.TryGetOffset(StorageKind.Output, location, component, out outputOffset);
            }
            else
            {
                if (ResourceReservations.IsVectorOrArrayVariable(ioVariable))
                {
                    int component = operation.GetSource(operation.SourcesCount - (isStore ? 2 : 1)).Value;

                    isValidOutput = resourceManager.Reservations.TryGetOffset(StorageKind.Output, ioVariable, component, out outputOffset);
                }
                else
                {
                    isValidOutput = resourceManager.Reservations.TryGetOffset(StorageKind.Output, ioVariable, out outputOffset);
                }
            }

            return isValidOutput;
        }
    }
}