From 9e9aed41bebc1b7d29dbfcddcc203693bcdc680e Mon Sep 17 00:00:00 2001
From: Liam <byteslice@airmail.cc>
Date: Tue, 19 Dec 2023 10:55:56 -0500
Subject: shader_recompiler: use float image operations on load/store when
 required

---
 .../backend/spirv/spirv_emit_context.cpp           | 22 +++++++++++++---------
 1 file changed, 13 insertions(+), 9 deletions(-)

(limited to 'src/shader_recompiler/backend/spirv/spirv_emit_context.cpp')

diff --git a/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp b/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp
index 2abc21a173..ed023fcfe0 100644
--- a/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp
@@ -74,20 +74,19 @@ spv::ImageFormat GetImageFormat(ImageFormat format) {
     throw InvalidArgument("Invalid image format {}", format);
 }
 
-Id ImageType(EmitContext& ctx, const ImageDescriptor& desc) {
+Id ImageType(EmitContext& ctx, const ImageDescriptor& desc, Id sampled_type) {
     const spv::ImageFormat format{GetImageFormat(desc.format)};
-    const Id type{ctx.U32[1]};
     switch (desc.type) {
     case TextureType::Color1D:
-        return ctx.TypeImage(type, spv::Dim::Dim1D, false, false, false, 2, format);
+        return ctx.TypeImage(sampled_type, spv::Dim::Dim1D, false, false, false, 2, format);
     case TextureType::ColorArray1D:
-        return ctx.TypeImage(type, spv::Dim::Dim1D, false, true, false, 2, format);
+        return ctx.TypeImage(sampled_type, spv::Dim::Dim1D, false, true, false, 2, format);
     case TextureType::Color2D:
-        return ctx.TypeImage(type, spv::Dim::Dim2D, false, false, false, 2, format);
+        return ctx.TypeImage(sampled_type, spv::Dim::Dim2D, false, false, false, 2, format);
     case TextureType::ColorArray2D:
-        return ctx.TypeImage(type, spv::Dim::Dim2D, false, true, false, 2, format);
+        return ctx.TypeImage(sampled_type, spv::Dim::Dim2D, false, true, false, 2, format);
     case TextureType::Color3D:
-        return ctx.TypeImage(type, spv::Dim::Dim3D, false, false, false, 2, format);
+        return ctx.TypeImage(sampled_type, spv::Dim::Dim3D, false, false, false, 2, format);
     case TextureType::Buffer:
         throw NotImplementedException("Image buffer");
     default:
@@ -1273,7 +1272,9 @@ void EmitContext::DefineImageBuffers(const Info& info, u32& binding) {
             throw NotImplementedException("Array of image buffers");
         }
         const spv::ImageFormat format{GetImageFormat(desc.format)};
-        const Id image_type{TypeImage(U32[1], spv::Dim::Buffer, false, false, false, 2, format)};
+        const Id sampled_type{desc.is_integer ? U32[1] : F32[1]};
+        const Id image_type{
+            TypeImage(sampled_type, spv::Dim::Buffer, false, false, false, 2, format)};
         const Id pointer_type{TypePointer(spv::StorageClass::UniformConstant, image_type)};
         const Id id{AddGlobalVariable(pointer_type, spv::StorageClass::UniformConstant)};
         Decorate(id, spv::Decoration::Binding, binding);
@@ -1283,6 +1284,7 @@ void EmitContext::DefineImageBuffers(const Info& info, u32& binding) {
             .id = id,
             .image_type = image_type,
             .count = desc.count,
+            .is_integer = desc.is_integer,
         });
         if (profile.supported_spirv >= 0x00010400) {
             interfaces.push_back(id);
@@ -1327,7 +1329,8 @@ void EmitContext::DefineImages(const Info& info, u32& binding, u32& scaling_inde
         if (desc.count != 1) {
             throw NotImplementedException("Array of images");
         }
-        const Id image_type{ImageType(*this, desc)};
+        const Id sampled_type{desc.is_integer ? U32[1] : F32[1]};
+        const Id image_type{ImageType(*this, desc, sampled_type)};
         const Id pointer_type{TypePointer(spv::StorageClass::UniformConstant, image_type)};
         const Id id{AddGlobalVariable(pointer_type, spv::StorageClass::UniformConstant)};
         Decorate(id, spv::Decoration::Binding, binding);
@@ -1337,6 +1340,7 @@ void EmitContext::DefineImages(const Info& info, u32& binding, u32& scaling_inde
             .id = id,
             .image_type = image_type,
             .count = desc.count,
+            .is_integer = desc.is_integer,
         });
         if (profile.supported_spirv >= 0x00010400) {
             interfaces.push_back(id);
-- 
cgit v1.2.3-70-g09d2