From bbfad79c89f9b7886005d39b51129bcfd94830b8 Mon Sep 17 00:00:00 2001
From: Billy Laws <blaws05@gmail.com>
Date: Tue, 2 Aug 2022 17:41:41 +0100
Subject: Vulkan: Add a workaround for input_position on Adreno drivers

Adreno drivers will crash compiling geometry shaders if the input position is not wrapped in a gl_in struct.
---
 .../backend/spirv/emit_spirv_context_get_set.cpp   |  7 ++--
 .../backend/spirv/spirv_emit_context.cpp           | 42 +++++++++++++++++-----
 .../backend/spirv/spirv_emit_context.h             |  1 +
 3 files changed, 39 insertions(+), 11 deletions(-)

(limited to 'src/shader_recompiler/backend/spirv')

diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
index db9c94ce8e..1590debc4c 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
@@ -321,8 +321,11 @@ Id EmitGetAttribute(EmitContext& ctx, IR::Attribute attr, Id vertex) {
     case IR::Attribute::PositionY:
     case IR::Attribute::PositionZ:
     case IR::Attribute::PositionW:
-        return ctx.OpLoad(ctx.F32[1], AttrPointer(ctx, ctx.input_f32, vertex, ctx.input_position,
-                                                  ctx.Const(element)));
+        return ctx.OpLoad(ctx.F32[1], ctx.need_input_position_indirect ?
+                                      AttrPointer(ctx, ctx.input_f32, vertex, ctx.input_position,
+                                                  ctx.u32_zero_value, ctx.Const(element))
+                                      : AttrPointer(ctx, ctx.input_f32, vertex, ctx.input_position,
+                                                    ctx.Const(element)));
     case IR::Attribute::InstanceId:
         if (ctx.profile.support_vertex_instance_id) {
             return ctx.OpBitcast(ctx.F32[1], ctx.OpLoad(ctx.U32[1], ctx.instance_id));
diff --git a/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp b/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp
index ecb2db4940..eb1e3e0b6e 100644
--- a/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp
@@ -721,9 +721,21 @@ void EmitContext::DefineAttributeMemAccess(const Info& info) {
         size_t label_index{0};
         if (info.loads.AnyComponent(IR::Attribute::PositionX)) {
             AddLabel(labels[label_index]);
-            const Id pointer{is_array
-                                 ? OpAccessChain(input_f32, input_position, vertex, masked_index)
-                                 : OpAccessChain(input_f32, input_position, masked_index)};
+            const Id pointer{[&]() {
+                if (need_input_position_indirect) {
+                    if (is_array)
+                        return OpAccessChain(input_f32, input_position, vertex, u32_zero_value,
+                                             masked_index);
+                    else
+                        return OpAccessChain(input_f32, input_position, u32_zero_value,
+                                             masked_index);
+                }  else {
+                    if (is_array)
+                        return OpAccessChain(input_f32, input_position, vertex, masked_index);
+                    else
+                        return OpAccessChain(input_f32, input_position, masked_index);
+                }
+            }()};
             const Id result{OpLoad(F32[1], pointer)};
             OpReturnValue(result);
             ++label_index;
@@ -1367,12 +1379,24 @@ void EmitContext::DefineInputs(const IR::Program& program) {
         Decorate(layer, spv::Decoration::Flat);
     }
     if (loads.AnyComponent(IR::Attribute::PositionX)) {
-        const bool is_fragment{stage != Stage::Fragment};
-        const spv::BuiltIn built_in{is_fragment ? spv::BuiltIn::Position : spv::BuiltIn::FragCoord};
-        input_position = DefineInput(*this, F32[4], true, built_in);
-        if (profile.support_geometry_shader_passthrough) {
-            if (info.passthrough.AnyComponent(IR::Attribute::PositionX)) {
-                Decorate(input_position, spv::Decoration::PassthroughNV);
+        const bool is_fragment{stage == Stage::Fragment};
+        if (!is_fragment && profile.has_broken_spirv_position_input) {
+            need_input_position_indirect = true;
+
+            const Id input_position_struct = TypeStruct(F32[4]);
+            input_position = DefineInput(*this, input_position_struct, true);
+
+            MemberDecorate(input_position_struct, 0, spv::Decoration::BuiltIn,
+                           static_cast<unsigned>(spv::BuiltIn::Position));
+            Decorate(input_position_struct, spv::Decoration::Block);
+        } else {
+            const spv::BuiltIn built_in{is_fragment ? spv::BuiltIn::FragCoord : spv::BuiltIn::Position};
+            input_position = DefineInput(*this, F32[4], true, built_in);
+
+            if (profile.support_geometry_shader_passthrough) {
+                if (info.passthrough.AnyComponent(IR::Attribute::PositionX)) {
+                    Decorate(input_position, spv::Decoration::PassthroughNV);
+                }
             }
         }
     }
diff --git a/src/shader_recompiler/backend/spirv/spirv_emit_context.h b/src/shader_recompiler/backend/spirv/spirv_emit_context.h
index 4414a51696..dbc5c55b90 100644
--- a/src/shader_recompiler/backend/spirv/spirv_emit_context.h
+++ b/src/shader_recompiler/backend/spirv/spirv_emit_context.h
@@ -280,6 +280,7 @@ public:
     Id write_global_func_u32x2{};
     Id write_global_func_u32x4{};
 
+    bool need_input_position_indirect{};
     Id input_position{};
     std::array<Id, 32> input_generics{};
 
-- 
cgit v1.2.3-70-g09d2