From 183855e396cc6918d36fbf3e38ea426e934b4e3e Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Thu, 15 Apr 2021 22:46:11 -0300
Subject: shader: Implement tessellation shaders, polygon mode and invocation
 id

---
 .../backend/spirv/emit_context.cpp                 | 147 +++++++++++++++------
 1 file changed, 105 insertions(+), 42 deletions(-)

(limited to 'src/shader_recompiler/backend/spirv/emit_context.cpp')

diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp
index 032cf5e03e..067f616137 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_context.cpp
@@ -125,19 +125,36 @@ u32 NumVertices(InputTopology input_topology) {
     throw InvalidArgument("Invalid input topology {}", input_topology);
 }
 
-Id DefineInput(EmitContext& ctx, Id type, std::optional<spv::BuiltIn> builtin = std::nullopt) {
-    if (ctx.stage == Stage::Geometry) {
-        const u32 num_vertices{NumVertices(ctx.profile.input_topology)};
-        type = ctx.TypeArray(type, ctx.Constant(ctx.U32[1], num_vertices));
+Id DefineInput(EmitContext& ctx, Id type, bool per_invocation,
+               std::optional<spv::BuiltIn> builtin = std::nullopt) {
+    switch (ctx.stage) {
+    case Stage::TessellationControl:
+    case Stage::TessellationEval:
+        if (per_invocation) {
+            type = ctx.TypeArray(type, ctx.Constant(ctx.U32[1], 32u));
+        }
+        break;
+    case Stage::Geometry:
+        if (per_invocation) {
+            const u32 num_vertices{NumVertices(ctx.profile.input_topology)};
+            type = ctx.TypeArray(type, ctx.Constant(ctx.U32[1], num_vertices));
+        }
+        break;
+    default:
+        break;
     }
     return DefineVariable(ctx, type, builtin, spv::StorageClass::Input);
 }
 
-Id DefineOutput(EmitContext& ctx, Id type, std::optional<spv::BuiltIn> builtin = std::nullopt) {
+Id DefineOutput(EmitContext& ctx, Id type, std::optional<u32> invocations,
+                std::optional<spv::BuiltIn> builtin = std::nullopt) {
+    if (invocations && ctx.stage == Stage::TessellationControl) {
+        type = ctx.TypeArray(type, ctx.Constant(ctx.U32[1], *invocations));
+    }
     return DefineVariable(ctx, type, builtin, spv::StorageClass::Output);
 }
 
-void DefineGenericOutput(EmitContext& ctx, size_t index) {
+void DefineGenericOutput(EmitContext& ctx, size_t index, std::optional<u32> invocations) {
     static constexpr std::string_view swizzle{"xyzw"};
     const size_t base_attr_index{static_cast<size_t>(IR::Attribute::Generic0X) + index * 4};
     u32 element{0};
@@ -150,7 +167,7 @@ void DefineGenericOutput(EmitContext& ctx, size_t index) {
         }
         const u32 num_components{xfb_varying ? xfb_varying->components : remainder};
 
-        const Id id{DefineOutput(ctx, ctx.F32[num_components])};
+        const Id id{DefineOutput(ctx, ctx.F32[num_components], invocations)};
         ctx.Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
         if (element > 0) {
             ctx.Decorate(id, spv::Decoration::Component, element);
@@ -161,10 +178,10 @@ void DefineGenericOutput(EmitContext& ctx, size_t index) {
             ctx.Decorate(id, spv::Decoration::Offset, xfb_varying->offset);
         }
         if (num_components < 4 || element > 0) {
-            ctx.Name(id, fmt::format("out_attr{}", index));
-        } else {
             const std::string_view subswizzle{swizzle.substr(element, num_components)};
             ctx.Name(id, fmt::format("out_attr{}_{}", index, subswizzle));
+        } else {
+            ctx.Name(id, fmt::format("out_attr{}", index));
         }
         const GenericElementInfo info{
             .id = id,
@@ -383,7 +400,7 @@ EmitContext::EmitContext(const Profile& profile_, IR::Program& program, u32& bin
     AddCapability(spv::Capability::Shader);
     DefineCommonTypes(program.info);
     DefineCommonConstants();
-    DefineInterfaces(program.info);
+    DefineInterfaces(program);
     DefineLocalMemory(program);
     DefineSharedMemory(program);
     DefineSharedMemoryFunctions(program);
@@ -472,9 +489,9 @@ void EmitContext::DefineCommonConstants() {
     f32_zero_value = Constant(F32[1], 0.0f);
 }
 
-void EmitContext::DefineInterfaces(const Info& info) {
-    DefineInputs(info);
-    DefineOutputs(info);
+void EmitContext::DefineInterfaces(const IR::Program& program) {
+    DefineInputs(program.info);
+    DefineOutputs(program);
 }
 
 void EmitContext::DefineLocalMemory(const IR::Program& program) {
@@ -972,26 +989,29 @@ void EmitContext::DefineLabels(IR::Program& program) {
 
 void EmitContext::DefineInputs(const Info& info) {
     if (info.uses_workgroup_id) {
-        workgroup_id = DefineInput(*this, U32[3], spv::BuiltIn::WorkgroupId);
+        workgroup_id = DefineInput(*this, U32[3], false, spv::BuiltIn::WorkgroupId);
     }
     if (info.uses_local_invocation_id) {
-        local_invocation_id = DefineInput(*this, U32[3], spv::BuiltIn::LocalInvocationId);
+        local_invocation_id = DefineInput(*this, U32[3], false, spv::BuiltIn::LocalInvocationId);
+    }
+    if (info.uses_invocation_id) {
+        invocation_id = DefineInput(*this, U32[1], false, spv::BuiltIn::InvocationId);
     }
     if (info.uses_is_helper_invocation) {
-        is_helper_invocation = DefineInput(*this, U1, spv::BuiltIn::HelperInvocation);
+        is_helper_invocation = DefineInput(*this, U1, false, spv::BuiltIn::HelperInvocation);
     }
     if (info.uses_subgroup_mask) {
-        subgroup_mask_eq = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupEqMaskKHR);
-        subgroup_mask_lt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLtMaskKHR);
-        subgroup_mask_le = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLeMaskKHR);
-        subgroup_mask_gt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGtMaskKHR);
-        subgroup_mask_ge = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGeMaskKHR);
+        subgroup_mask_eq = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupEqMaskKHR);
+        subgroup_mask_lt = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupLtMaskKHR);
+        subgroup_mask_le = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupLeMaskKHR);
+        subgroup_mask_gt = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupGtMaskKHR);
+        subgroup_mask_ge = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupGeMaskKHR);
     }
     if (info.uses_subgroup_invocation_id ||
         (profile.warp_size_potentially_larger_than_guest &&
          (info.uses_subgroup_vote || info.uses_subgroup_mask))) {
         subgroup_local_invocation_id =
-            DefineInput(*this, U32[1], spv::BuiltIn::SubgroupLocalInvocationId);
+            DefineInput(*this, U32[1], false, spv::BuiltIn::SubgroupLocalInvocationId);
     }
     if (info.uses_fswzadd) {
         const Id f32_one{Constant(F32[1], 1.0f)};
@@ -1004,29 +1024,32 @@ void EmitContext::DefineInputs(const Info& info) {
     if (info.loads_position) {
         const bool is_fragment{stage != Stage::Fragment};
         const spv::BuiltIn built_in{is_fragment ? spv::BuiltIn::Position : spv::BuiltIn::FragCoord};
-        input_position = DefineInput(*this, F32[4], built_in);
+        input_position = DefineInput(*this, F32[4], true, built_in);
     }
     if (info.loads_instance_id) {
         if (profile.support_vertex_instance_id) {
-            instance_id = DefineInput(*this, U32[1], spv::BuiltIn::InstanceId);
+            instance_id = DefineInput(*this, U32[1], true, spv::BuiltIn::InstanceId);
         } else {
-            instance_index = DefineInput(*this, U32[1], spv::BuiltIn::InstanceIndex);
-            base_instance = DefineInput(*this, U32[1], spv::BuiltIn::BaseInstance);
+            instance_index = DefineInput(*this, U32[1], true, spv::BuiltIn::InstanceIndex);
+            base_instance = DefineInput(*this, U32[1], true, spv::BuiltIn::BaseInstance);
         }
     }
     if (info.loads_vertex_id) {
         if (profile.support_vertex_instance_id) {
-            vertex_id = DefineInput(*this, U32[1], spv::BuiltIn::VertexId);
+            vertex_id = DefineInput(*this, U32[1], true, spv::BuiltIn::VertexId);
         } else {
-            vertex_index = DefineInput(*this, U32[1], spv::BuiltIn::VertexIndex);
-            base_vertex = DefineInput(*this, U32[1], spv::BuiltIn::BaseVertex);
+            vertex_index = DefineInput(*this, U32[1], true, spv::BuiltIn::VertexIndex);
+            base_vertex = DefineInput(*this, U32[1], true, spv::BuiltIn::BaseVertex);
         }
     }
     if (info.loads_front_face) {
-        front_face = DefineInput(*this, U1, spv::BuiltIn::FrontFacing);
+        front_face = DefineInput(*this, U1, true, spv::BuiltIn::FrontFacing);
     }
     if (info.loads_point_coord) {
-        point_coord = DefineInput(*this, F32[2], spv::BuiltIn::PointCoord);
+        point_coord = DefineInput(*this, F32[2], true, spv::BuiltIn::PointCoord);
+    }
+    if (info.loads_tess_coord) {
+        tess_coord = DefineInput(*this, F32[3], false, spv::BuiltIn::TessCoord);
     }
     for (size_t index = 0; index < info.input_generics.size(); ++index) {
         const InputVarying generic{info.input_generics[index]};
@@ -1038,7 +1061,7 @@ void EmitContext::DefineInputs(const Info& info) {
             continue;
         }
         const Id type{GetAttributeType(*this, input_type)};
-        const Id id{DefineInput(*this, type)};
+        const Id id{DefineInput(*this, type, true)};
         Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
         Name(id, fmt::format("in_attr{}", index));
         input_generics[index] = id;
@@ -1059,58 +1082,98 @@ void EmitContext::DefineInputs(const Info& info) {
             break;
         }
     }
+    if (stage == Stage::TessellationEval) {
+        for (size_t index = 0; index < info.uses_patches.size(); ++index) {
+            if (!info.uses_patches[index]) {
+                continue;
+            }
+            const Id id{DefineInput(*this, F32[4], false)};
+            Decorate(id, spv::Decoration::Patch);
+            Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
+            patches[index] = id;
+        }
+    }
 }
 
-void EmitContext::DefineOutputs(const Info& info) {
+void EmitContext::DefineOutputs(const IR::Program& program) {
+    const Info& info{program.info};
+    const std::optional<u32> invocations{program.invocations};
     if (info.stores_position || stage == Stage::VertexB) {
-        output_position = DefineOutput(*this, F32[4], spv::BuiltIn::Position);
+        output_position = DefineOutput(*this, F32[4], invocations, spv::BuiltIn::Position);
     }
     if (info.stores_point_size || profile.fixed_state_point_size) {
         if (stage == Stage::Fragment) {
             throw NotImplementedException("Storing PointSize in fragment stage");
         }
-        output_point_size = DefineOutput(*this, F32[1], spv::BuiltIn::PointSize);
+        output_point_size = DefineOutput(*this, F32[1], invocations, spv::BuiltIn::PointSize);
     }
     if (info.stores_clip_distance) {
         if (stage == Stage::Fragment) {
             throw NotImplementedException("Storing ClipDistance in fragment stage");
         }
         const Id type{TypeArray(F32[1], Constant(U32[1], 8U))};
-        clip_distances = DefineOutput(*this, type, spv::BuiltIn::ClipDistance);
+        clip_distances = DefineOutput(*this, type, invocations, spv::BuiltIn::ClipDistance);
     }
     if (info.stores_layer &&
         (profile.support_viewport_index_layer_non_geometry || stage == Stage::Geometry)) {
         if (stage == Stage::Fragment) {
             throw NotImplementedException("Storing Layer in fragment stage");
         }
-        layer = DefineOutput(*this, U32[1], spv::BuiltIn::Layer);
+        layer = DefineOutput(*this, U32[1], invocations, spv::BuiltIn::Layer);
     }
     if (info.stores_viewport_index &&
         (profile.support_viewport_index_layer_non_geometry || stage == Stage::Geometry)) {
         if (stage == Stage::Fragment) {
             throw NotImplementedException("Storing ViewportIndex in fragment stage");
         }
-        viewport_index = DefineOutput(*this, U32[1], spv::BuiltIn::ViewportIndex);
+        viewport_index = DefineOutput(*this, U32[1], invocations, spv::BuiltIn::ViewportIndex);
     }
     for (size_t index = 0; index < info.stores_generics.size(); ++index) {
         if (info.stores_generics[index]) {
-            DefineGenericOutput(*this, index);
+            DefineGenericOutput(*this, index, invocations);
         }
     }
-    if (stage == Stage::Fragment) {
+    switch (stage) {
+    case Stage::TessellationControl:
+        if (info.stores_tess_level_outer) {
+            const Id type{TypeArray(F32[1], Constant(U32[1], 4))};
+            output_tess_level_outer =
+                DefineOutput(*this, type, std::nullopt, spv::BuiltIn::TessLevelOuter);
+            Decorate(output_tess_level_outer, spv::Decoration::Patch);
+        }
+        if (info.stores_tess_level_inner) {
+            const Id type{TypeArray(F32[1], Constant(U32[1], 2))};
+            output_tess_level_inner =
+                DefineOutput(*this, type, std::nullopt, spv::BuiltIn::TessLevelInner);
+            Decorate(output_tess_level_inner, spv::Decoration::Patch);
+        }
+        for (size_t index = 0; index < info.uses_patches.size(); ++index) {
+            if (!info.uses_patches[index]) {
+                continue;
+            }
+            const Id id{DefineOutput(*this, F32[4], std::nullopt)};
+            Decorate(id, spv::Decoration::Patch);
+            Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
+            patches[index] = id;
+        }
+        break;
+    case Stage::Fragment:
         for (u32 index = 0; index < 8; ++index) {
             if (!info.stores_frag_color[index]) {
                 continue;
             }
-            frag_color[index] = DefineOutput(*this, F32[4]);
+            frag_color[index] = DefineOutput(*this, F32[4], std::nullopt);
             Decorate(frag_color[index], spv::Decoration::Location, index);
             Name(frag_color[index], fmt::format("frag_color{}", index));
         }
         if (info.stores_frag_depth) {
-            frag_depth = DefineOutput(*this, F32[1]);
+            frag_depth = DefineOutput(*this, F32[1], std::nullopt);
             Decorate(frag_depth, spv::Decoration::BuiltIn, spv::BuiltIn::FragDepth);
             Name(frag_depth, "frag_depth");
         }
+        break;
+    default:
+        break;
     }
 }
 
-- 
cgit v1.2.3-70-g09d2