From c4fd6b55bc9acd06b2fc89f84fd175d78e14110a Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Mon, 10 May 2021 18:21:28 -0300
Subject: glasm: Implement shuffle and vote instructions on GLASM

---
 .../backend/glasm/emit_context.cpp                 |  17 +++
 src/shader_recompiler/backend/glasm/emit_context.h |   2 +
 src/shader_recompiler/backend/glasm/emit_glasm.cpp |   6 ++
 .../backend/glasm/emit_glasm_instructions.h        |  28 ++---
 .../backend/glasm/emit_glasm_not_implemented.cpp   |  84 +--------------
 .../backend/glasm/emit_glasm_warp.cpp              | 118 +++++++++++++++++++++
 .../backend/spirv/emit_context.cpp                 |   2 +-
 src/shader_recompiler/backend/spirv/emit_spirv.cpp |   4 +-
 .../ir_opt/collect_shader_info_pass.cpp            |   4 +-
 src/shader_recompiler/shader_info.h                |   1 +
 10 files changed, 166 insertions(+), 100 deletions(-)

(limited to 'src/shader_recompiler')

diff --git a/src/shader_recompiler/backend/glasm/emit_context.cpp b/src/shader_recompiler/backend/glasm/emit_context.cpp
index 9f839f3bf7..f9d83dd91b 100644
--- a/src/shader_recompiler/backend/glasm/emit_context.cpp
+++ b/src/shader_recompiler/backend/glasm/emit_context.cpp
@@ -25,6 +25,23 @@ EmitContext::EmitContext(IR::Program& program) {
     if (const size_t num = program.info.storage_buffers_descriptors.size(); num > 0) {
         Add("PARAM c[{}]={{program.local[0..{}]}};", num, num - 1);
     }
+    switch (program.stage) {
+    case Stage::VertexA:
+    case Stage::VertexB:
+        stage_name = "vertex";
+        break;
+    case Stage::TessellationControl:
+    case Stage::TessellationEval:
+    case Stage::Geometry:
+        stage_name = "primitive";
+        break;
+    case Stage::Fragment:
+        stage_name = "fragment";
+        break;
+    case Stage::Compute:
+        stage_name = "compute";
+        break;
+    }
 }
 
 } // namespace Shader::Backend::GLASM
diff --git a/src/shader_recompiler/backend/glasm/emit_context.h b/src/shader_recompiler/backend/glasm/emit_context.h
index 37663c1c8f..4efe42adad 100644
--- a/src/shader_recompiler/backend/glasm/emit_context.h
+++ b/src/shader_recompiler/backend/glasm/emit_context.h
@@ -45,6 +45,8 @@ public:
 
     std::string code;
     RegAlloc reg_alloc{*this};
+
+    std::string_view stage_name = "invalid";
 };
 
 } // namespace Shader::Backend::GLASM
diff --git a/src/shader_recompiler/backend/glasm/emit_glasm.cpp b/src/shader_recompiler/backend/glasm/emit_glasm.cpp
index ad27b8b067..8b42cbf795 100644
--- a/src/shader_recompiler/backend/glasm/emit_glasm.cpp
+++ b/src/shader_recompiler/backend/glasm/emit_glasm.cpp
@@ -189,6 +189,12 @@ void SetupOptions(std::string& header, Info info) {
     if (info.uses_atomic_f16x2_add || info.uses_atomic_f16x2_min || info.uses_atomic_f16x2_max) {
         header += "OPTION NV_shader_atomic_fp16_vector;";
     }
+    if (info.uses_subgroup_invocation_id || info.uses_subgroup_mask) {
+        header += "OPTION NV_shader_thread_group;";
+    }
+    if (info.uses_subgroup_shuffles) {
+        header += "OPTION NV_shader_thread_shuffle;";
+    }
 }
 } // Anonymous namespace
 
