From d10cf55353175b13bed4cf18791e080ecb7fd95b Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Thu, 22 Apr 2021 16:17:59 -0300
Subject: shader: Implement indexed textures

---
 .../backend/spirv/emit_context.cpp                 | 79 +++++++++++++---------
 src/shader_recompiler/backend/spirv/emit_context.h | 11 ++-
 .../backend/spirv/emit_spirv_image.cpp             | 58 +++++++++-------
 3 files changed, 92 insertions(+), 56 deletions(-)

(limited to 'src/shader_recompiler/backend')

diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp
index 7f16cb0dc3..8e625f8fb9 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_context.cpp
@@ -380,6 +380,24 @@ Id CasLoop(EmitContext& ctx, Operation operation, Id array_pointer, Id element_p
     ctx.OpFunctionEnd();
     return func;
 }
+
+template <typename Desc>
+std::string NameOf(const Desc& desc, std::string_view prefix) {
+    if (desc.count > 1) {
+        return fmt::format("{}{}_{:02x}x{}", prefix, desc.cbuf_index, desc.cbuf_offset, desc.count);
+    } else {
+        return fmt::format("{}{}_{:02x}", prefix, desc.cbuf_index, desc.cbuf_offset);
+    }
+}
+
+Id DescType(EmitContext& ctx, Id sampled_type, Id pointer_type, u32 count) {
+    if (count > 1) {
+        const Id array_type{ctx.TypeArray(sampled_type, ctx.Const(count))};
+        return ctx.TypePointer(spv::StorageClass::UniformConstant, array_type);
+    } else {
+        return pointer_type;
+    }
+}
 } // Anonymous namespace
 
 void VectorTypes::Define(Sirit::Module& sirit_ctx, Id base_type, std::string_view name) {
@@ -971,12 +989,15 @@ void EmitContext::DefineTextureBuffers(const Info& info, u32& binding) {
         const Id id{AddGlobalVariable(type, spv::StorageClass::UniformConstant)};
         Decorate(id, spv::Decoration::Binding, binding);
         Decorate(id, spv::Decoration::DescriptorSet, 0U);
-        Name(id, fmt::format("texbuf{}_{:02x}", desc.cbuf_index, desc.cbuf_offset));
-        texture_buffers.insert(texture_buffers.end(), desc.count, id);
+        Name(id, NameOf(desc, "texbuf"));
+        texture_buffers.push_back({
+            .id = id,
+            .count = desc.count,
+        });
         if (profile.supported_spirv >= 0x00010400) {
             interfaces.push_back(id);
         }
-        binding += desc.count;
+        ++binding;
     }
 }
 
@@ -992,44 +1013,41 @@ void EmitContext::DefineImageBuffers(const Info& info, u32& binding) {
         const Id id{AddGlobalVariable(pointer_type, spv::StorageClass::UniformConstant)};
         Decorate(id, spv::Decoration::Binding, binding);
         Decorate(id, spv::Decoration::DescriptorSet, 0U);
-        Name(id, fmt::format("imgbuf{}_{:02x}", desc.cbuf_index, desc.cbuf_offset));
-        const ImageBufferDefinition def{
+        Name(id, NameOf(desc, "imgbuf"));
+        image_buffers.push_back({
             .id = id,
             .image_type = image_type,
-        };
-        image_buffers.insert(image_buffers.end(), desc.count, def);
+            .count = desc.count,
+        });
         if (profile.supported_spirv >= 0x00010400) {
             interfaces.push_back(id);
         }
-        binding += desc.count;
+        ++binding;
     }
 }
 
 void EmitContext::DefineTextures(const Info& info, u32& binding) {
     textures.reserve(info.texture_descriptors.size());
     for (const TextureDescriptor& desc : info.texture_descriptors) {
-        if (desc.count != 1) {
-            throw NotImplementedException("Array of textures");
-        }
         const Id image_type{ImageType(*this, desc)};
         const Id sampled_type{TypeSampledImage(image_type)};
         const Id pointer_type{TypePointer(spv::StorageClass::UniformConstant, sampled_type)};
-        const Id id{AddGlobalVariable(pointer_type, spv::StorageClass::UniformConstant)};
+        const Id desc_type{DescType(*this, sampled_type, pointer_type, desc.count)};
+        const Id id{AddGlobalVariable(desc_type, spv::StorageClass::UniformConstant)};
         Decorate(id, spv::Decoration::Binding, binding);
         Decorate(id, spv::Decoration::DescriptorSet, 0U);
-        Name(id, fmt::format("tex{}_{:02x}", desc.cbuf_index, desc.cbuf_offset));
-        for (u32 index = 0; index < desc.count; ++index) {
-            // TODO: Pass count info
-            textures.push_back(TextureDefinition{
-                .id{id},
-                .sampled_type{sampled_type},
-                .image_type{image_type},
-            });
-        }
+        Name(id, NameOf(desc, "tex"));
+        textures.push_back({
+            .id = id,
+            .sampled_type = sampled_type,
+            .pointer_type = pointer_type,
+            .image_type = image_type,
+            .count = desc.count,
+        });
         if (profile.supported_spirv >= 0x00010400) {
             interfaces.push_back(id);
         }
-        binding += desc.count;
+        ++binding;
     }
 }
 
