From 183855e396cc6918d36fbf3e38ea426e934b4e3e Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Thu, 15 Apr 2021 22:46:11 -0300
Subject: shader: Implement tessellation shaders, polygon mode and invocation
 id

---
 .../backend/spirv/emit_context.cpp                 | 147 +++++++++++++++------
 src/shader_recompiler/backend/spirv/emit_context.h |  10 +-
 src/shader_recompiler/backend/spirv/emit_spirv.cpp |  39 ++++++
 src/shader_recompiler/backend/spirv/emit_spirv.h   |   3 +
 .../backend/spirv/emit_spirv_context_get_set.cpp   |  88 ++++++++++--
 5 files changed, 232 insertions(+), 55 deletions(-)

(limited to 'src/shader_recompiler/backend/spirv')

diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp
index 032cf5e03e..067f616137 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_context.cpp
@@ -125,19 +125,36 @@ u32 NumVertices(InputTopology input_topology) {
     throw InvalidArgument("Invalid input topology {}", input_topology);
 }
 
-Id DefineInput(EmitContext& ctx, Id type, std::optional<spv::BuiltIn> builtin = std::nullopt) {
-    if (ctx.stage == Stage::Geometry) {
-        const u32 num_vertices{NumVertices(ctx.profile.input_topology)};
-        type = ctx.TypeArray(type, ctx.Constant(ctx.U32[1], num_vertices));
+Id DefineInput(EmitContext& ctx, Id type, bool per_invocation,
+               std::optional<spv::BuiltIn> builtin = std::nullopt) {
+    switch (ctx.stage) {
+    case Stage::TessellationControl:
+    case Stage::TessellationEval:
+        if (per_invocation) {
+            type = ctx.TypeArray(type, ctx.Constant(ctx.U32[1], 32u));
+        }
+        break;
+    case Stage::Geometry:
+        if (per_invocation) {
+            const u32 num_vertices{NumVertices(ctx.profile.input_topology)};
+            type = ctx.TypeArray(type, ctx.Constant(ctx.U32[1], num_vertices));
+        }
+        break;
+    default:
+        break;
     }
     return DefineVariable(ctx, type, builtin, spv::StorageClass::Input);
 }
 
-Id DefineOutput(EmitContext& ctx, Id type, std::optional<spv::BuiltIn> builtin = std::nullopt) {
+Id DefineOutput(EmitContext& ctx, Id type, std::optional<u32> invocations,
+                std::optional<spv::BuiltIn> builtin = std::nullopt) {
+    if (invocations && ctx.stage == Stage::TessellationControl) {
+        type = ctx.TypeArray(type, ctx.Constant(ctx.U32[1], *invocations));
+    }
     return DefineVariable(ctx, type, builtin, spv::StorageClass::Output);
 }
 
-void DefineGenericOutput(EmitContext& ctx, size_t index) {
+void DefineGenericOutput(EmitContext& ctx, size_t index, std::optional<u32> invocations) {
     static constexpr std::string_view swizzle{"xyzw"};
     const size_t base_attr_index{static_cast<size_t>(IR::Attribute::Generic0X) + index * 4};
     u32 element{0};
@@ -150,7 +167,7 @@ void DefineGenericOutput(EmitContext& ctx, size_t index) {
         }
         const u32 num_components{xfb_varying ? xfb_varying->components : remainder};
 
-        const Id id{DefineOutput(ctx, ctx.F32[num_components])};
+        const Id id{DefineOutput(ctx, ctx.F32[num_components], invocations)};
         ctx.Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
         if (element > 0) {
             ctx.Decorate(id, spv::Decoration::Component, element);
@@ -161,10 +178,10 @@ void DefineGenericOutput(EmitContext& ctx, size_t index) {
             ctx.Decorate(id, spv::Decoration::Offset, xfb_varying->offset);
         }
         if (num_components < 4 || element > 0) {
-            ctx.Name(id, fmt::format("out_attr{}", index));
-        } else {
             const std::string_view subswizzle{swizzle.substr(element, num_components)};
             ctx.Name(id, fmt::format("out_attr{}_{}", index, subswizzle));
+        } else {
+            ctx.Name(id, fmt::format("out_attr{}", index));
         }
         const GenericElementInfo info{
             .id = id,
@@ -383,7 +400,7 @@ EmitContext::EmitContext(const Profile& profile_, IR::Program& program, u32& bin
     AddCapability(spv::Capability::Shader);
     DefineCommonTypes(program.info);
     DefineCommonConstants();
-    DefineInterfaces(program.info);
+    DefineInterfaces(program);
     DefineLocalMemory(program);
     DefineSharedMemory(program);
     DefineSharedMemoryFunctions(program);
@@ -472,9 +489,9 @@ void EmitContext::DefineCommonConstants() {
     f32_zero_value = Constant(F32[1], 0.0f);
 }
 
-void EmitContext::DefineInterfaces(const Info& info) {
-    DefineInputs(info);
-    DefineOutputs(info);
+void EmitContext::DefineInterfaces(const IR::Program& program) {
+    DefineInputs(program.info);
+    DefineOutputs(program);
 }
 
 void EmitContext::DefineLocalMemory(const IR::Program& program) {
@@ -972,26 +989,29 @@ void EmitContext::DefineLabels(IR::Program& program) {
 
 void EmitContext::DefineInputs(const Info& info) {
     if (info.uses_workgroup_id) {
-        workgroup_id = DefineInput(*this, U32[3], spv::BuiltIn::WorkgroupId);
+        workgroup_id = DefineInput(*this, U32[3], false, spv::BuiltIn::WorkgroupId);
     }
     if (info.uses_local_invocation_id) {
-        local_invocation_id = DefineInput(*this, U32[3], spv::BuiltIn::LocalInvocationId);
+        local_invocation_id = DefineInput(*this, U32[3], false, spv::BuiltIn::LocalInvocationId);
+    }
+    if (info.uses_invocation_id) {
+        invocation_id = DefineInput(*this, U32[1], false, spv::BuiltIn::InvocationId);
     }
     if (info.uses_is_helper_invocation) {
-        is_helper_invocation = DefineInput(*this, U1, spv::BuiltIn::HelperInvocation);
+        is_helper_invocation = DefineInput(*this, U1, false, spv::BuiltIn::HelperInvocation);
     }
     if (info.uses_subgroup_mask) {
-        subgroup_mask_eq = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupEqMaskKHR);
-        subgroup_mask_lt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLtMaskKHR);
-        subgroup_mask_le = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLeMaskKHR);
-        subgroup_mask_gt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGtMaskKHR);
-        subgroup_mask_ge = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGeMaskKHR);
+        subgroup_mask_eq = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupEqMaskKHR);
+        subgroup_mask_lt = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupLtMaskKHR);
+        subgroup_mask_le = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupLeMaskKHR);
+        subgroup_mask_gt = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupGtMaskKHR);
+        subgroup_mask_ge = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupGeMaskKHR);
     }
     if (info.uses_subgroup_invocation_id ||
         (profile.warp_size_potentially_larger_than_guest &&
          (info.uses_subgroup_vote || info.uses_subgroup_mask))) {
         subgroup_local_invocation_id =
-            DefineInput(*this, U32[1], spv::BuiltIn::SubgroupLocalInvocationId);
+            DefineInput(*this, U32[1], false, spv::BuiltIn::SubgroupLocalInvocationId);
     }
     if (info.uses_fswzadd) {
         const Id f32_one{Constant(F32[1], 1.0f)};
@@ -1004,29 +1024,32 @@ void EmitContext::DefineInputs(const Info& info) {
     if (info.loads_position) {
         const bool is_fragment{stage != Stage::Fragment};
         const spv::BuiltIn built_in{is_fragment ? spv::BuiltIn::Position : spv::BuiltIn::FragCoord};
-        input_position = DefineInput(*this, F32[4], built_in);
+        input_position = DefineInput(*this, F32[4], true, built_in);
     }
     if (info.loads_instance_id) {
         if (profile.support_vertex_instance_id) {
-            instance_id = DefineInput(*this, U32[1], spv::BuiltIn::InstanceId);
+            instance_id = DefineInput(*this, U32[1], true, spv::BuiltIn::InstanceId);
         } else {
-            instance_index = DefineInput(*this, U32[1], spv::BuiltIn::InstanceIndex);
-            base_instance = DefineInput(*this, U32[1], spv::BuiltIn::BaseInstance);
+            instance_index = DefineInput(*this, U32[1], true, spv::BuiltIn::InstanceIndex);
+            base_instance = DefineInput(*this, U32[1], true, spv::BuiltIn::BaseInstance);
         }
     }
     if (info.loads_vertex_id) {
         if (profile.support_vertex_instance_id) {
-            vertex_id = DefineInput(*this, U32[1], spv::BuiltIn::VertexId);
+            vertex_id = DefineInput(*this, U32[1], true, spv::BuiltIn::VertexId);
         } else {
-            vertex_index = DefineInput(*this, U32[1], spv::BuiltIn::VertexIndex);
-            base_vertex = DefineInput(*this, U32[1], spv::BuiltIn::BaseVertex);
+            vertex_index = DefineInput(*this, U32[1], true, spv::BuiltIn::VertexIndex);
+            base_vertex = DefineInput(*this, U32[1], true, spv::BuiltIn::BaseVertex);
         }
     }
     if (info.loads_front_face) {
-        front_face = DefineInput(*this, U1, spv::BuiltIn::FrontFacing);
+        front_face = DefineInput(*this, U1, true, spv::BuiltIn::FrontFacing);
     }
     if (info.loads_point_coord) {
-        point_coord = DefineInput(*this, F32[2], spv::BuiltIn::PointCoord);
+        point_coord = DefineInput(*this, F32[2], true, spv::BuiltIn::PointCoord);
+    }
+    if (info.loads_tess_coord) {
+        tess_coord = DefineInput(*this, F32[3], false, spv::BuiltIn::TessCoord);
     }
     for (size_t index = 0; index < info.input_generics.size(); ++index) {
         const InputVarying generic{info.input_generics[index]};
@@ -1038,7 +1061,7 @@ void EmitContext::DefineInputs(const Info& info) {
             continue;
         }
         const Id type{GetAttributeType(*this, input_type)};
-        const Id id{DefineInput(*this, type)};
+        const Id id{DefineInput(*this, type, true)};
         Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
         Name(id, fmt::format("in_attr{}", index));
         input_generics[index] = id;
@@ -1059,58 +1082,98 @@ void EmitContext::DefineInputs(const Info& info) {
             break;
         }
     }
+    if (stage == Stage::TessellationEval) {
+        for (size_t index = 0; index < info.uses_patches.size(); ++index) {
+            if (!info.uses_patches[index]) {
+                continue;
+            }
+            const Id id{DefineInput(*this, F32[4], false)};
+            Decorate(id, spv::Decoration::Patch);
+            Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
+            patches[index] = id;
+        }
+    }
 }
 
-void EmitContext::DefineOutputs(const Info& info) {
+void EmitContext::DefineOutputs(const IR::Program& program) {
+    const Info& info{program.info};
+    const std::optional<u32> invocations{program.invocations};
     if (info.stores_position || stage == Stage::VertexB) {
-        output_position = DefineOutput(*this, F32[4], spv::BuiltIn::Position);
+        output_position = DefineOutput(*this, F32[4], invocations, spv::BuiltIn::Position);
     }
     if (info.stores_point_size || profile.fixed_state_point_size) {
         if (stage == Stage::Fragment) {
             throw NotImplementedException("Storing PointSize in fragment stage");
         }
-        output_point_size = DefineOutput(*this, F32[1], spv::BuiltIn::PointSize);
+        output_point_size = DefineOutput(*this, F32[1], invocations, spv::BuiltIn::PointSize);
     }
     if (info.stores_clip_distance) {
         if (stage == Stage::Fragment) {
             throw NotImplementedException("Storing ClipDistance in fragment stage");
         }
         const Id type{TypeArray(F32[1], Constant(U32[1], 8U))};
-        clip_distances = DefineOutput(*this, type, spv::BuiltIn::ClipDistance);
+        clip_distances = DefineOutput(*this, type, invocations, spv::BuiltIn::ClipDistance);
     }
     if (info.stores_layer &&
         (profile.support_viewport_index_layer_non_geometry || stage == Stage::Geometry)) {
         if (stage == Stage::Fragment) {
             throw NotImplementedException("Storing Layer in fragment stage");
         }
-        layer = DefineOutput(*this, U32[1], spv::BuiltIn::Layer);
+        layer = DefineOutput(*this, U32[1], invocations, spv::BuiltIn::Layer);
     }
     if (info.stores_viewport_index &&
         (profile.support_viewport_index_layer_non_geometry || stage == Stage::Geometry)) {
         if (stage == Stage::Fragment) {
             throw NotImplementedException("Storing ViewportIndex in fragment stage");
         }
-        viewport_index = DefineOutput(*this, U32[1], spv::BuiltIn::ViewportIndex);
+        viewport_index = DefineOutput(*this, U32[1], invocations, spv::BuiltIn::ViewportIndex);
     }
     for (size_t index = 0; index < info.stores_generics.size(); ++index) {
         if (info.stores_generics[index]) {
-            DefineGenericOutput(*this, index);
+            DefineGenericOutput(*this, index, invocations);
         }
     }
-    if (stage == Stage::Fragment) {
+    switch (stage) {
+    case Stage::TessellationControl:
+        if (info.stores_tess_level_outer) {
+            const Id type{TypeArray(F32[1], Constant(U32[1], 4))};
+            output_tess_level_outer =
+                DefineOutput(*this, type, std::nullopt, spv::BuiltIn::TessLevelOuter);
+            Decorate(output_tess_level_outer, spv::Decoration::Patch);
+        }
+        if (info.stores_tess_level_inner) {
+            const Id type{TypeArray(F32[1], Constant(U32[1], 2))};
+            output_tess_level_inner =
+                DefineOutput(*this, type, std::nullopt, spv::BuiltIn::TessLevelInner);
+            Decorate(output_tess_level_inner, spv::Decoration::Patch);
+        }
+        for (size_t index = 0; index < info.uses_patches.size(); ++index) {
+            if (!info.uses_patches[index]) {
+                continue;
+            }
+            const Id id{DefineOutput(*this, F32[4], std::nullopt)};
+            Decorate(id, spv::Decoration::Patch);
+            Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
+            patches[index] = id;
+        }
+        break;
+    case Stage::Fragment:
         for (u32 index = 0; index < 8; ++index) {
             if (!info.stores_frag_color[index]) {
                 continue;
             }
-            frag_color[index] = DefineOutput(*this, F32[4]);
+            frag_color[index] = DefineOutput(*this, F32[4], std::nullopt);
             Decorate(frag_color[index], spv::Decoration::Location, index);
             Name(frag_color[index], fmt::format("frag_color{}", index));
         }
         if (info.stores_frag_depth) {
-            frag_depth = DefineOutput(*this, F32[1]);
+            frag_depth = DefineOutput(*this, F32[1], std::nullopt);
             Decorate(frag_depth, spv::Decoration::BuiltIn, spv::BuiltIn::FragDepth);
             Name(frag_depth, "frag_depth");
         }
+        break;
+    default:
+        break;
     }
 }
 
diff --git a/src/shader_recompiler/backend/spirv/emit_context.h b/src/shader_recompiler/backend/spirv/emit_context.h
index 0da14d5f8e..ba0a253b35 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.h
+++ b/src/shader_recompiler/backend/spirv/emit_context.h
@@ -147,6 +147,7 @@ public:
 
     Id workgroup_id{};
     Id local_invocation_id{};
+    Id invocation_id{};
     Id is_helper_invocation{};
     Id subgroup_local_invocation_id{};
     Id subgroup_mask_eq{};
@@ -162,6 +163,7 @@ public:
     Id base_vertex{};
     Id front_face{};
     Id point_coord{};
+    Id tess_coord{};
     Id clip_distances{};
     Id layer{};
     Id viewport_index{};
@@ -204,6 +206,10 @@ public:
     Id output_position{};
     std::array<std::array<GenericElementInfo, 4>, 32> output_generics{};
 
+    Id output_tess_level_outer{};
+    Id output_tess_level_inner{};
+    std::array<Id, 30> patches{};
+
     std::array<Id, 8> frag_color{};
     Id frag_depth{};
 
@@ -212,7 +218,7 @@ public:
 private:
     void DefineCommonTypes(const Info& info);
     void DefineCommonConstants();
-    void DefineInterfaces(const Info& info);
+    void DefineInterfaces(const IR::Program& program);
     void DefineLocalMemory(const IR::Program& program);
     void DefineSharedMemory(const IR::Program& program);
     void DefineSharedMemoryFunctions(const IR::Program& program);
@@ -226,7 +232,7 @@ private:
     void DefineLabels(IR::Program& program);
 
     void DefineInputs(const Info& info);
-    void DefineOutputs(const Info& info);
+    void DefineOutputs(const IR::Program& program);
 };
 
 } // namespace Shader::Backend::SPIRV
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.cpp b/src/shader_recompiler/backend/spirv/emit_spirv.cpp
index 3bf4c6a9ec..105602ccf5 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv.cpp
@@ -45,6 +45,8 @@ ArgType Arg(EmitContext& ctx, const IR::Value& arg) {
         return arg.Label();
     } else if constexpr (std::is_same_v<ArgType, IR::Attribute>) {
         return arg.Attribute();
+    } else if constexpr (std::is_same_v<ArgType, IR::Patch>) {
+        return arg.Patch();
     } else if constexpr (std::is_same_v<ArgType, IR::Reg>) {
         return arg.Reg();
     }
@@ -120,6 +122,30 @@ Id DefineMain(EmitContext& ctx, IR::Program& program) {
     return main;
 }
 
+spv::ExecutionMode ExecutionMode(TessPrimitive primitive) {
+    switch (primitive) {
+    case TessPrimitive::Isolines:
+        return spv::ExecutionMode::Isolines;
+    case TessPrimitive::Triangles:
+        return spv::ExecutionMode::Triangles;
+    case TessPrimitive::Quads:
+        return spv::ExecutionMode::Quads;
+    }
+    throw InvalidArgument("Tessellation primitive {}", primitive);
+}
+
+spv::ExecutionMode ExecutionMode(TessSpacing spacing) {
+    switch (spacing) {
+    case TessSpacing::Equal:
+        return spv::ExecutionMode::SpacingEqual;
+    case TessSpacing::FractionalOdd:
+        return spv::ExecutionMode::SpacingFractionalOdd;
+    case TessSpacing::FractionalEven:
+        return spv::ExecutionMode::SpacingFractionalEven;
+    }
+    throw InvalidArgument("Tessellation spacing {}", spacing);
+}
+
 void DefineEntryPoint(const IR::Program& program, EmitContext& ctx, Id main) {
     const std::span interfaces(ctx.interfaces.data(), ctx.interfaces.size());
     spv::ExecutionModel execution_model{};
@@ -134,6 +160,19 @@ void DefineEntryPoint(const IR::Program& program, EmitContext& ctx, Id main) {
     case Stage::VertexB:
         execution_model = spv::ExecutionModel::Vertex;
         break;
+    case Stage::TessellationControl:
+        execution_model = spv::ExecutionModel::TessellationControl;
+        ctx.AddCapability(spv::Capability::Tessellation);
+        ctx.AddExecutionMode(main, spv::ExecutionMode::OutputVertices, program.invocations);
+        break;
+    case Stage::TessellationEval:
+        execution_model = spv::ExecutionModel::TessellationEvaluation;
+        ctx.AddCapability(spv::Capability::Tessellation);
+        ctx.AddExecutionMode(main, ExecutionMode(ctx.profile.tess_primitive));
+        ctx.AddExecutionMode(main, ExecutionMode(ctx.profile.tess_spacing));
+        ctx.AddExecutionMode(main, ctx.profile.tess_clockwise ? spv::ExecutionMode::VertexOrderCw
+                                                              : spv::ExecutionMode::VertexOrderCcw);
+        break;
     case Stage::Geometry:
         execution_model = spv::ExecutionModel::Geometry;
         ctx.AddCapability(spv::Capability::Geometry);
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.h b/src/shader_recompiler/backend/spirv/emit_spirv.h
index 55b2edba0c..8caf30f1b0 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv.h
+++ b/src/shader_recompiler/backend/spirv/emit_spirv.h
@@ -55,6 +55,8 @@ Id EmitGetAttribute(EmitContext& ctx, IR::Attribute attr, Id vertex);
 void EmitSetAttribute(EmitContext& ctx, IR::Attribute attr, Id value, Id vertex);
 Id EmitGetAttributeIndexed(EmitContext& ctx, Id offset, Id vertex);
 void EmitSetAttributeIndexed(EmitContext& ctx, Id offset, Id value, Id vertex);
+Id EmitGetPatch(EmitContext& ctx, IR::Patch patch);
+void EmitSetPatch(EmitContext& ctx, IR::Patch patch, Id value);
 void EmitSetFragColor(EmitContext& ctx, u32 index, u32 component, Id value);
 void EmitSetFragDepth(EmitContext& ctx, Id value);
 void EmitGetZFlag(EmitContext& ctx);
@@ -67,6 +69,7 @@ void EmitSetCFlag(EmitContext& ctx);
 void EmitSetOFlag(EmitContext& ctx);
 Id EmitWorkgroupId(EmitContext& ctx);
 Id EmitLocalInvocationId(EmitContext& ctx);
+Id EmitInvocationId(EmitContext& ctx);
 Id EmitIsHelperInvocation(EmitContext& ctx);
 Id EmitLoadLocal(EmitContext& ctx, Id word_offset);
 void EmitWriteLocal(EmitContext& ctx, Id word_offset, Id value);
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
index 59c56c5ba8..4a1aeece5a 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
@@ -32,13 +32,26 @@ std::optional<AttrInfo> AttrTypes(EmitContext& ctx, u32 index) {
 
 template <typename... Args>
 Id AttrPointer(EmitContext& ctx, Id pointer_type, Id vertex, Id base, Args&&... args) {
-    if (ctx.stage == Stage::Geometry) {
+    switch (ctx.stage) {
+    case Stage::TessellationControl:
+    case Stage::TessellationEval:
+    case Stage::Geometry:
         return ctx.OpAccessChain(pointer_type, base, vertex, std::forward<Args>(args)...);
-    } else {
+    default:
         return ctx.OpAccessChain(pointer_type, base, std::forward<Args>(args)...);
     }
 }
 
+template <typename... Args>
+Id OutputAccessChain(EmitContext& ctx, Id result_type, Id base, Args&&... args) {
+    if (ctx.stage == Stage::TessellationControl) {
+        const Id invocation_id{ctx.OpLoad(ctx.U32[1], ctx.invocation_id)};
+        return ctx.OpAccessChain(result_type, base, invocation_id, std::forward<Args>(args)...);
+    } else {
+        return ctx.OpAccessChain(result_type, base, std::forward<Args>(args)...);
+    }
+}
+
 std::optional<Id> OutputAttrPointer(EmitContext& ctx, IR::Attribute attr) {
     if (IR::IsGeneric(attr)) {
         const u32 index{IR::GenericAttributeIndex(attr)};
@@ -49,7 +62,7 @@ std::optional<Id> OutputAttrPointer(EmitContext& ctx, IR::Attribute attr) {
         } else {
             const u32 index_element{element - info.first_element};
             const Id index_id{ctx.Constant(ctx.U32[1], index_element)};
-            return ctx.OpAccessChain(ctx.output_f32, info.id, index_id);
+            return OutputAccessChain(ctx, ctx.output_f32, info.id, index_id);
         }
     }
     switch (attr) {
@@ -61,7 +74,7 @@ std::optional<Id> OutputAttrPointer(EmitContext& ctx, IR::Attribute attr) {
     case IR::Attribute::PositionW: {
         const u32 element{static_cast<u32>(attr) % 4};
         const Id element_id{ctx.Constant(ctx.U32[1], element)};
-        return ctx.OpAccessChain(ctx.output_f32, ctx.output_position, element_id);
+        return OutputAccessChain(ctx, ctx.output_f32, ctx.output_position, element_id);
     }
     case IR::Attribute::ClipDistance0:
     case IR::Attribute::ClipDistance1:
@@ -74,7 +87,7 @@ std::optional<Id> OutputAttrPointer(EmitContext& ctx, IR::Attribute attr) {
         const u32 base{static_cast<u32>(IR::Attribute::ClipDistance0)};
         const u32 index{static_cast<u32>(attr) - base};
         const Id clip_num{ctx.Constant(ctx.U32[1], index)};
-        return ctx.OpAccessChain(ctx.output_f32, ctx.clip_distances, clip_num);
+        return OutputAccessChain(ctx, ctx.output_f32, ctx.clip_distances, clip_num);
     }
     case IR::Attribute::Layer:
         return ctx.profile.support_viewport_index_layer_non_geometry ||
@@ -222,11 +235,18 @@ Id EmitGetAttribute(EmitContext& ctx, IR::Attribute attr, Id vertex) {
                             ctx.Constant(ctx.U32[1], std::numeric_limits<u32>::max()),
                             ctx.u32_zero_value);
     case IR::Attribute::PointSpriteS:
-        return ctx.OpLoad(ctx.F32[1], AttrPointer(ctx, ctx.input_f32, vertex, ctx.point_coord,
-                                                  ctx.u32_zero_value));
+        return ctx.OpLoad(ctx.F32[1],
+                          ctx.OpAccessChain(ctx.input_f32, ctx.point_coord, ctx.u32_zero_value));
     case IR::Attribute::PointSpriteT:
-        return ctx.OpLoad(ctx.F32[1], AttrPointer(ctx, ctx.input_f32, vertex, ctx.point_coord,
-                                                  ctx.Constant(ctx.U32[1], 1U)));
+        return ctx.OpLoad(ctx.F32[1], ctx.OpAccessChain(ctx.input_f32, ctx.point_coord,
+                                                        ctx.Constant(ctx.U32[1], 1U)));
+    case IR::Attribute::TessellationEvaluationPointU:
+        return ctx.OpLoad(ctx.F32[1],
+                          ctx.OpAccessChain(ctx.input_f32, ctx.tess_coord, ctx.u32_zero_value));
+    case IR::Attribute::TessellationEvaluationPointV:
+        return ctx.OpLoad(ctx.F32[1], ctx.OpAccessChain(ctx.input_f32, ctx.tess_coord,
+                                                        ctx.Constant(ctx.U32[1], 1U)));
+
     default:
         throw NotImplementedException("Read attribute {}", attr);
     }
@@ -240,9 +260,12 @@ void EmitSetAttribute(EmitContext& ctx, IR::Attribute attr, Id value, [[maybe_un
 }
 
 Id EmitGetAttributeIndexed(EmitContext& ctx, Id offset, Id vertex) {
-    if (ctx.stage == Stage::Geometry) {
+    switch (ctx.stage) {
+    case Stage::TessellationControl:
+    case Stage::TessellationEval:
+    case Stage::Geometry:
         return ctx.OpFunctionCall(ctx.F32[1], ctx.indexed_load_func, offset, vertex);
-    } else {
+    default:
         return ctx.OpFunctionCall(ctx.F32[1], ctx.indexed_load_func, offset);
     }
 }
@@ -251,6 +274,45 @@ void EmitSetAttributeIndexed(EmitContext& ctx, Id offset, Id value, [[maybe_unus
     ctx.OpFunctionCall(ctx.void_id, ctx.indexed_store_func, offset, value);
 }
 
+Id EmitGetPatch(EmitContext& ctx, IR::Patch patch) {
+    if (!IR::IsGeneric(patch)) {
+        throw NotImplementedException("Non-generic patch load");
+    }
+    const u32 index{IR::GenericPatchIndex(patch)};
+    const Id element{ctx.Constant(ctx.U32[1], IR::GenericPatchElement(patch))};
+    const Id pointer{ctx.OpAccessChain(ctx.input_f32, ctx.patches.at(index), element)};
+    return ctx.OpLoad(ctx.F32[1], pointer);
+}
+
+void EmitSetPatch(EmitContext& ctx, IR::Patch patch, Id value) {
+    const Id pointer{[&] {
+        if (IR::IsGeneric(patch)) {
+            const u32 index{IR::GenericPatchIndex(patch)};
+            const Id element{ctx.Constant(ctx.U32[1], IR::GenericPatchElement(patch))};
+            return ctx.OpAccessChain(ctx.output_f32, ctx.patches.at(index), element);
+        }
+        switch (patch) {
+        case IR::Patch::TessellationLodLeft:
+        case IR::Patch::TessellationLodRight:
+        case IR::Patch::TessellationLodTop:
+        case IR::Patch::TessellationLodBottom: {
+            const u32 index{static_cast<u32>(patch) - u32(IR::Patch::TessellationLodLeft)};
+            const Id index_id{ctx.Constant(ctx.U32[1], index)};
+            return ctx.OpAccessChain(ctx.output_f32, ctx.output_tess_level_outer, index_id);
+        }
+        case IR::Patch::TessellationLodInteriorU:
+            return ctx.OpAccessChain(ctx.output_f32, ctx.output_tess_level_inner,
+                                     ctx.u32_zero_value);
+        case IR::Patch::TessellationLodInteriorV:
+            return ctx.OpAccessChain(ctx.output_f32, ctx.output_tess_level_inner,
+                                     ctx.Constant(ctx.U32[1], 1u));
+        default:
+            throw NotImplementedException("Patch {}", patch);
+        }
+    }()};
+    ctx.OpStore(pointer, value);
+}
+
 void EmitSetFragColor(EmitContext& ctx, u32 index, u32 component, Id value) {
     const Id component_id{ctx.Constant(ctx.U32[1], component)};
     const Id pointer{ctx.OpAccessChain(ctx.output_f32, ctx.frag_color.at(index), component_id)};
@@ -301,6 +363,10 @@ Id EmitLocalInvocationId(EmitContext& ctx) {
     return ctx.OpLoad(ctx.U32[3], ctx.local_invocation_id);
 }
 
+Id EmitInvocationId(EmitContext& ctx) {
+    return ctx.OpLoad(ctx.U32[1], ctx.invocation_id);
+}
+
 Id EmitIsHelperInvocation(EmitContext& ctx) {
     return ctx.OpLoad(ctx.U1, ctx.is_helper_invocation);
 }
-- 
cgit v1.2.3-70-g09d2