diff --git a/src/shader_recompiler/backend/glasm/emit_glasm_instructions.h b/src/shader_recompiler/backend/glasm/emit_glasm_instructions.h
index 1bbd02022a..75613571fb 100644
--- a/src/shader_recompiler/backend/glasm/emit_glasm_instructions.h
+++ b/src/shader_recompiler/backend/glasm/emit_glasm_instructions.h
@@ -584,24 +584,24 @@ void EmitImageAtomicXor32(EmitContext& ctx, IR::Inst& inst, const IR::Value& ind
                           ScalarU32 value);
 void EmitImageAtomicExchange32(EmitContext& ctx, IR::Inst& inst, const IR::Value& index,
                                Register coords, ScalarU32 value);
-void EmitLaneId(EmitContext& ctx);
-void EmitVoteAll(EmitContext& ctx, ScalarS32 pred);
-void EmitVoteAny(EmitContext& ctx, ScalarS32 pred);
-void EmitVoteEqual(EmitContext& ctx, ScalarS32 pred);
-void EmitSubgroupBallot(EmitContext& ctx, ScalarS32 pred);
-void EmitSubgroupEqMask(EmitContext& ctx);
-void EmitSubgroupLtMask(EmitContext& ctx);
-void EmitSubgroupLeMask(EmitContext& ctx);
-void EmitSubgroupGtMask(EmitContext& ctx);
-void EmitSubgroupGeMask(EmitContext& ctx);
+void EmitLaneId(EmitContext& ctx, IR::Inst& inst);
+void EmitVoteAll(EmitContext& ctx, IR::Inst& inst, ScalarS32 pred);
+void EmitVoteAny(EmitContext& ctx, IR::Inst& inst, ScalarS32 pred);
+void EmitVoteEqual(EmitContext& ctx, IR::Inst& inst, ScalarS32 pred);
+void EmitSubgroupBallot(EmitContext& ctx, IR::Inst& inst, ScalarS32 pred);
+void EmitSubgroupEqMask(EmitContext& ctx, IR::Inst& inst);
+void EmitSubgroupLtMask(EmitContext& ctx, IR::Inst& inst);
+void EmitSubgroupLeMask(EmitContext& ctx, IR::Inst& inst);
+void EmitSubgroupGtMask(EmitContext& ctx, IR::Inst& inst);
+void EmitSubgroupGeMask(EmitContext& ctx, IR::Inst& inst);
 void EmitShuffleIndex(EmitContext& ctx, IR::Inst& inst, ScalarU32 value, ScalarU32 index,
-                      ScalarU32 clamp, ScalarU32 segmentation_mask);
+                      const IR::Value& clamp, const IR::Value& segmentation_mask);
 void EmitShuffleUp(EmitContext& ctx, IR::Inst& inst, ScalarU32 value, ScalarU32 index,
-                   ScalarU32 clamp, ScalarU32 segmentation_mask);
+                   const IR::Value& clamp, const IR::Value& segmentation_mask);
 void EmitShuffleDown(EmitContext& ctx, IR::Inst& inst, ScalarU32 value, ScalarU32 index,
-                     ScalarU32 clamp, ScalarU32 segmentation_mask);
+                     const IR::Value& clamp, const IR::Value& segmentation_mask);
 void EmitShuffleButterfly(EmitContext& ctx, IR::Inst& inst, ScalarU32 value, ScalarU32 index,
-                          ScalarU32 clamp, ScalarU32 segmentation_mask);
+                          const IR::Value& clamp, const IR::Value& segmentation_mask);
 void EmitFSwizzleAdd(EmitContext& ctx, ScalarF32 op_a, ScalarF32 op_b, ScalarU32 swizzle);
 void EmitDPdxFine(EmitContext& ctx, ScalarF32 op_a);
 void EmitDPdyFine(EmitContext& ctx, ScalarF32 op_a);
diff --git a/src/shader_recompiler/backend/glasm/emit_glasm_not_implemented.cpp b/src/shader_recompiler/backend/glasm/emit_glasm_not_implemented.cpp
index 85110bcc99..3c0a74e3cd 100644
--- a/src/shader_recompiler/backend/glasm/emit_glasm_not_implemented.cpp
+++ b/src/shader_recompiler/backend/glasm/emit_glasm_not_implemented.cpp
@@ -21,9 +21,7 @@ void EmitPhi(EmitContext& ctx, IR::Inst& inst) {
     NotImplemented();
 }
 