@@ -1037,24 +1055,23 @@ void EmitContext::DefineImages(const Info& info, u32& binding) {
     images.reserve(info.image_descriptors.size());
     for (const ImageDescriptor& desc : info.image_descriptors) {
         if (desc.count != 1) {
-            throw NotImplementedException("Array of textures");
+            throw NotImplementedException("Array of images");
         }
         const Id image_type{ImageType(*this, desc)};
         const Id pointer_type{TypePointer(spv::StorageClass::UniformConstant, image_type)};
         const Id id{AddGlobalVariable(pointer_type, spv::StorageClass::UniformConstant)};
         Decorate(id, spv::Decoration::Binding, binding);
         Decorate(id, spv::Decoration::DescriptorSet, 0U);
-        Name(id, fmt::format("img{}_{:02x}", desc.cbuf_index, desc.cbuf_offset));
-        for (u32 index = 0; index < desc.count; ++index) {
-            images.push_back(ImageDefinition{
-                .id{id},
-                .image_type{image_type},
-            });
-        }
+        Name(id, NameOf(desc, "img"));
+        images.push_back({
+            .id = id,
+            .image_type = image_type,
+            .count = desc.count,
+        });
         if (profile.supported_spirv >= 0x00010400) {
             interfaces.push_back(id);
         }
-        binding += desc.count;
+        ++binding;
     }
 }
 
diff --git a/src/shader_recompiler/backend/spirv/emit_context.h b/src/shader_recompiler/backend/spirv/emit_context.h
index a4503c7ab7..c52544fb7b 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.h
+++ b/src/shader_recompiler/backend/spirv/emit_context.h
@@ -32,17 +32,26 @@ private:
 struct TextureDefinition {
     Id id;
     Id sampled_type;
+    Id pointer_type;
     Id image_type;
+    u32 count;
+};
+
+struct TextureBufferDefinition {
+    Id id;
+    u32 count;
 };
 
 struct ImageBufferDefinition {
     Id id;
     Id image_type;
+    u32 count;
 };
 
 struct ImageDefinition {
     Id id;
     Id image_type;
+    u32 count;
 };
 
 struct UniformDefinitions {
@@ -162,7 +171,7 @@ public:
 
     std::array<UniformDefinitions, Info::MAX_CBUFS> cbufs{};
     std::array<StorageDefinitions, Info::MAX_SSBOS> ssbos{};
-    std::vector<Id> texture_buffers;
+    std::vector<TextureBufferDefinition> texture_buffers;
     std::vector<ImageBufferDefinition> image_buffers;
     std::vector<TextureDefinition> textures;
     std::vector<ImageDefinition> images;
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_image.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_image.cpp
index 90817f1612..6008980afe 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_image.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_image.cpp
@@ -147,24 +147,31 @@ private:
     spv::ImageOperandsMask mask{};
 };
 
