diff options
author | gdkchan <gab.dark.100@gmail.com> | 2020-10-25 17:00:44 -0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-10-25 17:00:44 -0300 |
commit | 49f970d5bd9163e2b4e26a33ef8f84529174d5de (patch) | |
tree | eceaf9c0454d27413ca77689c06a24b47467d1a0 /Ryujinx.Graphics.Shader/Translation/RegisterUsage.cs | |
parent | 973a615d405a83d5fc2f6a11ad12ba63c2a76465 (diff) |
Implement CAL and RET shader instructions (#1618)
* Add support for CAL and RET shader instructions
* Remove unused stuff
* Fix a bug that could cause the wrong values to be passed to a function
* Avoid repopulating function id dictionary every time
* PR feedback
* Fix vertex shader A/B merge
Diffstat (limited to 'Ryujinx.Graphics.Shader/Translation/RegisterUsage.cs')
-rw-r--r-- | Ryujinx.Graphics.Shader/Translation/RegisterUsage.cs | 484 |
1 files changed, 484 insertions, 0 deletions
diff --git a/Ryujinx.Graphics.Shader/Translation/RegisterUsage.cs b/Ryujinx.Graphics.Shader/Translation/RegisterUsage.cs new file mode 100644 index 00000000..fd90391f --- /dev/null +++ b/Ryujinx.Graphics.Shader/Translation/RegisterUsage.cs @@ -0,0 +1,484 @@ +using Ryujinx.Graphics.Shader.Decoders; +using Ryujinx.Graphics.Shader.IntermediateRepresentation; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Numerics; + +namespace Ryujinx.Graphics.Shader.Translation +{ + static class RegisterUsage + { + private const int RegsCount = 256; + private const int RegsMask = RegsCount - 1; + + private const int GprMasks = 4; + private const int PredMasks = 1; + private const int FlagMasks = 1; + private const int TotalMasks = GprMasks + PredMasks + FlagMasks; + + private struct RegisterMask : IEquatable<RegisterMask> + { + public long GprMask0 { get; set; } + public long GprMask1 { get; set; } + public long GprMask2 { get; set; } + public long GprMask3 { get; set; } + public long PredMask { get; set; } + public long FlagMask { get; set; } + + public RegisterMask(long gprMask0, long gprMask1, long gprMask2, long gprMask3, long predMask, long flagMask) + { + GprMask0 = gprMask0; + GprMask1 = gprMask1; + GprMask2 = gprMask2; + GprMask3 = gprMask3; + PredMask = predMask; + FlagMask = flagMask; + } + + public long GetMask(int index) + { + return index switch + { + 0 => GprMask0, + 1 => GprMask1, + 2 => GprMask2, + 3 => GprMask3, + 4 => PredMask, + 5 => FlagMask, + _ => throw new ArgumentOutOfRangeException(nameof(index)) + }; + } + + public static RegisterMask operator &(RegisterMask x, RegisterMask y) + { + return new RegisterMask( + x.GprMask0 & y.GprMask0, + x.GprMask1 & y.GprMask1, + x.GprMask2 & y.GprMask2, + x.GprMask3 & y.GprMask3, + x.PredMask & y.PredMask, + x.FlagMask & y.FlagMask); + } + + public static RegisterMask operator |(RegisterMask x, RegisterMask y) + { + return new RegisterMask( + x.GprMask0 | y.GprMask0, + x.GprMask1 | y.GprMask1, + x.GprMask2 | y.GprMask2, + x.GprMask3 | y.GprMask3, + x.PredMask | y.PredMask, + x.FlagMask | y.FlagMask); + } + + public static RegisterMask operator ~(RegisterMask x) + { + return new RegisterMask( + ~x.GprMask0, + ~x.GprMask1, + ~x.GprMask2, + ~x.GprMask3, + ~x.PredMask, + ~x.FlagMask); + } + + public static bool operator ==(RegisterMask x, RegisterMask y) + { + return x.Equals(y); + } + + public static bool operator !=(RegisterMask x, RegisterMask y) + { + return !x.Equals(y); + } + + public override bool Equals(object obj) + { + return obj is RegisterMask regMask && Equals(regMask); + } + + public bool Equals(RegisterMask other) + { + return GprMask0 == other.GprMask0 && + GprMask1 == other.GprMask1 && + GprMask2 == other.GprMask2 && + GprMask3 == other.GprMask3 && + PredMask == other.PredMask && + FlagMask == other.FlagMask; + } + + public override int GetHashCode() + { + return HashCode.Combine(GprMask0, GprMask1, GprMask2, GprMask3, PredMask, FlagMask); + } + } + + public struct FunctionRegisterUsage + { + public Register[] InArguments { get; } + public Register[] OutArguments { get; } + + public FunctionRegisterUsage(Register[] inArguments, Register[] outArguments) + { + InArguments = inArguments; + OutArguments = outArguments; + } + } + + public static FunctionRegisterUsage RunPass(ControlFlowGraph cfg) + { + List<Register> inArguments = new List<Register>(); + List<Register> outArguments = new List<Register>(); + + // Compute local register inputs and outputs used inside blocks. + RegisterMask[] localInputs = new RegisterMask[cfg.Blocks.Length]; + RegisterMask[] localOutputs = new RegisterMask[cfg.Blocks.Length]; + + foreach (BasicBlock block in cfg.Blocks) + { + for (LinkedListNode<INode> node = block.Operations.First; node != null; node = node.Next) + { + Operation operation = node.Value as Operation; + + for (int srcIndex = 0; srcIndex < operation.SourcesCount; srcIndex++) + { + Operand source = operation.GetSource(srcIndex); + + if (source.Type != OperandType.Register) + { + continue; + } + + Register register = source.GetRegister(); + + localInputs[block.Index] |= GetMask(register) & ~localOutputs[block.Index]; + } + + if (operation.Dest != null && operation.Dest.Type == OperandType.Register) + { + localOutputs[block.Index] |= GetMask(operation.Dest.GetRegister()); + } + } + } + + // Compute global register inputs and outputs used across blocks. + RegisterMask[] globalCmnOutputs = new RegisterMask[cfg.Blocks.Length]; + + RegisterMask[] globalInputs = new RegisterMask[cfg.Blocks.Length]; + RegisterMask[] globalOutputs = new RegisterMask[cfg.Blocks.Length]; + + RegisterMask allOutputs = new RegisterMask(); + RegisterMask allCmnOutputs = new RegisterMask(-1L, -1L, -1L, -1L, -1L, -1L); + + bool modified; + + bool firstPass = true; + + do + { + modified = false; + + // Compute register outputs. + for (int index = cfg.PostOrderBlocks.Length - 1; index >= 0; index--) + { + BasicBlock block = cfg.PostOrderBlocks[index]; + + if (block.Predecessors.Count != 0) + { + BasicBlock predecessor = block.Predecessors[0]; + + RegisterMask cmnOutputs = localOutputs[predecessor.Index] | globalCmnOutputs[predecessor.Index]; + + RegisterMask outputs = globalOutputs[predecessor.Index]; + + for (int pIndex = 1; pIndex < block.Predecessors.Count; pIndex++) + { + predecessor = block.Predecessors[pIndex]; + + cmnOutputs &= localOutputs[predecessor.Index] | globalCmnOutputs[predecessor.Index]; + + outputs |= globalOutputs[predecessor.Index]; + } + + globalInputs[block.Index] |= outputs & ~cmnOutputs; + + if (!firstPass) + { + cmnOutputs &= globalCmnOutputs[block.Index]; + } + + if (EndsWithReturn(block)) + { + allCmnOutputs &= cmnOutputs | localOutputs[block.Index]; + } + + if (Exchange(globalCmnOutputs, block.Index, cmnOutputs)) + { + modified = true; + } + + outputs |= localOutputs[block.Index]; + + if (Exchange(globalOutputs, block.Index, globalOutputs[block.Index] | outputs)) + { + allOutputs |= outputs; + modified = true; + } + } + else if (Exchange(globalOutputs, block.Index, localOutputs[block.Index])) + { + allOutputs |= localOutputs[block.Index]; + modified = true; + } + } + + // Compute register inputs. + for (int index = 0; index < cfg.PostOrderBlocks.Length; index++) + { + BasicBlock block = cfg.PostOrderBlocks[index]; + + RegisterMask inputs = localInputs[block.Index]; + + if (block.Next != null) + { + inputs |= globalInputs[block.Next.Index]; + } + + if (block.Branch != null) + { + inputs |= globalInputs[block.Branch.Index]; + } + + inputs &= ~globalCmnOutputs[block.Index]; + + if (Exchange(globalInputs, block.Index, globalInputs[block.Index] | inputs)) + { + modified = true; + } + } + + firstPass = false; + } + while (modified); + + // Insert load and store context instructions where needed. + foreach (BasicBlock block in cfg.Blocks) + { + // The only block without any predecessor should be the entry block. + // It always needs a context load as it is the first block to run. + if (block.Predecessors.Count == 0) + { + RegisterMask inputs = globalInputs[block.Index] | (allOutputs & ~allCmnOutputs); + + LoadLocals(block, inputs, inArguments); + } + + if (EndsWithReturn(block)) + { + StoreLocals(block, allOutputs, inArguments.Count, outArguments); + } + } + + return new FunctionRegisterUsage(inArguments.ToArray(), outArguments.ToArray()); + } + + public static void FixupCalls(BasicBlock[] blocks, FunctionRegisterUsage[] frus) + { + foreach (BasicBlock block in blocks) + { + for (LinkedListNode<INode> node = block.Operations.First; node != null; node = node.Next) + { + Operation operation = node.Value as Operation; + + if (operation.Inst == Instruction.Call) + { + Operand funcId = operation.GetSource(0); + + Debug.Assert(funcId.Type == OperandType.Constant); + + var fru = frus[funcId.Value]; + + Operand[] regs = new Operand[fru.InArguments.Length]; + + for (int i = 0; i < fru.InArguments.Length; i++) + { + regs[i] = OperandHelper.Register(fru.InArguments[i]); + } + + operation.AppendOperands(regs); + + for (int i = 0; i < fru.OutArguments.Length; i++) + { + Operation callOutArgOp = new Operation(Instruction.CallOutArgument, OperandHelper.Register(fru.OutArguments[i])); + + node = block.Operations.AddAfter(node, callOutArgOp); + } + } + } + } + } + + private static bool StartsWith(BasicBlock block, Instruction inst) + { + if (block.Operations.Count == 0) + { + return false; + } + + return block.Operations.First.Value is Operation operation && operation.Inst == inst; + } + + private static bool EndsWith(BasicBlock block, Instruction inst) + { + if (block.Operations.Count == 0) + { + return false; + } + + return block.Operations.Last.Value is Operation operation && operation.Inst == inst; + } + + private static RegisterMask GetMask(Register register) + { + Span<long> gprMasks = stackalloc long[4]; + long predMask = 0; + long flagMask = 0; + + switch (register.Type) + { + case RegisterType.Gpr: + gprMasks[register.Index >> 6] = 1L << (register.Index & 0x3f); + break; + case RegisterType.Predicate: + predMask = 1L << register.Index; + break; + case RegisterType.Flag: + flagMask = 1L << register.Index; + break; + } + + return new RegisterMask(gprMasks[0], gprMasks[1], gprMasks[2], gprMasks[3], predMask, flagMask); + } + + private static bool Exchange(RegisterMask[] masks, int blkIndex, RegisterMask value) + { + RegisterMask oldValue = masks[blkIndex]; + + masks[blkIndex] = value; + + return oldValue != value; + } + + private static void LoadLocals(BasicBlock block, RegisterMask masks, List<Register> inArguments) + { + bool fillArgsList = inArguments.Count == 0; + LinkedListNode<INode> node = null; + int argIndex = 0; + + for (int i = 0; i < TotalMasks; i++) + { + (RegisterType regType, int baseRegIndex) = GetRegTypeAndBaseIndex(i); + long mask = masks.GetMask(i); + + while (mask != 0) + { + int bit = BitOperations.TrailingZeroCount(mask); + + mask &= ~(1L << bit); + + Register register = new Register(baseRegIndex + bit, regType); + + if (fillArgsList) + { + inArguments.Add(register); + } + + Operation copyOp = new Operation(Instruction.Copy, OperandHelper.Register(register), OperandHelper.Argument(argIndex++)); + + if (node == null) + { + node = block.Operations.AddFirst(copyOp); + } + else + { + node = block.Operations.AddAfter(node, copyOp); + } + } + } + + Debug.Assert(argIndex <= inArguments.Count); + } + + private static void StoreLocals(BasicBlock block, RegisterMask masks, int inArgumentsCount, List<Register> outArguments) + { + LinkedListNode<INode> node = null; + int argIndex = inArgumentsCount; + bool fillArgsList = outArguments.Count == 0; + + for (int i = 0; i < TotalMasks; i++) + { + (RegisterType regType, int baseRegIndex) = GetRegTypeAndBaseIndex(i); + long mask = masks.GetMask(i); + + while (mask != 0) + { + int bit = BitOperations.TrailingZeroCount(mask); + + mask &= ~(1L << bit); + + Register register = new Register(baseRegIndex + bit, regType); + + if (fillArgsList) + { + outArguments.Add(register); + } + + Operation copyOp = new Operation(Instruction.Copy, OperandHelper.Argument(argIndex++), OperandHelper.Register(register)); + + if (node == null) + { + node = block.Operations.AddBefore(block.Operations.Last, copyOp); + } + else + { + node = block.Operations.AddAfter(node, copyOp); + } + } + } + + Debug.Assert(argIndex <= inArgumentsCount + outArguments.Count); + } + + private static (RegisterType RegType, int BaseRegIndex) GetRegTypeAndBaseIndex(int i) + { + RegisterType regType = RegisterType.Gpr; + int baseRegIndex = 0; + + if (i < GprMasks) + { + baseRegIndex = i * sizeof(long) * 8; + } + else if (i == GprMasks) + { + regType = RegisterType.Predicate; + } + else + { + regType = RegisterType.Flag; + } + + return (regType, baseRegIndex); + } + + private static bool EndsWithReturn(BasicBlock block) + { + if (!(block.GetLastOp() is Operation operation)) + { + return false; + } + + return operation.Inst == Instruction.Return; + } + } +}
\ No newline at end of file |