-void EmitVoid(EmitContext& ctx) {
-    NotImplemented();
-}
+void EmitVoid(EmitContext&) {}
 
 void EmitBranch(EmitContext& ctx) {
     NotImplemented();
@@ -636,84 +634,4 @@ void EmitImageAtomicExchange32(EmitContext& ctx, IR::Inst& inst, const IR::Value
     NotImplemented();
 }
 
-void EmitLaneId(EmitContext& ctx) {
-    NotImplemented();
-}
-
-void EmitVoteAll(EmitContext& ctx, ScalarS32 pred) {
-    NotImplemented();
-}
-
-void EmitVoteAny(EmitContext& ctx, ScalarS32 pred) {
-    NotImplemented();
-}
-
-void EmitVoteEqual(EmitContext& ctx, ScalarS32 pred) {
-    NotImplemented();
-}
-
-void EmitSubgroupBallot(EmitContext& ctx, ScalarS32 pred) {
-    NotImplemented();
-}
-
-void EmitSubgroupEqMask(EmitContext& ctx) {
-    NotImplemented();
-}
-
-void EmitSubgroupLtMask(EmitContext& ctx) {
-    NotImplemented();
-}
-
-void EmitSubgroupLeMask(EmitContext& ctx) {
-    NotImplemented();
-}
-
-void EmitSubgroupGtMask(EmitContext& ctx) {
-    NotImplemented();
-}
-
-void EmitSubgroupGeMask(EmitContext& ctx) {
-    NotImplemented();
-}
-
-void EmitShuffleIndex(EmitContext& ctx, IR::Inst& inst, ScalarU32 value, ScalarU32 index,
-                      ScalarU32 clamp, ScalarU32 segmentation_mask) {
-    NotImplemented();
-}
-
-void EmitShuffleUp(EmitContext& ctx, IR::Inst& inst, ScalarU32 value, ScalarU32 index,
-                   ScalarU32 clamp, ScalarU32 segmentation_mask) {
-    NotImplemented();
-}
-
-void EmitShuffleDown(EmitContext& ctx, IR::Inst& inst, ScalarU32 value, ScalarU32 index,
-                     ScalarU32 clamp, ScalarU32 segmentation_mask) {
-    NotImplemented();
-}
-
-void EmitShuffleButterfly(EmitContext& ctx, IR::Inst& inst, ScalarU32 value, ScalarU32 index,
-                          ScalarU32 clamp, ScalarU32 segmentation_mask) {
-    NotImplemented();
-}
-
-void EmitFSwizzleAdd(EmitContext& ctx, ScalarF32 op_a, ScalarF32 op_b, ScalarU32 swizzle) {
-    NotImplemented();
-}
-
-void EmitDPdxFine(EmitContext& ctx, ScalarF32 op_a) {
-    NotImplemented();
-}
-
-void EmitDPdyFine(EmitContext& ctx, ScalarF32 op_a) {
-    NotImplemented();
-}
-
-void EmitDPdxCoarse(EmitContext& ctx, ScalarF32 op_a) {
-    NotImplemented();
-}
-
-void EmitDPdyCoarse(EmitContext& ctx, ScalarF32 op_a) {
-    NotImplemented();
-}
-
 } // namespace Shader::Backend::GLASM