-Id Texture(EmitContext& ctx, const IR::Value& index) {
-    if (index.IsImmediate()) {
-        const TextureDefinition def{ctx.textures.at(index.U32())};
+Id Texture(EmitContext& ctx, IR::TextureInstInfo info, [[maybe_unused]] const IR::Value& index) {
+    const TextureDefinition& def{ctx.textures.at(info.descriptor_index)};
+    if (def.count > 1) {
+        const Id pointer{ctx.OpAccessChain(def.pointer_type, def.id, ctx.Def(index))};
+        return ctx.OpLoad(def.sampled_type, pointer);
+    } else {
         return ctx.OpLoad(def.sampled_type, def.id);
     }
-    throw NotImplementedException("Indirect texture sample");
 }
 
-Id TextureImage(EmitContext& ctx, const IR::Value& index, IR::TextureInstInfo info) {
-    if (!index.IsImmediate()) {
-        throw NotImplementedException("Indirect texture sample");
-    }
+Id TextureImage(EmitContext& ctx, IR::TextureInstInfo info,
+                [[maybe_unused]] const IR::Value& index) {
     if (info.type == TextureType::Buffer) {
-        const Id sampler_id{ctx.texture_buffers.at(index.U32())};
+        const TextureBufferDefinition& def{ctx.texture_buffers.at(info.descriptor_index)};
+        if (def.count > 1) {
+            throw NotImplementedException("Indirect texture sample");
+        }
+        const Id sampler_id{def.id};
         const Id id{ctx.OpLoad(ctx.sampled_texture_buffer_type, sampler_id)};
         return ctx.OpImage(ctx.image_buffer_type, id);
     } else {
-        const TextureDefinition def{ctx.textures.at(index.U32())};
+        const TextureDefinition& def{ctx.textures.at(info.descriptor_index)};
+        if (def.count > 1) {
+            throw NotImplementedException("Indirect texture sample");
+        }
         return ctx.OpImage(def.image_type, ctx.OpLoad(def.sampled_type, def.id));
     }
 }
@@ -311,7 +318,7 @@ Id EmitImageSampleImplicitLod(EmitContext& ctx, IR::Inst* inst, const IR::Value&
                                      bias_lc, offset);
         return Emit(&EmitContext::OpImageSparseSampleImplicitLod,
                     &EmitContext::OpImageSampleImplicitLod, ctx, inst, ctx.F32[4],
-                    Texture(ctx, index), coords, operands.Mask(), operands.Span());
+                    Texture(ctx, info, index), coords, operands.Mask(), operands.Span());
     } else {
         // We can't use implicit lods on non-fragment stages on SPIR-V. Maxwell hardware behaves as
         // if the lod was explicitly zero.  This may change on Turing with implicit compute
@@ -320,7 +327,7 @@ Id EmitImageSampleImplicitLod(EmitContext& ctx, IR::Inst* inst, const IR::Value&
         const ImageOperands operands(ctx, false, true, info.has_lod_clamp != 0, lod, offset);
         return Emit(&EmitContext::OpImageSparseSampleExplicitLod,
                     &EmitContext::OpImageSampleExplicitLod, ctx, inst, ctx.F32[4],
-                    Texture(ctx, index), coords, operands.Mask(), operands.Span());
+                    Texture(ctx, info, index), coords, operands.Mask(), operands.Span());
     }
 }
 
