From d8ec99dadaa033aa440671572ed38e2614815e11 Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Wed, 14 Apr 2021 18:09:18 -0300
Subject: spirv: Implement Layer stores

---
 src/shader_recompiler/backend/spirv/emit_context.cpp     |  9 ++++++++-
 src/shader_recompiler/backend/spirv/emit_context.h       |  1 +
 src/shader_recompiler/backend/spirv/emit_spirv.cpp       | 16 ++++++++++------
 .../backend/spirv/emit_spirv_context_get_set.cpp         |  9 +++++++--
 4 files changed, 26 insertions(+), 9 deletions(-)

(limited to 'src/shader_recompiler/backend')

diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp
index 74c42233d7..f96d5ae37d 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_context.cpp
@@ -1050,8 +1050,15 @@ void EmitContext::DefineOutputs(const Info& info) {
         const Id type{TypeArray(F32[1], Constant(U32[1], 8U))};
         clip_distances = DefineOutput(*this, type, spv::BuiltIn::ClipDistance);
     }
+    if (info.stores_layer &&
+        (profile.support_viewport_index_layer_non_geometry || stage == Stage::Geometry)) {
+        if (stage == Stage::Fragment) {
+            throw NotImplementedException("Storing Layer in fragment stage");
+        }
+        layer = DefineOutput(*this, U32[1], spv::BuiltIn::Layer);
+    }
     if (info.stores_viewport_index &&
-        (profile.support_viewport_index_layer_non_geometry || stage == Shader::Stage::Geometry)) {
+        (profile.support_viewport_index_layer_non_geometry || stage == Stage::Geometry)) {
         if (stage == Stage::Fragment) {
             throw NotImplementedException("Storing ViewportIndex in fragment stage");
         }
diff --git a/src/shader_recompiler/backend/spirv/emit_context.h b/src/shader_recompiler/backend/spirv/emit_context.h
index b27e5540c9..1f0d8be774 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.h
+++ b/src/shader_recompiler/backend/spirv/emit_context.h
@@ -157,6 +157,7 @@ public:
     Id front_face{};
     Id point_coord{};
     Id clip_distances{};
+    Id layer{};
     Id viewport_index{};
 
     Id fswzadd_lut_a{};
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.cpp b/src/shader_recompiler/backend/spirv/emit_spirv.cpp
index 444ba276f7..3bf4c6a9ec 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv.cpp
@@ -124,17 +124,17 @@ void DefineEntryPoint(const IR::Program& program, EmitContext& ctx, Id main) {
     const std::span interfaces(ctx.interfaces.data(), ctx.interfaces.size());
     spv::ExecutionModel execution_model{};
     switch (program.stage) {
-    case Shader::Stage::Compute: {
+    case Stage::Compute: {
         const std::array<u32, 3> workgroup_size{program.workgroup_size};
         execution_model = spv::ExecutionModel::GLCompute;
         ctx.AddExecutionMode(main, spv::ExecutionMode::LocalSize, workgroup_size[0],
                              workgroup_size[1], workgroup_size[2]);
         break;
     }
-    case Shader::Stage::VertexB:
+    case Stage::VertexB:
         execution_model = spv::ExecutionModel::Vertex;
         break;
-    case Shader::Stage::Geometry:
+    case Stage::Geometry:
         execution_model = spv::ExecutionModel::Geometry;
         ctx.AddCapability(spv::Capability::Geometry);
         ctx.AddCapability(spv::Capability::GeometryStreams);
@@ -172,7 +172,7 @@ void DefineEntryPoint(const IR::Program& program, EmitContext& ctx, Id main) {
         ctx.AddExecutionMode(main, spv::ExecutionMode::OutputVertices, program.output_vertices);
         ctx.AddExecutionMode(main, spv::ExecutionMode::Invocations, program.invocations);
         break;
-    case Shader::Stage::Fragment:
+    case Stage::Fragment:
         execution_model = spv::ExecutionModel::Fragment;
         ctx.AddExecutionMode(main, spv::ExecutionMode::OriginUpperLeft);
         if (program.info.stores_frag_depth) {
@@ -258,10 +258,14 @@ void SetupCapabilities(const Profile& profile, const Info& info, EmitContext& ct
         ctx.AddExtension("SPV_EXT_demote_to_helper_invocation");
         ctx.AddCapability(spv::Capability::DemoteToHelperInvocationEXT);
     }
+    if (info.stores_layer) {
+        ctx.AddCapability(spv::Capability::ShaderLayer);
+    }
     if (info.stores_viewport_index) {
         ctx.AddCapability(spv::Capability::MultiViewport);
-        if (profile.support_viewport_index_layer_non_geometry &&
-            ctx.stage != Shader::Stage::Geometry) {
+    }
+    if (info.stores_layer || info.stores_viewport_index) {
+        if (profile.support_viewport_index_layer_non_geometry && ctx.stage != Stage::Geometry) {
             ctx.AddExtension("SPV_EXT_shader_viewport_index_layer");
             ctx.AddCapability(spv::Capability::ShaderViewportIndexLayerEXT);
         }
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
index f9c151a5c5..59c56c5ba8 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
@@ -76,9 +76,14 @@ std::optional<Id> OutputAttrPointer(EmitContext& ctx, IR::Attribute attr) {
         const Id clip_num{ctx.Constant(ctx.U32[1], index)};
         return ctx.OpAccessChain(ctx.output_f32, ctx.clip_distances, clip_num);
     }
+    case IR::Attribute::Layer:
+        return ctx.profile.support_viewport_index_layer_non_geometry ||
+                       ctx.stage == Shader::Stage::Geometry
+                   ? std::optional<Id>{ctx.layer}
+                   : std::nullopt;
     case IR::Attribute::ViewportIndex:
-        return (ctx.profile.support_viewport_index_layer_non_geometry ||
-                ctx.stage == Shader::Stage::Geometry)
+        return ctx.profile.support_viewport_index_layer_non_geometry ||
+                       ctx.stage == Shader::Stage::Geometry
                    ? std::optional<Id>{ctx.viewport_index}
                    : std::nullopt;
     default:
-- 
cgit v1.2.3-70-g09d2