diff --git a/src/shader_recompiler/backend/glasm/emit_glasm_warp.cpp b/src/shader_recompiler/backend/glasm/emit_glasm_warp.cpp
index e69de29bb2..37eb577cde 100644
--- a/src/shader_recompiler/backend/glasm/emit_glasm_warp.cpp
+++ b/src/shader_recompiler/backend/glasm/emit_glasm_warp.cpp
@@ -0,0 +1,118 @@
+// Copyright 2021 yuzu Emulator Project
+// Licensed under GPLv2 or any later version
+// Refer to the license.txt file included.
+
+#include "shader_recompiler/backend/glasm/emit_context.h"
+#include "shader_recompiler/backend/glasm/emit_glasm_instructions.h"
+#include "shader_recompiler/frontend/ir/value.h"
+
+namespace Shader::Backend::GLASM {
+
+void EmitLaneId(EmitContext& ctx, IR::Inst& inst) {
+    ctx.Add("MOV.S {}.x,{}.threadid;", inst, ctx.stage_name);
+}
+
+void EmitVoteAll(EmitContext& ctx, IR::Inst& inst, ScalarS32 pred) {
+    ctx.Add("TGALL.S {}.x,{};", inst, pred);
+}
+
+void EmitVoteAny(EmitContext& ctx, IR::Inst& inst, ScalarS32 pred) {
+    ctx.Add("TGANY.S {}.x,{};", inst, pred);
+}
+
+void EmitVoteEqual(EmitContext& ctx, IR::Inst& inst, ScalarS32 pred) {
+    ctx.Add("TGEQ.S {}.x,{};", inst, pred);
+}
+
+void EmitSubgroupBallot(EmitContext& ctx, IR::Inst& inst, ScalarS32 pred) {
+    ctx.Add("TGBALLOT {}.x,{};", inst, pred);
+}
+
+void EmitSubgroupEqMask(EmitContext& ctx, IR::Inst& inst) {
+    ctx.Add("MOV.U {},{}.threadeqmask;", inst, ctx.stage_name);
+}
+
+void EmitSubgroupLtMask(EmitContext& ctx, IR::Inst& inst) {
+    ctx.Add("MOV.U {},{}.threadltmask;", inst, ctx.stage_name);
+}
+
+void EmitSubgroupLeMask(EmitContext& ctx, IR::Inst& inst) {
+    ctx.Add("MOV.U {},{}.threadlemask;", inst, ctx.stage_name);
+}
+
+void EmitSubgroupGtMask(EmitContext& ctx, IR::Inst& inst) {
+    ctx.Add("MOV.U {},{}.threadgtmask;", inst, ctx.stage_name);
+}
+
+void EmitSubgroupGeMask(EmitContext& ctx, IR::Inst& inst) {
+    ctx.Add("MOV.U {},{}.threadgemask;", inst, ctx.stage_name);
+}
+
+static void Shuffle(EmitContext& ctx, IR::Inst& inst, ScalarU32 value, ScalarU32 index,
+                    const IR::Value& clamp, const IR::Value& segmentation_mask,
+                    std::string_view op) {
+    std::string mask;
+    if (clamp.IsImmediate() && segmentation_mask.IsImmediate()) {
+        mask = fmt::to_string(clamp.U32() | (segmentation_mask.U32() << 8));
+    } else {
+        mask = "RC";
+        ctx.Add("BFI.U RC.x,{{5,8,0,0}},{},{};",
+                ScalarU32{ctx.reg_alloc.Consume(segmentation_mask)},
+                ScalarU32{ctx.reg_alloc.Consume(clamp)});
+    }
+    const Register value_ret{ctx.reg_alloc.Define(inst)};
+    IR::Inst* const in_bounds{inst.GetAssociatedPseudoOperation(IR::Opcode::GetInBoundsFromOp)};
+    if (in_bounds) {
+        const Register bounds_ret{ctx.reg_alloc.Define(*in_bounds)};
+        ctx.Add("SHF{}.U {},{},{},{};"
+                "MOV.U {}.x,{}.y;",
+                op, bounds_ret, value, index, mask, value_ret, bounds_ret);
+        in_bounds->Invalidate();
+    } else {
+        ctx.Add("SHF{}.U {},{},{},{};"
+                "MOV.U {}.x,{}.y;",
+                op, value_ret, value, index, mask, value_ret, value_ret);
+    }
+}
+
+void EmitShuffleIndex(EmitContext& ctx, IR::Inst& inst, ScalarU32 value, ScalarU32 index,
+                      const IR::Value& clamp, const IR::Value& segmentation_mask) {
+    Shuffle(ctx, inst, value, index, clamp, segmentation_mask, "IDX");
+}
+
+void EmitShuffleUp(EmitContext& ctx, IR::Inst& inst, ScalarU32 value, ScalarU32 index,
+                   const IR::Value& clamp, const IR::Value& segmentation_mask) {
+    Shuffle(ctx, inst, value, index, clamp, segmentation_mask, "UP");
+}
+
+void EmitShuffleDown(EmitContext& ctx, IR::Inst& inst, ScalarU32 value, ScalarU32 index,
+                     const IR::Value& clamp, const IR::Value& segmentation_mask) {
+    Shuffle(ctx, inst, value, index, clamp, segmentation_mask, "DOWN");
+}
+
+void EmitShuffleButterfly(EmitContext& ctx, IR::Inst& inst, ScalarU32 value, ScalarU32 index,
+                          const IR::Value& clamp, const IR::Value& segmentation_mask) {
+    Shuffle(ctx, inst, value, index, clamp, segmentation_mask, "XOR");
+}
+
+void EmitFSwizzleAdd(EmitContext&, ScalarF32, ScalarF32, ScalarU32) {
+    throw NotImplementedException("GLASM instruction");
+}
+
+void EmitDPdxFine(EmitContext&, ScalarF32) {
+    throw NotImplementedException("GLASM instruction");
+}
+
+void EmitDPdyFine(EmitContext&, ScalarF32) {
+    throw NotImplementedException("GLASM instruction");
+}
+
+void EmitDPdxCoarse(EmitContext&, ScalarF32) {
+    throw NotImplementedException("GLASM instruction");
+}
+
+void EmitDPdyCoarse(EmitContext&, ScalarF32) {
+    throw NotImplementedException("GLASM instruction");
+}
+
+} // namespace Shader::Backend::GLASM
diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp
index be88b76f7e..9759591bdd 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_context.cpp
@@ -1168,7 +1168,7 @@ void EmitContext::DefineInputs(const Info& info) {
         subgroup_mask_gt = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupGtMaskKHR);
         subgroup_mask_ge = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupGeMaskKHR);
     }
