From e3514bcd6b09f623da14c4f3c4ffd988e75577ed Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Fri, 16 Apr 2021 16:31:15 -0300
Subject: spirv: Implement ViewportMask with NV_viewport_array2

---
 src/shader_recompiler/backend/spirv/emit_context.cpp               | 4 ++++
 src/shader_recompiler/backend/spirv/emit_context.h                 | 2 ++
 src/shader_recompiler/backend/spirv/emit_spirv.cpp                 | 4 ++++
 src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp | 5 +++++
 src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp          | 3 +++
 src/shader_recompiler/profile.h                                    | 1 +
 src/shader_recompiler/shader_info.h                                | 1 +
 7 files changed, 20 insertions(+)

(limited to 'src/shader_recompiler')

diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp
index 3946dab143..2f8678b4ec 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_context.cpp
@@ -457,6 +457,7 @@ void EmitContext::DefineCommonTypes(const Info& info) {
     input_s32 = Name(TypePointer(spv::StorageClass::Input, TypeInt(32, true)), "input_s32");
 
     output_f32 = Name(TypePointer(spv::StorageClass::Output, F32[1]), "output_f32");
+    output_u32 = Name(TypePointer(spv::StorageClass::Output, U32[1]), "output_u32");
 
     if (info.uses_int8) {
         AddCapability(spv::Capability::Int8);
@@ -1131,6 +1132,9 @@ void EmitContext::DefineOutputs(const IR::Program& program) {
         }
         viewport_index = DefineOutput(*this, U32[1], invocations, spv::BuiltIn::ViewportIndex);
     }
+    if (info.stores_viewport_mask && profile.support_viewport_mask) {
+        viewport_mask = DefineOutput(*this, TypeArray(U32[1], Constant(U32[1], 1u)), std::nullopt);
+    }
     for (size_t index = 0; index < info.stores_generics.size(); ++index) {
         if (info.stores_generics[index]) {
             DefineGenericOutput(*this, index, invocations);
diff --git a/src/shader_recompiler/backend/spirv/emit_context.h b/src/shader_recompiler/backend/spirv/emit_context.h
index c7d6f8a38f..c41cad098b 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.h
+++ b/src/shader_recompiler/backend/spirv/emit_context.h
@@ -134,6 +134,7 @@ public:
     Id input_s32{};
 
     Id output_f32{};
+    Id output_u32{};
 
     Id image_buffer_type{};
     Id sampled_texture_buffer_type{};
@@ -167,6 +168,7 @@ public:
     Id clip_distances{};
     Id layer{};
     Id viewport_index{};
+    Id viewport_mask{};
     Id primitive_id{};
 
     Id fswzadd_lut_a{};
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.cpp b/src/shader_recompiler/backend/spirv/emit_spirv.cpp
index 105602ccf5..90c4833a88 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv.cpp
@@ -303,6 +303,10 @@ void SetupCapabilities(const Profile& profile, const Info& info, EmitContext& ct
     if (info.stores_viewport_index) {
         ctx.AddCapability(spv::Capability::MultiViewport);
     }
+    if (info.stores_viewport_mask && profile.support_viewport_mask) {
+        ctx.AddExtension("SPV_NV_viewport_array2");
+        ctx.AddCapability(spv::Capability::ShaderViewportMaskNV);
+    }
     if (info.stores_layer || info.stores_viewport_index) {
         if (profile.support_viewport_index_layer_non_geometry && ctx.stage != Stage::Geometry) {
             ctx.AddExtension("SPV_EXT_shader_viewport_index_layer");
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
index f3de577f6c..ca067f1c43 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
@@ -99,6 +99,11 @@ std::optional<Id> OutputAttrPointer(EmitContext& ctx, IR::Attribute attr) {
                        ctx.stage == Shader::Stage::Geometry
                    ? std::optional<Id>{ctx.viewport_index}
                    : std::nullopt;
+    case IR::Attribute::ViewportMask:
+        if (!ctx.profile.support_viewport_mask) {
+            return std::nullopt;
+        }
+        return ctx.OpAccessChain(ctx.output_u32, ctx.viewport_mask, ctx.u32_zero_value);
     default:
         throw NotImplementedException("Read attribute {}", attr);
     }
diff --git a/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp b/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
index c84bf211fb..9631a445ee 100644
--- a/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
+++ b/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
@@ -96,6 +96,9 @@ void SetAttribute(Info& info, IR::Attribute attribute) {
     case IR::Attribute::ViewportIndex:
         info.stores_viewport_index = true;
         break;
+    case IR::Attribute::ViewportMask:
+        info.stores_viewport_mask = true;
+        break;
     default:
         throw NotImplementedException("Set attribute {}", attribute);
     }
diff --git a/src/shader_recompiler/profile.h b/src/shader_recompiler/profile.h
index 3a04f075ee..a2c2948d50 100644
--- a/src/shader_recompiler/profile.h
+++ b/src/shader_recompiler/profile.h
@@ -75,6 +75,7 @@ struct Profile {
     bool support_explicit_workgroup_layout{};
     bool support_vote{};
     bool support_viewport_index_layer_non_geometry{};
+    bool support_viewport_mask{};
     bool support_typeless_image_loads{};
     bool warp_size_potentially_larger_than_guest{};
     bool support_int64_atomics{};
diff --git a/src/shader_recompiler/shader_info.h b/src/shader_recompiler/shader_info.h
index d6cde15960..d33df8aad4 100644
--- a/src/shader_recompiler/shader_info.h
+++ b/src/shader_recompiler/shader_info.h
@@ -124,6 +124,7 @@ struct Info {
     bool stores_clip_distance{};
     bool stores_layer{};
     bool stores_viewport_index{};
+    bool stores_viewport_mask{};
     bool stores_tess_level_outer{};
     bool stores_tess_level_inner{};
     bool stores_indexed_attributes{};
-- 
cgit v1.2.3-70-g09d2