From 14ac0c2923c41df9c6fc4833d2a8e46a6efe5b59 Mon Sep 17 00:00:00 2001
From: ameerj <52414509+ameerj@users.noreply.github.com>
Date: Fri, 24 Dec 2021 20:00:28 -0500
Subject: shader: Add integer attribute get optimization pass

Works around an nvidia driver bug, where casting the integer attributes to float and back to an integer always returned 0.
---
 .../backend/glasm/emit_glasm_context_get_set.cpp   | 16 ++++++++++++++
 .../backend/glasm/emit_glasm_instructions.h        |  1 +
 .../backend/glsl/emit_glsl_context_get_set.cpp     | 16 ++++++++++++++
 .../backend/glsl/emit_glsl_instructions.h          |  2 ++
 .../backend/spirv/emit_spirv_context_get_set.cpp   | 25 ++++++++++++++++++++++
 .../backend/spirv/emit_spirv_instructions.h        |  1 +
 src/shader_recompiler/frontend/ir/opcodes.inc      |  1 +
 .../ir_opt/collect_shader_info_pass.cpp            |  1 +
 .../ir_opt/constant_propagation_pass.cpp           | 23 ++++++++++++++++++++
 9 files changed, 86 insertions(+)

(limited to 'src')

diff --git a/src/shader_recompiler/backend/glasm/emit_glasm_context_get_set.cpp b/src/shader_recompiler/backend/glasm/emit_glasm_context_get_set.cpp
index 081b2c8e03..c0f5fc4024 100644
--- a/src/shader_recompiler/backend/glasm/emit_glasm_context_get_set.cpp
+++ b/src/shader_recompiler/backend/glasm/emit_glasm_context_get_set.cpp
@@ -126,6 +126,22 @@ void EmitGetAttribute(EmitContext& ctx, IR::Inst& inst, IR::Attribute attr, Scal
     }
 }
 
