From be94ee88d227d0d3dbeabe9ade98bacd910c7a7e Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Fri, 5 Feb 2021 19:19:36 -0300
Subject: shader: Make typed IR

---
 .../maxwell/translate/impl/floating_point_add.cpp  | 12 +++++------
 .../impl/floating_point_conversion_integer.cpp     | 20 +++++++++----------
 .../impl/floating_point_fused_multiply_add.cpp     | 16 +++++++--------
 .../impl/floating_point_multi_function.cpp         |  6 +++---
 .../translate/impl/floating_point_multiply.cpp     | 13 ++++++------
 .../frontend/maxwell/translate/impl/impl.cpp       | 20 +++++++++++++++++++
 .../frontend/maxwell/translate/impl/impl.h         |  6 ++++++
 .../translate/impl/load_store_attribute.cpp        | 23 +++++++++++-----------
 .../maxwell/translate/impl/load_store_memory.cpp   |  4 ++--
 9 files changed, 74 insertions(+), 46 deletions(-)

(limited to 'src/shader_recompiler/frontend/maxwell')

diff --git a/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_add.cpp b/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_add.cpp
index d2c44b9ccd..cb3a326cfa 100644
--- a/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_add.cpp
+++ b/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_add.cpp
@@ -11,7 +11,7 @@ namespace Shader::Maxwell {
 namespace {
 
 void FADD(TranslatorVisitor& v, u64 insn, bool sat, bool cc, bool ftz, FpRounding fp_rounding,
-          const IR::U32& src_b, bool abs_a, bool neg_a, bool abs_b, bool neg_b) {
+          const IR::F32& src_b, bool abs_a, bool neg_a, bool abs_b, bool neg_b) {
     union {
         u64 raw;
         BitField<0, 8, IR::Reg> dest_reg;
@@ -24,17 +24,17 @@ void FADD(TranslatorVisitor& v, u64 insn, bool sat, bool cc, bool ftz, FpRoundin
     if (cc) {
         throw NotImplementedException("FADD CC");
     }
-    const IR::U32 op_a{v.ir.FPAbsNeg(v.X(fadd.src_a), abs_a, neg_a)};
-    const IR::U32 op_b{v.ir.FPAbsNeg(src_b, abs_b, neg_b)};
+    const IR::F32 op_a{v.ir.FPAbsNeg(v.F(fadd.src_a), abs_a, neg_a)};
+    const IR::F32 op_b{v.ir.FPAbsNeg(src_b, abs_b, neg_b)};
     IR::FpControl control{
         .no_contraction{true},
         .rounding{CastFpRounding(fp_rounding)},
         .fmz_mode{ftz ? IR::FmzMode::FTZ : IR::FmzMode::None},
     };
-    v.X(fadd.dest_reg, v.ir.FPAdd(op_a, op_b, control));
+    v.F(fadd.dest_reg, v.ir.FPAdd(op_a, op_b, control));
 }
 
-void FADD(TranslatorVisitor& v, u64 insn, const IR::U32& src_b) {
+void FADD(TranslatorVisitor& v, u64 insn, const IR::F32& src_b) {
     union {
         u64 raw;
         BitField<39, 2, FpRounding> fp_rounding;
@@ -53,7 +53,7 @@ void FADD(TranslatorVisitor& v, u64 insn, const IR::U32& src_b) {
 } // Anonymous namespace
 
 void TranslatorVisitor::FADD_reg(u64 insn) {
-    FADD(*this, insn, GetReg20(insn));
+    FADD(*this, insn, GetReg20F(insn));
 }
 
 void TranslatorVisitor::FADD_cbuf(u64) {
diff --git a/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_conversion_integer.cpp b/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_conversion_integer.cpp
index c4288d9a83..acd8445ad1 100644
--- a/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_conversion_integer.cpp
+++ b/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_conversion_integer.cpp
@@ -55,21 +55,21 @@ size_t BitSize(DestFormat dest_format) {
     }
 }
 
-void TranslateF2I(TranslatorVisitor& v, u64 insn, const IR::U16U32U64& op_a) {
+void TranslateF2I(TranslatorVisitor& v, u64 insn, const IR::F16F32F64& src_a) {
     // F2I is used to convert from a floating point value to an integer
     const F2I f2i{insn};
 
-    const IR::U16U32U64 float_value{v.ir.FPAbsNeg(op_a, f2i.abs != 0, f2i.neg != 0)};
-    const IR::U16U32U64 rounded_value{[&] {
+    const IR::F16F32F64 op_a{v.ir.FPAbsNeg(src_a, f2i.abs != 0, f2i.neg != 0)};
+    const IR::F16F32F64 rounded_value{[&] {
         switch (f2i.rounding) {
         case Rounding::Round:
-            return v.ir.FPRoundEven(float_value);
+            return v.ir.FPRoundEven(op_a);
         case Rounding::Floor:
-            return v.ir.FPFloor(float_value);
+            return v.ir.FPFloor(op_a);
         case Rounding::Ceil:
-            return v.ir.FPCeil(float_value);
+            return v.ir.FPCeil(op_a);
         case Rounding::Trunc:
-            return v.ir.FPTrunc(float_value);
+            return v.ir.FPTrunc(op_a);
         default:
             throw NotImplementedException("Invalid F2I rounding {}", f2i.rounding.Value());
         }
@@ -105,12 +105,12 @@ void TranslatorVisitor::F2I_reg(u64 insn) {
         BitField<20, 8, IR::Reg> src_reg;
     } const f2i{insn};
 
-    const IR::U16U32U64 op_a{[&]() -> IR::U16U32U64 {
+    const IR::F16F32F64 op_a{[&]() -> IR::F16F32F64 {
         switch (f2i.base.src_format) {
         case SrcFormat::F16:
-            return ir.CompositeExtract(ir.UnpackFloat2x16(X(f2i.src_reg)), f2i.base.half);
+            return IR::F16{ir.CompositeExtract(ir.UnpackFloat2x16(X(f2i.src_reg)), f2i.base.half)};
         case SrcFormat::F32:
-            return X(f2i.src_reg);
+            return F(f2i.src_reg);
         case SrcFormat::F64:
             return ir.PackDouble2x32(ir.CompositeConstruct(X(f2i.src_reg), X(f2i.src_reg + 1)));
         default:
diff --git a/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_fused_multiply_add.cpp b/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_fused_multiply_add.cpp
index 30ca052ec5..1464f2807a 100644
--- a/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_fused_multiply_add.cpp
+++ b/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_fused_multiply_add.cpp
@@ -9,7 +9,7 @@
 
 namespace Shader::Maxwell {
 namespace {
-void FFMA(TranslatorVisitor& v, u64 insn, const IR::U32& src_b, const IR::U32& src_c, bool neg_a,
+void FFMA(TranslatorVisitor& v, u64 insn, const IR::F32& src_b, const IR::F32& src_c, bool neg_a,
           bool neg_b, bool neg_c, bool sat, bool cc, FmzMode fmz_mode, FpRounding fp_rounding) {
     union {
         u64 raw;
@@ -23,18 +23,18 @@ void FFMA(TranslatorVisitor& v, u64 insn, const IR::U32& src_b, const IR::U32& s
     if (cc) {
         throw NotImplementedException("FFMA CC");
     }
-    const IR::U32 op_a{v.ir.FPAbsNeg(v.X(ffma.src_a), false, neg_a)};
-    const IR::U32 op_b{v.ir.FPAbsNeg(src_b, false, neg_b)};
-    const IR::U32 op_c{v.ir.FPAbsNeg(src_c, false, neg_c)};
+    const IR::F32 op_a{v.ir.FPAbsNeg(v.F(ffma.src_a), false, neg_a)};
+    const IR::F32 op_b{v.ir.FPAbsNeg(src_b, false, neg_b)};
+    const IR::F32 op_c{v.ir.FPAbsNeg(src_c, false, neg_c)};
     const IR::FpControl fp_control{
         .no_contraction{true},
         .rounding{CastFpRounding(fp_rounding)},
         .fmz_mode{CastFmzMode(fmz_mode)},
     };
-    v.X(ffma.dest_reg, v.ir.FPFma(op_a, op_b, op_c, fp_control));
+    v.F(ffma.dest_reg, v.ir.FPFma(op_a, op_b, op_c, fp_control));
 }
 
-void FFMA(TranslatorVisitor& v, u64 insn, const IR::U32& src_b, const IR::U32& src_c) {
+void FFMA(TranslatorVisitor& v, u64 insn, const IR::F32& src_b, const IR::F32& src_c) {
     union {
         u64 raw;
         BitField<47, 1, u64> cc;
@@ -51,7 +51,7 @@ void FFMA(TranslatorVisitor& v, u64 insn, const IR::U32& src_b, const IR::U32& s
 } // Anonymous namespace
 
 void TranslatorVisitor::FFMA_reg(u64 insn) {
-    FFMA(*this, insn, GetReg20(insn), GetReg39(insn));
+    FFMA(*this, insn, GetReg20F(insn), GetReg39F(insn));
 }
 
 void TranslatorVisitor::FFMA_rc(u64) {
@@ -59,7 +59,7 @@ void TranslatorVisitor::FFMA_rc(u64) {
 }
 
 void TranslatorVisitor::FFMA_cr(u64 insn) {
-    FFMA(*this, insn, GetCbuf(insn), GetReg39(insn));
+    FFMA(*this, insn, GetCbufF(insn), GetReg39F(insn));
 }
 
 void TranslatorVisitor::FFMA_imm(u64) {
diff --git a/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_multi_function.cpp b/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_multi_function.cpp
index e2ab0dab22..90cddb18b4 100644
--- a/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_multi_function.cpp
+++ b/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_multi_function.cpp
@@ -35,8 +35,8 @@ void TranslatorVisitor::MUFU(u64 insn) {
         BitField<50, 1, u64> sat;
     } const mufu{insn};
 
-    const IR::U32 op_a{ir.FPAbsNeg(X(mufu.src_reg), mufu.abs != 0, mufu.neg != 0)};
-    IR::U32 value{[&]() -> IR::U32 {
+    const IR::F32 op_a{ir.FPAbsNeg(F(mufu.src_reg), mufu.abs != 0, mufu.neg != 0)};
+    IR::F32 value{[&]() -> IR::F32 {
         switch (mufu.operation) {
         case Operation::Cos:
             return ir.FPCosNotReduced(op_a);
@@ -65,7 +65,7 @@ void TranslatorVisitor::MUFU(u64 insn) {
         value = ir.FPSaturate(value);
     }
 
-    X(mufu.dest_reg, value);
+    F(mufu.dest_reg, value);
 }
 
 } // namespace Shader::Maxwell
diff --git a/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_multiply.cpp b/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_multiply.cpp
index 743a1e2f0f..1b1d38be7a 100644
--- a/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_multiply.cpp
+++ b/src/shader_recompiler/frontend/maxwell/translate/impl/floating_point_multiply.cpp
@@ -4,6 +4,7 @@
 
 #include "common/bit_field.h"
 #include "common/common_types.h"
+#include "shader_recompiler/frontend/ir/ir_emitter.h"
 #include "shader_recompiler/frontend/ir/modifiers.h"
 #include "shader_recompiler/frontend/maxwell/translate/impl/common_encoding.h"
 #include "shader_recompiler/frontend/maxwell/translate/impl/impl.h"
@@ -43,7 +44,7 @@ float ScaleFactor(Scale scale) {
     throw NotImplementedException("Invalid FMUL scale {}", scale);
 }
 
-void FMUL(TranslatorVisitor& v, u64 insn, const IR::U32& src_b, FmzMode fmz_mode,
+void FMUL(TranslatorVisitor& v, u64 insn, const IR::F32& src_b, FmzMode fmz_mode,
           FpRounding fp_rounding, Scale scale, bool sat, bool cc, bool neg_b) {
     union {
         u64 raw;
@@ -57,23 +58,23 @@ void FMUL(TranslatorVisitor& v, u64 insn, const IR::U32& src_b, FmzMode fmz_mode
     if (sat) {
         throw NotImplementedException("FMUL SAT");
     }
-    IR::U32 op_a{v.X(fmul.src_a)};
+    IR::F32 op_a{v.F(fmul.src_a)};
     if (scale != Scale::None) {
         if (fmz_mode != FmzMode::FTZ || fp_rounding != FpRounding::RN) {
             throw NotImplementedException("FMUL scale with non-FMZ or non-RN modifiers");
         }
         op_a = v.ir.FPMul(op_a, v.ir.Imm32(ScaleFactor(scale)));
     }
-    const IR::U32 op_b{v.ir.FPAbsNeg(src_b, false, neg_b)};
+    const IR::F32 op_b{v.ir.FPAbsNeg(src_b, false, neg_b)};
     const IR::FpControl fp_control{
         .no_contraction{true},
         .rounding{CastFpRounding(fp_rounding)},
         .fmz_mode{CastFmzMode(fmz_mode)},
     };
-    v.X(fmul.dest_reg, v.ir.FPMul(op_a, op_b, fp_control));
+    v.F(fmul.dest_reg, v.ir.FPMul(op_a, op_b, fp_control));
 }
 
-void FMUL(TranslatorVisitor& v, u64 insn, const IR::U32& src_b) {
+void FMUL(TranslatorVisitor& v, u64 insn, const IR::F32& src_b) {
     union {
         u64 raw;
         BitField<39, 2, FpRounding> fp_rounding;
@@ -90,7 +91,7 @@ void FMUL(TranslatorVisitor& v, u64 insn, const IR::U32& src_b) {
 } // Anonymous namespace
 
 void TranslatorVisitor::FMUL_reg(u64 insn) {
-    return FMUL(*this, insn, GetReg20(insn));
+    return FMUL(*this, insn, GetReg20F(insn));
 }
 
 void TranslatorVisitor::FMUL_cbuf(u64) {
diff --git a/src/shader_recompiler/frontend/maxwell/translate/impl/impl.cpp b/src/shader_recompiler/frontend/maxwell/translate/impl/impl.cpp
index 548c7f611d..3c9eaddd94 100644
--- a/src/shader_recompiler/frontend/maxwell/translate/impl/impl.cpp
+++ b/src/shader_recompiler/frontend/maxwell/translate/impl/impl.cpp
@@ -12,10 +12,18 @@ IR::U32 TranslatorVisitor::X(IR::Reg reg) {
     return ir.GetReg(reg);
 }
 
+IR::F32 TranslatorVisitor::F(IR::Reg reg) {
+    return ir.BitCast<IR::F32>(X(reg));
+}
+
 void TranslatorVisitor::X(IR::Reg dest_reg, const IR::U32& value) {
     ir.SetReg(dest_reg, value);
 }
 
+void TranslatorVisitor::F(IR::Reg dest_reg, const IR::F32& value) {
+    X(dest_reg, ir.BitCast<IR::U32>(value));
+}
+
 IR::U32 TranslatorVisitor::GetReg20(u64 insn) {
     union {
         u64 raw;
@@ -32,6 +40,14 @@ IR::U32 TranslatorVisitor::GetReg39(u64 insn) {
     return X(reg.index);
 }
 
+IR::F32 TranslatorVisitor::GetReg20F(u64 insn) {
+    return ir.BitCast<IR::F32>(GetReg20(insn));
+}
+
+IR::F32 TranslatorVisitor::GetReg39F(u64 insn) {
+    return ir.BitCast<IR::F32>(GetReg39(insn));
+}
+
 IR::U32 TranslatorVisitor::GetCbuf(u64 insn) {
     union {
         u64 raw;
@@ -49,6 +65,10 @@ IR::U32 TranslatorVisitor::GetCbuf(u64 insn) {
     return ir.GetCbuf(binding, byte_offset);
 }
 
+IR::F32 TranslatorVisitor::GetCbufF(u64 insn) {
+    return ir.BitCast<IR::F32>(GetCbuf(insn));
+}
+
 IR::U32 TranslatorVisitor::GetImm20(u64 insn) {
     union {
         u64 raw;
diff --git a/src/shader_recompiler/frontend/maxwell/translate/impl/impl.h b/src/shader_recompiler/frontend/maxwell/translate/impl/impl.h
index ef6d977fef..b701605d73 100644
--- a/src/shader_recompiler/frontend/maxwell/translate/impl/impl.h
+++ b/src/shader_recompiler/frontend/maxwell/translate/impl/impl.h
@@ -296,12 +296,18 @@ public:
     void XMAD_imm(u64 insn);
 
     [[nodiscard]] IR::U32 X(IR::Reg reg);
+    [[nodiscard]] IR::F32 F(IR::Reg reg);
+
     void X(IR::Reg dest_reg, const IR::U32& value);
+    void F(IR::Reg dest_reg, const IR::F32& value);
 
     [[nodiscard]] IR::U32 GetReg20(u64 insn);
     [[nodiscard]] IR::U32 GetReg39(u64 insn);
+    [[nodiscard]] IR::F32 GetReg20F(u64 insn);
+    [[nodiscard]] IR::F32 GetReg39F(u64 insn);
 
     [[nodiscard]] IR::U32 GetCbuf(u64 insn);
+    [[nodiscard]] IR::F32 GetCbufF(u64 insn);
 
     [[nodiscard]] IR::U32 GetImm20(u64 insn);
 
diff --git a/src/shader_recompiler/frontend/maxwell/translate/impl/load_store_attribute.cpp b/src/shader_recompiler/frontend/maxwell/translate/impl/load_store_attribute.cpp
index 23512db1a4..de65173e8d 100644
--- a/src/shader_recompiler/frontend/maxwell/translate/impl/load_store_attribute.cpp
+++ b/src/shader_recompiler/frontend/maxwell/translate/impl/load_store_attribute.cpp
@@ -5,22 +5,23 @@
 #include "common/bit_field.h"
 #include "common/common_types.h"
 #include "shader_recompiler/exception.h"
+#include "shader_recompiler/frontend/ir/ir_emitter.h"
 #include "shader_recompiler/frontend/maxwell/opcode.h"
 #include "shader_recompiler/frontend/maxwell/translate/impl/impl.h"
 
 namespace Shader::Maxwell {
 namespace {
 enum class InterpolationMode : u64 {
-    Pass = 0,
-    Multiply = 1,
-    Constant = 2,
-    Sc = 3,
+    Pass,
+    Multiply,
+    Constant,
+    Sc,
 };
 
 enum class SampleMode : u64 {
-    Default = 0,
-    Centroid = 1,
-    Offset = 2,
+    Default,
+    Centroid,
+    Offset,
 };
 } // Anonymous namespace
 
@@ -54,12 +55,12 @@ void TranslatorVisitor::IPA(u64 insn) {
     }
 
     const IR::Attribute attribute{ipa.attribute};
-    IR::U32 value{ir.GetAttribute(attribute)};
+    IR::F32 value{ir.GetAttribute(attribute)};
     if (IR::IsGeneric(attribute)) {
         // const bool is_perspective{UnimplementedReadHeader(GenericAttributeIndex(attribute))};
         const bool is_perspective{false};
         if (is_perspective) {
-            const IR::U32 rcp_position_w{ir.FPRecip(ir.GetAttribute(IR::Attribute::PositionW))};
+            const IR::F32 rcp_position_w{ir.FPRecip(ir.GetAttribute(IR::Attribute::PositionW))};
             value = ir.FPMul(value, rcp_position_w);
         }
     }
@@ -68,7 +69,7 @@ void TranslatorVisitor::IPA(u64 insn) {
     case InterpolationMode::Pass:
         break;
     case InterpolationMode::Multiply:
-        value = ir.FPMul(value, ir.GetReg(ipa.multiplier));
+        value = ir.FPMul(value, F(ipa.multiplier));
         break;
     case InterpolationMode::Constant:
         throw NotImplementedException("IPA.CONSTANT");
@@ -86,7 +87,7 @@ void TranslatorVisitor::IPA(u64 insn) {
         value = ir.FPSaturate(value);
     }
 
-    ir.SetReg(ipa.dest_reg, value);
+    F(ipa.dest_reg, value);
 }
 
 } // namespace Shader::Maxwell
diff --git a/src/shader_recompiler/frontend/maxwell/translate/impl/load_store_memory.cpp b/src/shader_recompiler/frontend/maxwell/translate/impl/load_store_memory.cpp
index c9669c6178..9f1570479d 100644
--- a/src/shader_recompiler/frontend/maxwell/translate/impl/load_store_memory.cpp
+++ b/src/shader_recompiler/frontend/maxwell/translate/impl/load_store_memory.cpp
@@ -114,7 +114,7 @@ void TranslatorVisitor::LDG(u64 insn) {
         }
         const IR::Value vector{ir.LoadGlobal64(address)};
         for (int i = 0; i < 2; ++i) {
-            X(dest_reg + i, ir.CompositeExtract(vector, i));
+            X(dest_reg + i, IR::U32{ir.CompositeExtract(vector, i)});
         }
         break;
     }
@@ -124,7 +124,7 @@ void TranslatorVisitor::LDG(u64 insn) {
         }
         const IR::Value vector{ir.LoadGlobal128(address)};
         for (int i = 0; i < 4; ++i) {
-            X(dest_reg + i, ir.CompositeExtract(vector, i));
+            X(dest_reg + i, IR::U32{ir.CompositeExtract(vector, i)});
         }
         break;
     }
-- 
cgit v1.2.3-70-g09d2