aboutsummaryrefslogtreecommitdiff
path: root/src/Ryujinx.Graphics.Shader/Translation/Transforms/VectorComponentSelect.cs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Ryujinx.Graphics.Shader/Translation/Transforms/VectorComponentSelect.cs')
-rw-r--r--src/Ryujinx.Graphics.Shader/Translation/Transforms/VectorComponentSelect.cs96
1 files changed, 96 insertions, 0 deletions
diff --git a/src/Ryujinx.Graphics.Shader/Translation/Transforms/VectorComponentSelect.cs b/src/Ryujinx.Graphics.Shader/Translation/Transforms/VectorComponentSelect.cs
new file mode 100644
index 00000000..e55f4355
--- /dev/null
+++ b/src/Ryujinx.Graphics.Shader/Translation/Transforms/VectorComponentSelect.cs
@@ -0,0 +1,96 @@
+using Ryujinx.Graphics.Shader.IntermediateRepresentation;
+using Ryujinx.Graphics.Shader.StructuredIr;
+using System.Collections.Generic;
+
+using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper;
+
+namespace Ryujinx.Graphics.Shader.Translation.Transforms
+{
+ class VectorComponentSelect : ITransformPass
+ {
+ public static bool IsEnabled(IGpuAccessor gpuAccessor, ShaderStage stage, TargetLanguage targetLanguage, FeatureFlags usedFeatures)
+ {
+ return gpuAccessor.QueryHostHasVectorIndexingBug();
+ }
+
+ public static LinkedListNode<INode> RunPass(TransformContext context, LinkedListNode<INode> node)
+ {
+ Operation operation = (Operation)node.Value;
+
+ if (operation.Inst != Instruction.Load ||
+ operation.StorageKind != StorageKind.ConstantBuffer ||
+ operation.SourcesCount < 3)
+ {
+ return node;
+ }
+
+ Operand bindingIndex = operation.GetSource(0);
+ Operand fieldIndex = operation.GetSource(1);
+ Operand elemIndex = operation.GetSource(operation.SourcesCount - 1);
+
+ if (bindingIndex.Type != OperandType.Constant ||
+ fieldIndex.Type != OperandType.Constant ||
+ elemIndex.Type == OperandType.Constant)
+ {
+ return node;
+ }
+
+ BufferDefinition buffer = context.ResourceManager.Properties.ConstantBuffers[bindingIndex.Value];
+ StructureField field = buffer.Type.Fields[fieldIndex.Value];
+
+ int elemCount = (field.Type & AggregateType.ElementCountMask) switch
+ {
+ AggregateType.Vector2 => 2,
+ AggregateType.Vector3 => 3,
+ AggregateType.Vector4 => 4,
+ _ => 1
+ };
+
+ if (elemCount == 1)
+ {
+ return node;
+ }
+
+ Operand result = null;
+
+ for (int i = 0; i < elemCount; i++)
+ {
+ Operand value = Local();
+ Operand[] inputs = new Operand[operation.SourcesCount];
+
+ for (int srcIndex = 0; srcIndex < inputs.Length - 1; srcIndex++)
+ {
+ inputs[srcIndex] = operation.GetSource(srcIndex);
+ }
+
+ inputs[^1] = Const(i);
+
+ Operation loadOp = new(Instruction.Load, StorageKind.ConstantBuffer, value, inputs);
+
+ node.List.AddBefore(node, loadOp);
+
+ if (i == 0)
+ {
+ result = value;
+ }
+ else
+ {
+ Operand isCurrentIndex = Local();
+ Operand selection = Local();
+
+ Operation compareOp = new(Instruction.CompareEqual, isCurrentIndex, new Operand[] { elemIndex, Const(i) });
+ Operation selectOp = new(Instruction.ConditionalSelect, selection, new Operand[] { isCurrentIndex, value, result });
+
+ node.List.AddBefore(node, compareOp);
+ node.List.AddBefore(node, selectOp);
+
+ result = selection;
+ }
+ }
+
+ operation.TurnIntoCopy(result);
+
+ return node;
+ }
+ }
+}