+void EmitGetAttributeU32(EmitContext& ctx, IR::Inst& inst, IR::Attribute attr, ScalarU32) {
+    switch (attr) {
+    case IR::Attribute::PrimitiveId:
+        ctx.Add("MOV.S {}.x,primitive.id;", inst);
+        break;
+    case IR::Attribute::InstanceId:
+        ctx.Add("MOV.S {}.x,{}.instance;", inst, ctx.attrib_name);
+        break;
+    case IR::Attribute::VertexId:
+        ctx.Add("MOV.S {}.x,{}.id;", inst, ctx.attrib_name);
+        break;
+    default:
+        throw NotImplementedException("Get U32 attribute {}", attr);
+    }
+}
+
 void EmitSetAttribute(EmitContext& ctx, IR::Attribute attr, ScalarF32 value,
                       [[maybe_unused]] ScalarU32 vertex) {
     const u32 element{static_cast<u32>(attr) % 4};
diff --git a/src/shader_recompiler/backend/glasm/emit_glasm_instructions.h b/src/shader_recompiler/backend/glasm/emit_glasm_instructions.h
index 1f343bff5e..b480078563 100644
--- a/src/shader_recompiler/backend/glasm/emit_glasm_instructions.h
+++ b/src/shader_recompiler/backend/glasm/emit_glasm_instructions.h
@@ -50,6 +50,7 @@ void EmitGetCbufU32(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding,
 void EmitGetCbufF32(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding, ScalarU32 offset);
 void EmitGetCbufU32x2(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding, ScalarU32 offset);
 void EmitGetAttribute(EmitContext& ctx, IR::Inst& inst, IR::Attribute attr, ScalarU32 vertex);
+void EmitGetAttributeU32(EmitContext& ctx, IR::Inst& inst, IR::Attribute attr, ScalarU32 vertex);
 void EmitSetAttribute(EmitContext& ctx, IR::Attribute attr, ScalarF32 value, ScalarU32 vertex);
 void EmitGetAttributeIndexed(EmitContext& ctx, IR::Inst& inst, ScalarS32 offset, ScalarU32 vertex);
 void EmitSetAttributeIndexed(EmitContext& ctx, ScalarU32 offset, ScalarF32 value, ScalarU32 vertex);
diff --git a/src/shader_recompiler/backend/glsl/emit_glsl_context_get_set.cpp b/src/shader_recompiler/backend/glsl/emit_glsl_context_get_set.cpp
index 6477bd1928..5ef46d6343 100644
--- a/src/shader_recompiler/backend/glsl/emit_glsl_context_get_set.cpp
+++ b/src/shader_recompiler/backend/glsl/emit_glsl_context_get_set.cpp
@@ -221,6 +221,22 @@ void EmitGetAttribute(EmitContext& ctx, IR::Inst& inst, IR::Attribute attr,
     }
 }
 
+void EmitGetAttributeU32(EmitContext& ctx, IR::Inst& inst, IR::Attribute attr, std::string_view) {
+    switch (attr) {
+    case IR::Attribute::PrimitiveId:
+        ctx.AddU32("{}=uint(gl_PrimitiveID);", inst);
+        break;
+    case IR::Attribute::InstanceId:
+        ctx.AddU32("{}=uint(gl_InstanceID);", inst);
+        break;
+    case IR::Attribute::VertexId:
+        ctx.AddU32("{}=uint(gl_VertexID);", inst);
+        break;
+    default:
+        throw NotImplementedException("Get U32 attribute {}", attr);
+    }
+}
+
 void EmitSetAttribute(EmitContext& ctx, IR::Attribute attr, std::string_view value,
                       [[maybe_unused]] std::string_view vertex) {
     if (IR::IsGeneric(attr)) {
diff --git a/src/shader_recompiler/backend/glsl/emit_glsl_instructions.h b/src/shader_recompiler/backend/glsl/emit_glsl_instructions.h
index f86502e4c8..6cabbc717d 100644
--- a/src/shader_recompiler/backend/glsl/emit_glsl_instructions.h
+++ b/src/shader_recompiler/backend/glsl/emit_glsl_instructions.h
@@ -60,6 +60,8 @@ void EmitGetCbufU32x2(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding
                       const IR::Value& offset);
 void EmitGetAttribute(EmitContext& ctx, IR::Inst& inst, IR::Attribute attr,
                       std::string_view vertex);
+void EmitGetAttributeU32(EmitContext& ctx, IR::Inst& inst, IR::Attribute attr,
+                         std::string_view vertex);
 void EmitSetAttribute(EmitContext& ctx, IR::Attribute attr, std::string_view value,
                       std::string_view vertex);
 void EmitGetAttributeIndexed(EmitContext& ctx, IR::Inst& inst, std::string_view offset,
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
index 14f470812e..8ea730c807 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
@@ -355,6 +355,31 @@ Id EmitGetAttribute(EmitContext& ctx, IR::Attribute attr, Id vertex) {
     }
 }
 
+Id EmitGetAttributeU32(EmitContext& ctx, IR::Attribute attr, Id) {
+    switch (attr) {
+    case IR::Attribute::PrimitiveId:
+        return ctx.OpLoad(ctx.U32[1], ctx.primitive_id);
+    case IR::Attribute::InstanceId:
+        if (ctx.profile.support_vertex_instance_id) {
+            return ctx.OpLoad(ctx.U32[1], ctx.instance_id);
+        } else {
+            const Id index{ctx.OpLoad(ctx.U32[1], ctx.instance_index)};
+            const Id base{ctx.OpLoad(ctx.U32[1], ctx.base_instance)};
+            return ctx.OpISub(ctx.U32[1], index, base);
+        }
+    case IR::Attribute::VertexId:
+        if (ctx.profile.support_vertex_instance_id) {
+            return ctx.OpLoad(ctx.U32[1], ctx.vertex_id);
+        } else {
+            const Id index{ctx.OpLoad(ctx.U32[1], ctx.vertex_index)};
+            const Id base{ctx.OpLoad(ctx.U32[1], ctx.base_vertex)};
+            return ctx.OpISub(ctx.U32[1], index, base);
+        }
+    default:
+        throw NotImplementedException("Read U32 attribute {}", attr);
+    }
+}
+
 void EmitSetAttribute(EmitContext& ctx, IR::Attribute attr, Id value, [[maybe_unused]] Id vertex) {
     const std::optional<OutAttr> output{OutputAttrPointer(ctx, attr)};
     if (!output) {
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h b/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h
index 6cd22dd3ef..887112deb4 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h
@@ -53,6 +53,7 @@ Id EmitGetCbufU32(EmitContext& ctx, const IR::Value& binding, const IR::Value& o
 Id EmitGetCbufF32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset);
 Id EmitGetCbufU32x2(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset);
 Id EmitGetAttribute(EmitContext& ctx, IR::Attribute attr, Id vertex);
+Id EmitGetAttributeU32(EmitContext& ctx, IR::Attribute attr, Id vertex);
 void EmitSetAttribute(EmitContext& ctx, IR::Attribute attr, Id value, Id vertex);
 Id EmitGetAttributeIndexed(EmitContext& ctx, Id offset, Id vertex);
 void EmitSetAttributeIndexed(EmitContext& ctx, Id offset, Id value, Id vertex);
diff --git a/src/shader_recompiler/frontend/ir/opcodes.inc b/src/shader_recompiler/frontend/ir/opcodes.inc
index 6929919df1..b94ce74061 100644
--- a/src/shader_recompiler/frontend/ir/opcodes.inc
+++ b/src/shader_recompiler/frontend/ir/opcodes.inc
@@ -40,6 +40,7 @@ OPCODE(GetCbufU32,                                          U32,            U32,
 OPCODE(GetCbufF32,                                          F32,            U32,            U32,                                                            )
 OPCODE(GetCbufU32x2,                                        U32x2,          U32,            U32,                                                            )
 OPCODE(GetAttribute,                                        F32,            Attribute,      U32,                                                            )
+OPCODE(GetAttributeU32,                                     U32,            Attribute,      U32,                                                            )
 OPCODE(SetAttribute,                                        Void,           Attribute,      F32,            U32,                                            )
 OPCODE(GetAttributeIndexed,                                 F32,            U32,            U32,                                                            )
 OPCODE(SetAttributeIndexed,                                 Void,           U32,            F32,            U32,                                            )
diff --git a/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp b/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
index 1e476d83d9..a78c469be1 100644
--- a/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
+++ b/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
@@ -389,6 +389,7 @@ void VisitUsages(Info& info, IR::Inst& inst) {
         info.uses_demote_to_helper_invocation = true;
         break;
     case IR::Opcode::GetAttribute:
+    case IR::Opcode::GetAttributeU32:
         info.loads.mask[static_cast<size_t>(inst.Arg(0).Attribute())] = true;
         break;
     case IR::Opcode::SetAttribute:
diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
index d089fdd12f..c134a12bc6 100644
--- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
+++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
@@ -505,6 +505,29 @@ void FoldBitCast(IR::Inst& inst, IR::Opcode reverse) {
             return;
         }
     }
+    if constexpr (op == IR::Opcode::BitCastU32F32) {
+        // Workaround for new NVIDIA driver bug, where:
+        // uint attr = ftou(itof(gl_InstanceID));
+        // always returned 0.
+        // We can instead manually optimize this and work around the driver bug:
+        // uint attr = uint(gl_InstanceID);
+        if (arg_inst->GetOpcode() == IR::Opcode::GetAttribute) {
+            const IR::Attribute attr{arg_inst->Arg(0).Attribute()};
+            switch (attr) {
+            case IR::Attribute::PrimitiveId:
+            case IR::Attribute::InstanceId:
+            case IR::Attribute::VertexId:
+                break;
+            default:
+                return;
+            }
+            // Replace the bitcasts with an integer attribute get
+            inst.ReplaceOpcode(IR::Opcode::GetAttributeU32);
+            inst.SetArg(0, arg_inst->Arg(0));
+            inst.SetArg(1, arg_inst->Arg(1));
+            return;
+        }
+    }
 }
 
 void FoldInverseFunc(IR::Inst& inst, IR::Opcode reverse) {
-- 
cgit v1.2.3-70-g09d2