-    if (info.uses_subgroup_invocation_id ||
+    if (info.uses_subgroup_invocation_id || info.uses_subgroup_shuffles ||
         (profile.warp_size_potentially_larger_than_guest &&
          (info.uses_subgroup_vote || info.uses_subgroup_mask))) {
         subgroup_local_invocation_id =
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.cpp b/src/shader_recompiler/backend/spirv/emit_spirv.cpp
index 0681dfd168..2dad87e872 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv.cpp
@@ -318,7 +318,9 @@ void SetupCapabilities(const Profile& profile, const Info& info, EmitContext& ct
         ctx.AddExtension("SPV_KHR_shader_draw_parameters");
         ctx.AddCapability(spv::Capability::DrawParameters);
     }
-    if ((info.uses_subgroup_vote || info.uses_subgroup_invocation_id) && profile.support_vote) {
+    if ((info.uses_subgroup_vote || info.uses_subgroup_invocation_id ||
+         info.uses_subgroup_shuffles) &&
+        profile.support_vote) {
         ctx.AddExtension("SPV_KHR_shader_ballot");
         ctx.AddCapability(spv::Capability::SubgroupBallotKHR);
         if (!profile.warp_size_potentially_larger_than_guest) {
diff --git a/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp b/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
index 13b793d572..ea08aacc3e 100644
--- a/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
+++ b/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
@@ -504,11 +504,13 @@ void VisitUsages(Info& info, IR::Inst& inst) {
         info.uses_is_helper_invocation = true;
         break;
     case IR::Opcode::LaneId:
+        info.uses_subgroup_invocation_id = true;
+        break;
     case IR::Opcode::ShuffleIndex:
     case IR::Opcode::ShuffleUp:
     case IR::Opcode::ShuffleDown:
     case IR::Opcode::ShuffleButterfly:
-        info.uses_subgroup_invocation_id = true;
+        info.uses_subgroup_shuffles = true;
         break;
     case IR::Opcode::GetCbufU8:
     case IR::Opcode::GetCbufS8:
diff --git a/src/shader_recompiler/shader_info.h b/src/shader_recompiler/shader_info.h
index a50a9a18c5..d6c32fbe52 100644
--- a/src/shader_recompiler/shader_info.h
+++ b/src/shader_recompiler/shader_info.h
@@ -116,6 +116,7 @@ struct Info {
     bool uses_sample_id{};
     bool uses_is_helper_invocation{};
     bool uses_subgroup_invocation_id{};
+    bool uses_subgroup_shuffles{};
     std::array<bool, 30> uses_patches{};
 
     std::array<InputVarying, 32> input_generics{};
-- 
cgit v1.2.3-70-g09d2