using Ryujinx.Graphics.Shader.IntermediateRepresentation; using System; using System.Collections.Generic; using System.Linq; using System.Text; using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper; namespace Ryujinx.Graphics.Shader.Translation.Optimizations { static class GlobalToStorage { private const int DriverReservedCb = 0; enum LsMemoryType { Local, Shared, } private class GtsContext { private readonly struct Entry { public readonly int FunctionId; public readonly Instruction Inst; public readonly StorageKind StorageKind; public readonly bool IsMultiTarget; public readonly IReadOnlyList TargetCbs; public Entry( int functionId, Instruction inst, StorageKind storageKind, bool isMultiTarget, IReadOnlyList targetCbs) { FunctionId = functionId; Inst = inst; StorageKind = storageKind; IsMultiTarget = isMultiTarget; TargetCbs = targetCbs; } } private readonly struct LsKey : IEquatable { public readonly Operand BaseOffset; public readonly int ConstOffset; public readonly LsMemoryType Type; public LsKey(Operand baseOffset, int constOffset, LsMemoryType type) { BaseOffset = baseOffset; ConstOffset = constOffset; Type = type; } public override int GetHashCode() { return HashCode.Combine(BaseOffset, ConstOffset, Type); } public override bool Equals(object obj) { return obj is LsKey other && Equals(other); } public bool Equals(LsKey other) { return other.BaseOffset == BaseOffset && other.ConstOffset == ConstOffset && other.Type == Type; } } private readonly List _entries; private readonly Dictionary> _sharedEntries; private readonly HelperFunctionManager _hfm; public GtsContext(HelperFunctionManager hfm) { _entries = new List(); _sharedEntries = new Dictionary>(); _hfm = hfm; } public int AddFunction(Operation baseOp, bool isMultiTarget, IReadOnlyList targetCbs, Function function) { int functionId = _hfm.AddFunction(function); _entries.Add(new Entry(functionId, baseOp.Inst, baseOp.StorageKind, isMultiTarget, targetCbs)); return functionId; } public bool TryGetFunctionId(Operation baseOp, bool isMultiTarget, IReadOnlyList targetCbs, out int functionId) { foreach (Entry entry in _entries) { if (entry.Inst != baseOp.Inst || entry.StorageKind != baseOp.StorageKind || entry.IsMultiTarget != isMultiTarget || entry.TargetCbs.Count != targetCbs.Count) { continue; } bool allEqual = true; for (int index = 0; index < targetCbs.Count; index++) { if (targetCbs[index] != entry.TargetCbs[index]) { allEqual = false; break; } } if (allEqual) { functionId = entry.FunctionId; return true; } } functionId = -1; return false; } public void AddMemoryTargetCb(LsMemoryType type, Operand baseOffset, int constOffset, uint targetCb, SearchResult result) { LsKey key = new(baseOffset, constOffset, type); if (!_sharedEntries.TryGetValue(key, out Dictionary targetCbs)) { // No entry with this base offset, create a new one. targetCbs = new Dictionary() { { targetCb, result } }; _sharedEntries.Add(key, targetCbs); } else if (targetCbs.TryGetValue(targetCb, out SearchResult existingResult)) { // If our entry already exists, but does not match the new result, // we set the offset to null to indicate there are multiple possible offsets. // This will be used on the multi-target access that does not need to know the offset. if (existingResult.Offset != null && (existingResult.Offset != result.Offset || existingResult.ConstOffset != result.ConstOffset)) { targetCbs[targetCb] = new SearchResult(result.SbCbSlot, result.SbCbOffset); } } else { // An entry for this base offset already exists, but not for the specified // constant buffer region where the storage buffer base address and size // comes from. targetCbs.Add(targetCb, result); } } public bool TryGetMemoryTargetCb(LsMemoryType type, Operand baseOffset, int constOffset, out SearchResult result) { LsKey key = new(baseOffset, constOffset, type); if (_sharedEntries.TryGetValue(key, out Dictionary targetCbs) && targetCbs.Count == 1) { SearchResult candidateResult = targetCbs.Values.First(); if (candidateResult.Found) { result = candidateResult; return true; } } result = default; return false; } } private readonly struct SearchResult { public static SearchResult NotFound => new(-1, 0); public bool Found => SbCbSlot != -1; public int SbCbSlot { get; } public int SbCbOffset { get; } public Operand Offset { get; } public int ConstOffset { get; } public SearchResult(int sbCbSlot, int sbCbOffset) { SbCbSlot = sbCbSlot; SbCbOffset = sbCbOffset; } public SearchResult(int sbCbSlot, int sbCbOffset, Operand offset, int constOffset = 0) { SbCbSlot = sbCbSlot; SbCbOffset = sbCbOffset; Offset = offset; ConstOffset = constOffset; } } public static void RunPass( HelperFunctionManager hfm, BasicBlock[] blocks, ResourceManager resourceManager, IGpuAccessor gpuAccessor, TargetLanguage targetLanguage) { GtsContext gtsContext = new(hfm); foreach (BasicBlock block in blocks) { for (LinkedListNode node = block.Operations.First; node != null; node = node.Next) { if (node.Value is not Operation operation) { continue; } if (IsGlobalMemory(operation.StorageKind)) { LinkedListNode nextNode = ReplaceGlobalMemoryWithStorage( gtsContext, resourceManager, gpuAccessor, targetLanguage, block, node); if (nextNode == null) { // The returned value being null means that the global memory replacement failed, // so we just make loads read 0 and stores do nothing. gpuAccessor.Log($"Failed to reserve storage buffer for global memory operation \"{operation.Inst}\"."); if (operation.Dest != null) { operation.TurnIntoCopy(Const(0)); } else { Utils.DeleteNode(node, operation); } } else { node = nextNode; } } else if (operation.Inst == Instruction.Store && (operation.StorageKind == StorageKind.SharedMemory || operation.StorageKind == StorageKind.LocalMemory)) { // The NVIDIA compiler can sometimes use shared or local memory as temporary // storage to place the base address and size on, so we need // to be able to find such information stored in memory too. if (TryGetMemoryOffsets(operation, out LsMemoryType type, out Operand baseOffset, out int constOffset)) { Operand value = operation.GetSource(operation.SourcesCount - 1); var result = FindUniqueBaseAddressCb(gtsContext, block, value, needsOffset: false); if (result.Found) { uint targetCb = PackCbSlotAndOffset(result.SbCbSlot, result.SbCbOffset); gtsContext.AddMemoryTargetCb(type, baseOffset, constOffset, targetCb, result); } } } } } } private static bool IsGlobalMemory(StorageKind storageKind) { return storageKind == StorageKind.GlobalMemory || storageKind == StorageKind.GlobalMemoryS8 || storageKind == StorageKind.GlobalMemoryS16 || storageKind == StorageKind.GlobalMemoryU8 || storageKind == StorageKind.GlobalMemoryU16; } private static bool IsSmallInt(StorageKind storageKind) { return storageKind == StorageKind.GlobalMemoryS8 || storageKind == StorageKind.GlobalMemoryS16 || storageKind == StorageKind.GlobalMemoryU8 || storageKind == StorageKind.GlobalMemoryU16; } private static LinkedListNode ReplaceGlobalMemoryWithStorage( GtsContext gtsContext, ResourceManager resourceManager, IGpuAccessor gpuAccessor, TargetLanguage targetLanguage, BasicBlock block, LinkedListNode node) { Operation operation = node.Value as Operation; Operand globalAddress = operation.GetSource(0); SearchResult result = FindUniqueBaseAddressCb(gtsContext, block, globalAddress, needsOffset: true); if (result.Found) { // We found the storage buffer that is being accessed. // There are two possible paths here, if the operation is simple enough, // we just generate the storage access code inline. // Otherwise, we generate a function call (and the function if necessary). Operand offset = result.Offset; bool storageUnaligned = gpuAccessor.QueryHasUnalignedStorageBuffer(); if (storageUnaligned) { Operand baseAddress = Cbuf(result.SbCbSlot, result.SbCbOffset); Operand baseAddressMasked = Local(); Operand hostOffset = Local(); int alignment = gpuAccessor.QueryHostStorageBufferOffsetAlignment(); Operation maskOp = new(Instruction.BitwiseAnd, baseAddressMasked, baseAddress, Const(-alignment)); Operation subOp = new(Instruction.Subtract, hostOffset, globalAddress, baseAddressMasked); node.List.AddBefore(node, maskOp); node.List.AddBefore(node, subOp); offset = hostOffset; } else if (result.ConstOffset != 0) { Operand newOffset = Local(); Operation addOp = new(Instruction.Add, newOffset, offset, Const(result.ConstOffset)); node.List.AddBefore(node, addOp); offset = newOffset; } if (CanUseInlineStorageOp(operation, targetLanguage)) { return GenerateInlineStorageOp(resourceManager, node, operation, offset, result); } else { if (!TryGenerateSingleTargetStorageOp( gtsContext, resourceManager, targetLanguage, operation, result, out int functionId)) { return null; } return GenerateCallStorageOp(node, operation, offset, functionId); } } else { // Failed to find the storage buffer directly. // Try to walk through Phi chains and find all possible constant buffers where // the base address might be stored. // Generate a helper function that will check all possible storage buffers and use the right one. if (!TryGenerateMultiTargetStorageOp( gtsContext, resourceManager, gpuAccessor, targetLanguage, block, operation, out int functionId)) { return null; } return GenerateCallStorageOp(node, operation, null, functionId); } } private static bool CanUseInlineStorageOp(Operation operation, TargetLanguage targetLanguage) { if (operation.StorageKind != StorageKind.GlobalMemory) { return false; } return (operation.Inst != Instruction.AtomicMaxS32 && operation.Inst != Instruction.AtomicMinS32) || targetLanguage == TargetLanguage.Spirv; } private static LinkedListNode GenerateInlineStorageOp( ResourceManager resourceManager, LinkedListNode node, Operation operation, Operand offset, SearchResult result) { bool isStore = operation.Inst == Instruction.Store || operation.Inst.IsAtomic(); if (!resourceManager.TryGetStorageBufferBinding(result.SbCbSlot, result.SbCbOffset, isStore, out int binding)) { return null; } Operand wordOffset = Local(); Operand[] sources; if (operation.Inst == Instruction.AtomicCompareAndSwap) { sources = new[] { Const(binding), Const(0), wordOffset, operation.GetSource(operation.SourcesCount - 2), operation.GetSource(operation.SourcesCount - 1), }; } else if (isStore) { sources = new[] { Const(binding), Const(0), wordOffset, operation.GetSource(operation.SourcesCount - 1) }; } else { sources = new[] { Const(binding), Const(0), wordOffset }; } Operation shiftOp = new(Instruction.ShiftRightU32, wordOffset, offset, Const(2)); Operation storageOp = new(operation.Inst, StorageKind.StorageBuffer, operation.Dest, sources); node.List.AddBefore(node, shiftOp); LinkedListNode newNode = node.List.AddBefore(node, storageOp); Utils.DeleteNode(node, operation); return newNode; } private static LinkedListNode GenerateCallStorageOp(LinkedListNode node, Operation operation, Operand offset, int functionId) { // Generate call to a helper function that will perform the storage buffer operation. Operand[] sources = new Operand[operation.SourcesCount - 1 + (offset == null ? 2 : 1)]; sources[0] = Const(functionId); if (offset != null) { // If the offset was supplised, we use that and skip the global address. sources[1] = offset; for (int srcIndex = 2; srcIndex < operation.SourcesCount; srcIndex++) { sources[srcIndex] = operation.GetSource(srcIndex); } } else { // Use the 64-bit global address which is split in 2 32-bit arguments. for (int srcIndex = 0; srcIndex < operation.SourcesCount; srcIndex++) { sources[srcIndex + 1] = operation.GetSource(srcIndex); } } bool returnsValue = operation.Dest != null; Operand returnValue = returnsValue ? Local() : null; Operation callOp = new(Instruction.Call, returnValue, sources); LinkedListNode newNode = node.List.AddBefore(node, callOp); if (returnsValue) { operation.TurnIntoCopy(returnValue); return node; } else { Utils.DeleteNode(node, operation); return newNode; } } private static bool TryGenerateSingleTargetStorageOp( GtsContext gtsContext, ResourceManager resourceManager, TargetLanguage targetLanguage, Operation operation, SearchResult result, out int functionId) { List targetCbs = new() { PackCbSlotAndOffset(result.SbCbSlot, result.SbCbOffset) }; if (gtsContext.TryGetFunctionId(operation, isMultiTarget: false, targetCbs, out functionId)) { return true; } int inArgumentsCount = 1; if (operation.Inst == Instruction.AtomicCompareAndSwap) { inArgumentsCount = 3; } else if (operation.Inst == Instruction.Store || operation.Inst.IsAtomic()) { inArgumentsCount = 2; } EmitterContext context = new(); Operand offset = Argument(0); Operand compare = null; Operand value = null; if (inArgumentsCount == 3) { compare = Argument(1); value = Argument(2); } else if (inArgumentsCount == 2) { value = Argument(1); } if (!TryGenerateStorageOp( resourceManager, targetLanguage, context, operation.Inst, operation.StorageKind, offset, compare, value, result, out Operand resultValue)) { functionId = 0; return false; } bool returnsValue = resultValue != null; if (returnsValue) { context.Return(resultValue); } else { context.Return(); } string functionName = GetFunctionName(operation, isMultiTarget: false, targetCbs); Function function = new( ControlFlowGraph.Create(context.GetOperations()).Blocks, functionName, returnsValue, inArgumentsCount, 0); functionId = gtsContext.AddFunction(operation, isMultiTarget: false, targetCbs, function); return true; } private static bool TryGenerateMultiTargetStorageOp( GtsContext gtsContext, ResourceManager resourceManager, IGpuAccessor gpuAccessor, TargetLanguage targetLanguage, BasicBlock block, Operation operation, out int functionId) { Queue phis = new(); HashSet visited = new(); List targetCbs = new(); Operand globalAddress = operation.GetSource(0); if (globalAddress.AsgOp is Operation addOp && addOp.Inst == Instruction.Add) { Operand src1 = addOp.GetSource(0); Operand src2 = addOp.GetSource(1); if (src1.Type == OperandType.Constant && src2.Type == OperandType.LocalVariable) { globalAddress = src2; } else if (src1.Type == OperandType.LocalVariable && src2.Type == OperandType.Constant) { globalAddress = src1; } } if (globalAddress.AsgOp is PhiNode phi && visited.Add(phi)) { phis.Enqueue(phi); } else { SearchResult result = FindUniqueBaseAddressCb(gtsContext, block, operation.GetSource(0), needsOffset: false); if (result.Found) { targetCbs.Add(PackCbSlotAndOffset(result.SbCbSlot, result.SbCbOffset)); } } while (phis.TryDequeue(out phi)) { for (int srcIndex = 0; srcIndex < phi.SourcesCount; srcIndex++) { BasicBlock phiBlock = phi.GetBlock(srcIndex); Operand phiSource = phi.GetSource(srcIndex); SearchResult result = FindUniqueBaseAddressCb(gtsContext, phiBlock, phiSource, needsOffset: false); if (result.Found) { uint targetCb = PackCbSlotAndOffset(result.SbCbSlot, result.SbCbOffset); if (!targetCbs.Contains(targetCb)) { targetCbs.Add(targetCb); } } else if (phiSource.AsgOp is PhiNode phi2 && visited.Add(phi2)) { phis.Enqueue(phi2); } } } targetCbs.Sort(); if (targetCbs.Count == 0) { gpuAccessor.Log($"Failed to find storage buffer for global memory operation \"{operation.Inst}\"."); } if (gtsContext.TryGetFunctionId(operation, isMultiTarget: true, targetCbs, out functionId)) { return true; } int inArgumentsCount = 2; if (operation.Inst == Instruction.AtomicCompareAndSwap) { inArgumentsCount = 4; } else if (operation.Inst == Instruction.Store || operation.Inst.IsAtomic()) { inArgumentsCount = 3; } EmitterContext context = new(); Operand globalAddressLow = Argument(0); Operand globalAddressHigh = Argument(1); foreach (uint targetCb in targetCbs) { (int sbCbSlot, int sbCbOffset) = UnpackCbSlotAndOffset(targetCb); Operand baseAddrLow = Cbuf(sbCbSlot, sbCbOffset); Operand baseAddrHigh = Cbuf(sbCbSlot, sbCbOffset + 1); Operand size = Cbuf(sbCbSlot, sbCbOffset + 2); Operand offset = context.ISubtract(globalAddressLow, baseAddrLow); Operand borrow = context.ICompareLessUnsigned(globalAddressLow, baseAddrLow); Operand inRangeLow = context.ICompareLessUnsigned(offset, size); Operand addrHighBorrowed = context.IAdd(globalAddressHigh, borrow); Operand inRangeHigh = context.ICompareEqual(addrHighBorrowed, baseAddrHigh); Operand inRange = context.BitwiseAnd(inRangeLow, inRangeHigh); Operand lblSkip = Label(); context.BranchIfFalse(lblSkip, inRange); Operand compare = null; Operand value = null; if (inArgumentsCount == 4) { compare = Argument(2); value = Argument(3); } else if (inArgumentsCount == 3) { value = Argument(2); } SearchResult result = new(sbCbSlot, sbCbOffset); int alignment = gpuAccessor.QueryHostStorageBufferOffsetAlignment(); Operand baseAddressMasked = context.BitwiseAnd(baseAddrLow, Const(-alignment)); Operand hostOffset = context.ISubtract(globalAddressLow, baseAddressMasked); if (!TryGenerateStorageOp( resourceManager, targetLanguage, context, operation.Inst, operation.StorageKind, hostOffset, compare, value, result, out Operand resultValue)) { functionId = 0; return false; } if (resultValue != null) { context.Return(resultValue); } else { context.Return(); } context.MarkLabel(lblSkip); } bool returnsValue = operation.Dest != null; if (returnsValue) { context.Return(Const(0)); } else { context.Return(); } string functionName = GetFunctionName(operation, isMultiTarget: true, targetCbs); Function function = new( ControlFlowGraph.Create(context.GetOperations()).Blocks, functionName, returnsValue, inArgumentsCount, 0); functionId = gtsContext.AddFunction(operation, isMultiTarget: true, targetCbs, function); return true; } private static uint PackCbSlotAndOffset(int cbSlot, int cbOffset) { return (uint)((ushort)cbSlot | ((ushort)cbOffset << 16)); } private static (int, int) UnpackCbSlotAndOffset(uint packed) { return ((ushort)packed, (ushort)(packed >> 16)); } private static string GetFunctionName(Operation baseOp, bool isMultiTarget, IReadOnlyList targetCbs) { StringBuilder nameBuilder = new(); nameBuilder.Append(baseOp.Inst.ToString()); nameBuilder.Append(baseOp.StorageKind switch { StorageKind.GlobalMemoryS8 => "S8", StorageKind.GlobalMemoryS16 => "S16", StorageKind.GlobalMemoryU8 => "U8", StorageKind.GlobalMemoryU16 => "U16", _ => string.Empty, }); if (isMultiTarget) { nameBuilder.Append("Multi"); } foreach (uint targetCb in targetCbs) { (int sbCbSlot, int sbCbOffset) = UnpackCbSlotAndOffset(targetCb); nameBuilder.Append($"_c{sbCbSlot}o{sbCbOffset}"); } return nameBuilder.ToString(); } private static bool TryGenerateStorageOp( ResourceManager resourceManager, TargetLanguage targetLanguage, EmitterContext context, Instruction inst, StorageKind storageKind, Operand offset, Operand compare, Operand value, SearchResult result, out Operand resultValue) { resultValue = null; bool isStore = inst.IsAtomic() || inst == Instruction.Store; if (!resourceManager.TryGetStorageBufferBinding(result.SbCbSlot, result.SbCbOffset, isStore, out int binding)) { return false; } Operand wordOffset = context.ShiftRightU32(offset, Const(2)); if (inst.IsAtomic()) { if (IsSmallInt(storageKind)) { throw new NotImplementedException(); } switch (inst) { case Instruction.AtomicAdd: resultValue = context.AtomicAdd(StorageKind.StorageBuffer, binding, Const(0), wordOffset, value); break; case Instruction.AtomicAnd: resultValue = context.AtomicAnd(StorageKind.StorageBuffer, binding, Const(0), wordOffset, value); break; case Instruction.AtomicCompareAndSwap: resultValue = context.AtomicCompareAndSwap(StorageKind.StorageBuffer, binding, Const(0), wordOffset, compare, value); break; case Instruction.AtomicMaxS32: if (targetLanguage == TargetLanguage.Spirv) { resultValue = context.AtomicMaxS32(StorageKind.StorageBuffer, binding, Const(0), wordOffset, value); } else { resultValue = GenerateAtomicCasLoop(context, wordOffset, binding, (memValue) => { return context.IMaximumS32(memValue, value); }); } break; case Instruction.AtomicMaxU32: resultValue = context.AtomicMaxU32(StorageKind.StorageBuffer, binding, Const(0), wordOffset, value); break; case Instruction.AtomicMinS32: if (targetLanguage == TargetLanguage.Spirv) { resultValue = context.AtomicMinS32(StorageKind.StorageBuffer, binding, Const(0), wordOffset, value); } else { resultValue = GenerateAtomicCasLoop(context, wordOffset, binding, (memValue) => { return context.IMinimumS32(memValue, value); }); } break; case Instruction.AtomicMinU32: resultValue = context.AtomicMinU32(StorageKind.StorageBuffer, binding, Const(0), wordOffset, value); break; case Instruction.AtomicOr: resultValue = context.AtomicOr(StorageKind.StorageBuffer, binding, Const(0), wordOffset, value); break; case Instruction.AtomicSwap: resultValue = context.AtomicSwap(StorageKind.StorageBuffer, binding, Const(0), wordOffset, value); break; case Instruction.AtomicXor: resultValue = context.AtomicXor(StorageKind.StorageBuffer, binding, Const(0), wordOffset, value); break; } } else if (inst == Instruction.Store) { int bitSize = storageKind switch { StorageKind.GlobalMemoryS8 or StorageKind.GlobalMemoryU8 => 8, StorageKind.GlobalMemoryS16 or StorageKind.GlobalMemoryU16 => 16, _ => 32, }; if (bitSize < 32) { Operand bitOffset = HelperFunctionManager.GetBitOffset(context, offset); GenerateAtomicCasLoop(context, wordOffset, binding, (memValue) => { return context.BitfieldInsert(memValue, value, bitOffset, Const(bitSize)); }); } else { context.Store(StorageKind.StorageBuffer, binding, Const(0), wordOffset, value); } } else { value = context.Load(StorageKind.StorageBuffer, binding, Const(0), wordOffset); if (IsSmallInt(storageKind)) { Operand bitOffset = HelperFunctionManager.GetBitOffset(context, offset); switch (storageKind) { case StorageKind.GlobalMemoryS8: value = context.ShiftRightS32(value, bitOffset); value = context.BitfieldExtractS32(value, Const(0), Const(8)); break; case StorageKind.GlobalMemoryS16: value = context.ShiftRightS32(value, bitOffset); value = context.BitfieldExtractS32(value, Const(0), Const(16)); break; case StorageKind.GlobalMemoryU8: value = context.ShiftRightU32(value, bitOffset); value = context.BitwiseAnd(value, Const(byte.MaxValue)); break; case StorageKind.GlobalMemoryU16: value = context.ShiftRightU32(value, bitOffset); value = context.BitwiseAnd(value, Const(ushort.MaxValue)); break; } } resultValue = value; } return true; } private static Operand GenerateAtomicCasLoop(EmitterContext context, Operand wordOffset, int binding, Func opCallback) { Operand lblLoopHead = Label(); context.MarkLabel(lblLoopHead); Operand oldValue = context.Load(StorageKind.StorageBuffer, binding, Const(0), wordOffset); Operand newValue = opCallback(oldValue); Operand casResult = context.AtomicCompareAndSwap( StorageKind.StorageBuffer, binding, Const(0), wordOffset, oldValue, newValue); Operand casFail = context.ICompareNotEqual(casResult, oldValue); context.BranchIfTrue(lblLoopHead, casFail); return oldValue; } private static SearchResult FindUniqueBaseAddressCb(GtsContext gtsContext, BasicBlock block, Operand globalAddress, bool needsOffset) { globalAddress = Utils.FindLastOperation(globalAddress, block); if (globalAddress.Type == OperandType.ConstantBuffer) { return GetBaseAddressCbWithOffset(globalAddress, Const(0), 0); } Operation operation = globalAddress.AsgOp as Operation; if (operation == null || operation.Inst != Instruction.Add) { return FindBaseAddressCbFromMemory(gtsContext, operation, 0, needsOffset); } Operand src1 = operation.GetSource(0); Operand src2 = operation.GetSource(1); int constOffset = 0; if ((src1.Type == OperandType.LocalVariable && src2.Type == OperandType.Constant) || (src2.Type == OperandType.LocalVariable && src1.Type == OperandType.Constant)) { Operand baseAddr; Operand offset; if (src1.Type == OperandType.LocalVariable) { baseAddr = Utils.FindLastOperation(src1, block); offset = src2; } else { baseAddr = Utils.FindLastOperation(src2, block); offset = src1; } var result = GetBaseAddressCbWithOffset(baseAddr, offset, 0); if (result.Found) { return result; } constOffset = offset.Value; operation = baseAddr.AsgOp as Operation; if (operation == null || operation.Inst != Instruction.Add) { return FindBaseAddressCbFromMemory(gtsContext, operation, constOffset, needsOffset); } } src1 = operation.GetSource(0); src2 = operation.GetSource(1); // If we have two possible results, we give preference to the ones from // the driver reserved constant buffer, as those are the ones that // contains the base address. // If both are constant buffer, give preference to the second operand, // because constant buffer are always encoded as the second operand, // so the second operand will always be the one from the last instruction. if (src1.Type != OperandType.ConstantBuffer || (src1.Type == OperandType.ConstantBuffer && src2.Type == OperandType.ConstantBuffer) || (src2.Type == OperandType.ConstantBuffer && src2.GetCbufSlot() == DriverReservedCb)) { return GetBaseAddressCbWithOffset(src2, src1, constOffset); } return GetBaseAddressCbWithOffset(src1, src2, constOffset); } private static SearchResult FindBaseAddressCbFromMemory(GtsContext gtsContext, Operation operation, int constOffset, bool needsOffset) { if (operation != null) { if (TryGetMemoryOffsets(operation, out LsMemoryType type, out Operand bo, out int co) && gtsContext.TryGetMemoryTargetCb(type, bo, co, out SearchResult result) && (result.Offset != null || !needsOffset)) { if (constOffset != 0) { return new SearchResult( result.SbCbSlot, result.SbCbOffset, result.Offset, result.ConstOffset + constOffset); } return result; } } return SearchResult.NotFound; } private static SearchResult GetBaseAddressCbWithOffset(Operand baseAddress, Operand offset, int constOffset) { if (baseAddress.Type == OperandType.ConstantBuffer) { int sbCbSlot = baseAddress.GetCbufSlot(); int sbCbOffset = baseAddress.GetCbufOffset(); // We require the offset to be aligned to 1 word (64 bits), // since the address size is 64-bit and the GPU only supports aligned memory access. if ((sbCbOffset & 1) == 0) { return new SearchResult(sbCbSlot, sbCbOffset, offset, constOffset); } } return SearchResult.NotFound; } private static bool TryGetMemoryOffsets(Operation operation, out LsMemoryType type, out Operand baseOffset, out int constOffset) { baseOffset = null; if (operation.Inst == Instruction.Load || operation.Inst == Instruction.Store) { if (operation.StorageKind == StorageKind.SharedMemory) { type = LsMemoryType.Shared; return TryGetSharedMemoryOffsets(operation, out baseOffset, out constOffset); } else if (operation.StorageKind == StorageKind.LocalMemory) { type = LsMemoryType.Local; return TryGetLocalMemoryOffset(operation, out constOffset); } } type = default; constOffset = 0; return false; } private static bool TryGetSharedMemoryOffsets(Operation operation, out Operand baseOffset, out int constOffset) { baseOffset = null; constOffset = 0; // The byte offset is right shifted by 2 to get the 32-bit word offset, // so we want to get the byte offset back, since each one of those word // offsets are a new "local variable" which will not match. if (operation.GetSource(1).AsgOp is Operation shiftRightOp && shiftRightOp.Inst == Instruction.ShiftRightU32 && shiftRightOp.GetSource(1).Type == OperandType.Constant && shiftRightOp.GetSource(1).Value == 2) { baseOffset = shiftRightOp.GetSource(0); } // Check if we have a constant offset being added to the base offset. if (baseOffset?.AsgOp is Operation addOp && addOp.Inst == Instruction.Add) { Operand src1 = addOp.GetSource(0); Operand src2 = addOp.GetSource(1); if (src1.Type == OperandType.Constant && src2.Type == OperandType.LocalVariable) { constOffset = src1.Value; baseOffset = src2; } else if (src1.Type == OperandType.LocalVariable && src2.Type == OperandType.Constant) { baseOffset = src1; constOffset = src2.Value; } } return baseOffset != null && baseOffset.Type == OperandType.LocalVariable; } private static bool TryGetLocalMemoryOffset(Operation operation, out int constOffset) { Operand offset = operation.GetSource(1); if (offset.Type == OperandType.Constant) { constOffset = offset.Value; return true; } constOffset = 0; return false; } } }