@@ -329,8 +336,8 @@ Id EmitImageSampleExplicitLod(EmitContext& ctx, IR::Inst* inst, const IR::Value&
     const auto info{inst->Flags<IR::TextureInstInfo>()};
     const ImageOperands operands(ctx, false, true, info.has_lod_clamp != 0, lod_lc, offset);
     return Emit(&EmitContext::OpImageSparseSampleExplicitLod,
-                &EmitContext::OpImageSampleExplicitLod, ctx, inst, ctx.F32[4], Texture(ctx, index),
-                coords, operands.Mask(), operands.Span());
+                &EmitContext::OpImageSampleExplicitLod, ctx, inst, ctx.F32[4],
+                Texture(ctx, info, index), coords, operands.Mask(), operands.Span());
 }
 
 Id EmitImageSampleDrefImplicitLod(EmitContext& ctx, IR::Inst* inst, const IR::Value& index,
@@ -340,7 +347,7 @@ Id EmitImageSampleDrefImplicitLod(EmitContext& ctx, IR::Inst* inst, const IR::Va
                                  offset);
     return Emit(&EmitContext::OpImageSparseSampleDrefImplicitLod,
                 &EmitContext::OpImageSampleDrefImplicitLod, ctx, inst, ctx.F32[1],
-                Texture(ctx, index), coords, dref, operands.Mask(), operands.Span());
+                Texture(ctx, info, index), coords, dref, operands.Mask(), operands.Span());
 }
 
 Id EmitImageSampleDrefExplicitLod(EmitContext& ctx, IR::Inst* inst, const IR::Value& index,
@@ -349,7 +356,7 @@ Id EmitImageSampleDrefExplicitLod(EmitContext& ctx, IR::Inst* inst, const IR::Va
     const ImageOperands operands(ctx, false, true, info.has_lod_clamp != 0, lod_lc, offset);
     return Emit(&EmitContext::OpImageSparseSampleDrefExplicitLod,
                 &EmitContext::OpImageSampleDrefExplicitLod, ctx, inst, ctx.F32[1],
-                Texture(ctx, index), coords, dref, operands.Mask(), operands.Span());
+                Texture(ctx, info, index), coords, dref, operands.Mask(), operands.Span());
 }
 
 Id EmitImageGather(EmitContext& ctx, IR::Inst* inst, const IR::Value& index, Id coords,
@@ -357,15 +364,17 @@ Id EmitImageGather(EmitContext& ctx, IR::Inst* inst, const IR::Value& index, Id
     const auto info{inst->Flags<IR::TextureInstInfo>()};
     const ImageOperands operands(ctx, offset, offset2);
     return Emit(&EmitContext::OpImageSparseGather, &EmitContext::OpImageGather, ctx, inst,
-                ctx.F32[4], Texture(ctx, index), coords, ctx.Const(info.gather_component),
+                ctx.F32[4], Texture(ctx, info, index), coords, ctx.Const(info.gather_component),
                 operands.Mask(), operands.Span());
 }
 
 Id EmitImageGatherDref(EmitContext& ctx, IR::Inst* inst, const IR::Value& index, Id coords,
                        const IR::Value& offset, const IR::Value& offset2, Id dref) {
+    const auto info{inst->Flags<IR::TextureInstInfo>()};
     const ImageOperands operands(ctx, offset, offset2);
     return Emit(&EmitContext::OpImageSparseDrefGather, &EmitContext::OpImageDrefGather, ctx, inst,
-                ctx.F32[4], Texture(ctx, index), coords, dref, operands.Mask(), operands.Span());
+                ctx.F32[4], Texture(ctx, info, index), coords, dref, operands.Mask(),
+                operands.Span());
 }
 
 Id EmitImageFetch(EmitContext& ctx, IR::Inst* inst, const IR::Value& index, Id coords, Id offset,
@@ -376,12 +385,12 @@ Id EmitImageFetch(EmitContext& ctx, IR::Inst* inst, const IR::Value& index, Id c
     }
     const ImageOperands operands(offset, lod, ms);
     return Emit(&EmitContext::OpImageSparseFetch, &EmitContext::OpImageFetch, ctx, inst, ctx.F32[4],
-                TextureImage(ctx, index, info), coords, operands.Mask(), operands.Span());
+                TextureImage(ctx, info, index), coords, operands.Mask(), operands.Span());
 }
 
 Id EmitImageQueryDimensions(EmitContext& ctx, IR::Inst* inst, const IR::Value& index, Id lod) {
     const auto info{inst->Flags<IR::TextureInstInfo>()};
-    const Id image{TextureImage(ctx, index, info)};
+    const Id image{TextureImage(ctx, info, index)};
     const Id zero{ctx.u32_zero_value};
     const auto mips{[&] { return ctx.OpImageQueryLevels(ctx.U32[1], image); }};
     switch (info.type) {
@@ -405,9 +414,10 @@ Id EmitImageQueryDimensions(EmitContext& ctx, IR::Inst* inst, const IR::Value& i
     throw LogicError("Unspecified image type {}", info.type.Value());
 }
 
-Id EmitImageQueryLod(EmitContext& ctx, IR::Inst*, const IR::Value& index, Id coords) {
+Id EmitImageQueryLod(EmitContext& ctx, IR::Inst* inst, const IR::Value& index, Id coords) {
+    const auto info{inst->Flags<IR::TextureInstInfo>()};
     const Id zero{ctx.f32_zero_value};
-    const Id sampler{Texture(ctx, index)};
+    const Id sampler{Texture(ctx, info, index)};
     return ctx.OpCompositeConstruct(ctx.F32[4], ctx.OpImageQueryLod(ctx.F32[2], sampler, coords),
                                     zero, zero);
 }
@@ -418,8 +428,8 @@ Id EmitImageGradient(EmitContext& ctx, IR::Inst* inst, const IR::Value& index, I
     const ImageOperands operands(ctx, info.has_lod_clamp != 0, derivates, info.num_derivates,
                                  offset, lod_clamp);
     return Emit(&EmitContext::OpImageSparseSampleExplicitLod,
-                &EmitContext::OpImageSampleExplicitLod, ctx, inst, ctx.F32[4], Texture(ctx, index),
-                coords, operands.Mask(), operands.Span());
+                &EmitContext::OpImageSampleExplicitLod, ctx, inst, ctx.F32[4],
+                Texture(ctx, info, index), coords, operands.Mask(), operands.Span());
 }
 
 Id EmitImageRead(EmitContext& ctx, IR::Inst* inst, const IR::Value& index, Id coords) {
-- 
cgit v1.2.3-70-g09d2