From be94ee88d227d0d3dbeabe9ade98bacd910c7a7e Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Fri, 5 Feb 2021 19:19:36 -0300
Subject: shader: Make typed IR

---
 src/shader_recompiler/frontend/ir/ir_emitter.cpp | 275 +++++++++++++++--------
 1 file changed, 184 insertions(+), 91 deletions(-)

(limited to 'src/shader_recompiler/frontend/ir/ir_emitter.cpp')

diff --git a/src/shader_recompiler/frontend/ir/ir_emitter.cpp b/src/shader_recompiler/frontend/ir/ir_emitter.cpp
index 1c5ae0109b..9d7dc034c9 100644
--- a/src/shader_recompiler/frontend/ir/ir_emitter.cpp
+++ b/src/shader_recompiler/frontend/ir/ir_emitter.cpp
@@ -32,16 +32,16 @@ U32 IREmitter::Imm32(s32 value) const {
     return U32{Value{static_cast<u32>(value)}};
 }
 
-U32 IREmitter::Imm32(f32 value) const {
-    return U32{Value{Common::BitCast<u32>(value)}};
+F32 IREmitter::Imm32(f32 value) const {
+    return F32{Value{value}};
 }
 
 U64 IREmitter::Imm64(u64 value) const {
     return U64{Value{value}};
 }
 
-U64 IREmitter::Imm64(f64 value) const {
-    return U64{Value{Common::BitCast<u64>(value)}};
+F64 IREmitter::Imm64(f64 value) const {
+    return F64{Value{value}};
 }
 
 void IREmitter::Branch(IR::Block* label) {
@@ -121,11 +121,11 @@ void IREmitter::SetOFlag(const U1& value) {
     Inst(Opcode::SetOFlag, value);
 }
 
-U32 IREmitter::GetAttribute(IR::Attribute attribute) {
-    return Inst<U32>(Opcode::GetAttribute, attribute);
+F32 IREmitter::GetAttribute(IR::Attribute attribute) {
+    return Inst<F32>(Opcode::GetAttribute, attribute);
 }
 
-void IREmitter::SetAttribute(IR::Attribute attribute, const U32& value) {
+void IREmitter::SetAttribute(IR::Attribute attribute, const F32& value) {
     Inst(Opcode::SetAttribute, attribute, value);
 }
 
@@ -225,50 +225,113 @@ U1 IREmitter::GetOverflowFromOp(const Value& op) {
     return Inst<U1>(Opcode::GetOverflowFromOp, op);
 }
 
-U16U32U64 IREmitter::FPAdd(const U16U32U64& a, const U16U32U64& b, FpControl control) {
+F16F32F64 IREmitter::FPAdd(const F16F32F64& a, const F16F32F64& b, FpControl control) {
     if (a.Type() != a.Type()) {
         throw InvalidArgument("Mismatching types {} and {}", a.Type(), b.Type());
     }
     switch (a.Type()) {
-    case Type::U16:
-        return Inst<U16>(Opcode::FPAdd16, Flags{control}, a, b);
-    case Type::U32:
-        return Inst<U32>(Opcode::FPAdd32, Flags{control}, a, b);
-    case Type::U64:
-        return Inst<U64>(Opcode::FPAdd64, Flags{control}, a, b);
+    case Type::F16:
+        return Inst<F16>(Opcode::FPAdd16, Flags{control}, a, b);
+    case Type::F32:
+        return Inst<F32>(Opcode::FPAdd32, Flags{control}, a, b);
+    case Type::F64:
+        return Inst<F64>(Opcode::FPAdd64, Flags{control}, a, b);
     default:
         ThrowInvalidType(a.Type());
     }
 }
 
-Value IREmitter::CompositeConstruct(const UAny& e1, const UAny& e2) {
+Value IREmitter::CompositeConstruct(const Value& e1, const Value& e2) {
     if (e1.Type() != e2.Type()) {
         throw InvalidArgument("Mismatching types {} and {}", e1.Type(), e2.Type());
     }
-    return Inst(Opcode::CompositeConstruct2, e1, e2);
+    switch (e1.Type()) {
+    case Type::U32:
+        return Inst(Opcode::CompositeConstructU32x2, e1, e2);
+    case Type::F16:
+        return Inst(Opcode::CompositeConstructF16x2, e1, e2);
+    case Type::F32:
+        return Inst(Opcode::CompositeConstructF32x2, e1, e2);
+    case Type::F64:
+        return Inst(Opcode::CompositeConstructF64x2, e1, e2);
+    default:
+        ThrowInvalidType(e1.Type());
+    }
 }
 
-Value IREmitter::CompositeConstruct(const UAny& e1, const UAny& e2, const UAny& e3) {
+Value IREmitter::CompositeConstruct(const Value& e1, const Value& e2, const Value& e3) {
     if (e1.Type() != e2.Type() || e1.Type() != e3.Type()) {
         throw InvalidArgument("Mismatching types {}, {}, and {}", e1.Type(), e2.Type(), e3.Type());
     }
-    return Inst(Opcode::CompositeConstruct3, e1, e2, e3);
+    switch (e1.Type()) {
+    case Type::U32:
+        return Inst(Opcode::CompositeConstructU32x3, e1, e2, e3);
+    case Type::F16:
+        return Inst(Opcode::CompositeConstructF16x3, e1, e2, e3);
+    case Type::F32:
+        return Inst(Opcode::CompositeConstructF32x3, e1, e2, e3);
+    case Type::F64:
+        return Inst(Opcode::CompositeConstructF64x3, e1, e2, e3);
+    default:
+        ThrowInvalidType(e1.Type());
+    }
 }
 
-Value IREmitter::CompositeConstruct(const UAny& e1, const UAny& e2, const UAny& e3,
-                                    const UAny& e4) {
+Value IREmitter::CompositeConstruct(const Value& e1, const Value& e2, const Value& e3,
+                                    const Value& e4) {
     if (e1.Type() != e2.Type() || e1.Type() != e3.Type() || e1.Type() != e4.Type()) {
         throw InvalidArgument("Mismatching types {}, {}, {}, and {}", e1.Type(), e2.Type(),
                               e3.Type(), e4.Type());
     }
-    return Inst(Opcode::CompositeConstruct4, e1, e2, e3, e4);
+    switch (e1.Type()) {
+    case Type::U32:
+        return Inst(Opcode::CompositeConstructU32x4, e1, e2, e3, e4);
+    case Type::F16:
+        return Inst(Opcode::CompositeConstructF16x4, e1, e2, e3, e4);
+    case Type::F32:
+        return Inst(Opcode::CompositeConstructF32x4, e1, e2, e3, e4);
+    case Type::F64:
+        return Inst(Opcode::CompositeConstructF64x4, e1, e2, e3, e4);
+    default:
+        ThrowInvalidType(e1.Type());
+    }
 }
 
-UAny IREmitter::CompositeExtract(const Value& vector, size_t element) {
-    if (element >= 4) {
-        throw InvalidArgument("Out of bounds element {}", element);
+Value IREmitter::CompositeExtract(const Value& vector, size_t element) {
+    const auto read = [&](Opcode opcode, size_t limit) -> Value {
+        if (element >= limit) {
+            throw InvalidArgument("Out of bounds element {}", element);
+        }
+        return Inst(opcode, vector, Value{static_cast<u32>(element)});
+    };
+    switch (vector.Type()) {
+    case Type::U32x2:
+        return read(Opcode::CompositeExtractU32x2, 2);
+    case Type::U32x3:
+        return read(Opcode::CompositeExtractU32x3, 3);
+    case Type::U32x4:
+        return read(Opcode::CompositeExtractU32x4, 4);
+    case Type::F16x2:
+        return read(Opcode::CompositeExtractF16x2, 2);
+    case Type::F16x3:
+        return read(Opcode::CompositeExtractF16x3, 3);
+    case Type::F16x4:
+        return read(Opcode::CompositeExtractF16x4, 4);
+    case Type::F32x2:
+        return read(Opcode::CompositeExtractF32x2, 2);
+    case Type::F32x3:
+        return read(Opcode::CompositeExtractF32x3, 3);
+    case Type::F32x4:
+        return read(Opcode::CompositeExtractF32x4, 4);
+    case Type::F64x2:
+        return read(Opcode::CompositeExtractF64x2, 2);
+    case Type::F64x3:
+        return read(Opcode::CompositeExtractF64x3, 3);
+    case Type::F64x4:
+        return read(Opcode::CompositeExtractF64x4, 4);
+    default:
+        ThrowInvalidType(vector.Type());
     }
-    return Inst<UAny>(Opcode::CompositeExtract, vector, Imm32(static_cast<u32>(element)));
 }
 
 UAny IREmitter::Select(const U1& condition, const UAny& true_value, const UAny& false_value) {
@@ -289,6 +352,36 @@ UAny IREmitter::Select(const U1& condition, const UAny& true_value, const UAny&
     }
 }
 
+template <>
+IR::U32 IREmitter::BitCast<IR::U32, IR::F32>(const IR::F32& value) {
+    return Inst<IR::U32>(Opcode::BitCastU32F32, value);
+}
+
+template <>
+IR::F32 IREmitter::BitCast<IR::F32, IR::U32>(const IR::U32& value) {
+    return Inst<IR::F32>(Opcode::BitCastF32U32, value);
+}
+
+template <>
+IR::U16 IREmitter::BitCast<IR::U16, IR::F16>(const IR::F16& value) {
+    return Inst<IR::U16>(Opcode::BitCastU16F16, value);
+}
+
+template <>
+IR::F16 IREmitter::BitCast<IR::F16, IR::U16>(const IR::U16& value) {
+    return Inst<IR::F16>(Opcode::BitCastF16U16, value);
+}
+
+template <>
+IR::U64 IREmitter::BitCast<IR::U64, IR::F64>(const IR::F64& value) {
+    return Inst<IR::U64>(Opcode::BitCastU64F64, value);
+}
+
+template <>
+IR::F64 IREmitter::BitCast<IR::F64, IR::U64>(const IR::U64& value) {
+    return Inst<IR::F64>(Opcode::BitCastF64U64, value);
+}
+
 U64 IREmitter::PackUint2x32(const Value& vector) {
     return Inst<U64>(Opcode::PackUint2x32, vector);
 }
@@ -305,75 +398,75 @@ Value IREmitter::UnpackFloat2x16(const U32& value) {
     return Inst<Value>(Opcode::UnpackFloat2x16, value);
 }
 
-U64 IREmitter::PackDouble2x32(const Value& vector) {
-    return Inst<U64>(Opcode::PackDouble2x32, vector);
+F64 IREmitter::PackDouble2x32(const Value& vector) {
+    return Inst<F64>(Opcode::PackDouble2x32, vector);
 }
 
-Value IREmitter::UnpackDouble2x32(const U64& value) {
+Value IREmitter::UnpackDouble2x32(const F64& value) {
     return Inst<Value>(Opcode::UnpackDouble2x32, value);
 }
 
-U16U32U64 IREmitter::FPMul(const U16U32U64& a, const U16U32U64& b, FpControl control) {
+F16F32F64 IREmitter::FPMul(const F16F32F64& a, const F16F32F64& b, FpControl control) {
     if (a.Type() != b.Type()) {
         throw InvalidArgument("Mismatching types {} and {}", a.Type(), b.Type());
     }
     switch (a.Type()) {
-    case Type::U16:
-        return Inst<U16>(Opcode::FPMul16, Flags{control}, a, b);
-    case Type::U32:
-        return Inst<U32>(Opcode::FPMul32, Flags{control}, a, b);
-    case Type::U64:
-        return Inst<U64>(Opcode::FPMul64, Flags{control}, a, b);
+    case Type::F16:
+        return Inst<F16>(Opcode::FPMul16, Flags{control}, a, b);
+    case Type::F32:
+        return Inst<F32>(Opcode::FPMul32, Flags{control}, a, b);
+    case Type::F64:
+        return Inst<F64>(Opcode::FPMul64, Flags{control}, a, b);
     default:
         ThrowInvalidType(a.Type());
     }
 }
 
-U16U32U64 IREmitter::FPFma(const U16U32U64& a, const U16U32U64& b, const U16U32U64& c,
+F16F32F64 IREmitter::FPFma(const F16F32F64& a, const F16F32F64& b, const F16F32F64& c,
                            FpControl control) {
     if (a.Type() != b.Type() || a.Type() != c.Type()) {
         throw InvalidArgument("Mismatching types {}, {}, and {}", a.Type(), b.Type(), c.Type());
     }
     switch (a.Type()) {
-    case Type::U16:
-        return Inst<U16>(Opcode::FPFma16, Flags{control}, a, b, c);
-    case Type::U32:
-        return Inst<U32>(Opcode::FPFma32, Flags{control}, a, b, c);
-    case Type::U64:
-        return Inst<U64>(Opcode::FPFma64, Flags{control}, a, b, c);
+    case Type::F16:
+        return Inst<F16>(Opcode::FPFma16, Flags{control}, a, b, c);
+    case Type::F32:
+        return Inst<F32>(Opcode::FPFma32, Flags{control}, a, b, c);
+    case Type::F64:
+        return Inst<F64>(Opcode::FPFma64, Flags{control}, a, b, c);
     default:
         ThrowInvalidType(a.Type());
     }
 }
 
-U16U32U64 IREmitter::FPAbs(const U16U32U64& value) {
+F16F32F64 IREmitter::FPAbs(const F16F32F64& value) {
     switch (value.Type()) {
     case Type::U16:
-        return Inst<U16>(Opcode::FPAbs16, value);
+        return Inst<F16>(Opcode::FPAbs16, value);
     case Type::U32:
-        return Inst<U32>(Opcode::FPAbs32, value);
+        return Inst<F32>(Opcode::FPAbs32, value);
     case Type::U64:
-        return Inst<U64>(Opcode::FPAbs64, value);
+        return Inst<F64>(Opcode::FPAbs64, value);
     default:
         ThrowInvalidType(value.Type());
     }
 }
 
-U16U32U64 IREmitter::FPNeg(const U16U32U64& value) {
+F16F32F64 IREmitter::FPNeg(const F16F32F64& value) {
     switch (value.Type()) {
     case Type::U16:
-        return Inst<U16>(Opcode::FPNeg16, value);
+        return Inst<F16>(Opcode::FPNeg16, value);
     case Type::U32:
-        return Inst<U32>(Opcode::FPNeg32, value);
+        return Inst<F32>(Opcode::FPNeg32, value);
     case Type::U64:
-        return Inst<U64>(Opcode::FPNeg64, value);
+        return Inst<F64>(Opcode::FPNeg64, value);
     default:
         ThrowInvalidType(value.Type());
     }
 }
 
-U16U32U64 IREmitter::FPAbsNeg(const U16U32U64& value, bool abs, bool neg) {
-    U16U32U64 result{value};
+F16F32F64 IREmitter::FPAbsNeg(const F16F32F64& value, bool abs, bool neg) {
+    F16F32F64 result{value};
     if (abs) {
         result = FPAbs(value);
     }
@@ -383,108 +476,108 @@ U16U32U64 IREmitter::FPAbsNeg(const U16U32U64& value, bool abs, bool neg) {
     return result;
 }
 
-U32 IREmitter::FPCosNotReduced(const U32& value) {
-    return Inst<U32>(Opcode::FPCosNotReduced, value);
+F32 IREmitter::FPCosNotReduced(const F32& value) {
+    return Inst<F32>(Opcode::FPCosNotReduced, value);
 }
 
-U32 IREmitter::FPExp2NotReduced(const U32& value) {
-    return Inst<U32>(Opcode::FPExp2NotReduced, value);
+F32 IREmitter::FPExp2NotReduced(const F32& value) {
+    return Inst<F32>(Opcode::FPExp2NotReduced, value);
 }
 
-U32 IREmitter::FPLog2(const U32& value) {
-    return Inst<U32>(Opcode::FPLog2, value);
+F32 IREmitter::FPLog2(const F32& value) {
+    return Inst<F32>(Opcode::FPLog2, value);
 }
 
-U32U64 IREmitter::FPRecip(const U32U64& value) {
+F32F64 IREmitter::FPRecip(const F32F64& value) {
     switch (value.Type()) {
     case Type::U32:
-        return Inst<U32>(Opcode::FPRecip32, value);
+        return Inst<F32>(Opcode::FPRecip32, value);
     case Type::U64:
-        return Inst<U64>(Opcode::FPRecip64, value);
+        return Inst<F64>(Opcode::FPRecip64, value);
     default:
         ThrowInvalidType(value.Type());
     }
 }
 
-U32U64 IREmitter::FPRecipSqrt(const U32U64& value) {
+F32F64 IREmitter::FPRecipSqrt(const F32F64& value) {
     switch (value.Type()) {
     case Type::U32:
-        return Inst<U32>(Opcode::FPRecipSqrt32, value);
+        return Inst<F32>(Opcode::FPRecipSqrt32, value);
     case Type::U64:
-        return Inst<U64>(Opcode::FPRecipSqrt64, value);
+        return Inst<F64>(Opcode::FPRecipSqrt64, value);
     default:
         ThrowInvalidType(value.Type());
     }
 }
 
-U32 IREmitter::FPSinNotReduced(const U32& value) {
-    return Inst<U32>(Opcode::FPSinNotReduced, value);
+F32 IREmitter::FPSinNotReduced(const F32& value) {
+    return Inst<F32>(Opcode::FPSinNotReduced, value);
 }
 
-U32 IREmitter::FPSqrt(const U32& value) {
-    return Inst<U32>(Opcode::FPSqrt, value);
+F32 IREmitter::FPSqrt(const F32& value) {
+    return Inst<F32>(Opcode::FPSqrt, value);
 }
 
-U16U32U64 IREmitter::FPSaturate(const U16U32U64& value) {
+F16F32F64 IREmitter::FPSaturate(const F16F32F64& value) {
     switch (value.Type()) {
     case Type::U16:
-        return Inst<U16>(Opcode::FPSaturate16, value);
+        return Inst<F16>(Opcode::FPSaturate16, value);
     case Type::U32:
-        return Inst<U32>(Opcode::FPSaturate32, value);
+        return Inst<F32>(Opcode::FPSaturate32, value);
     case Type::U64:
-        return Inst<U64>(Opcode::FPSaturate64, value);
+        return Inst<F64>(Opcode::FPSaturate64, value);
     default:
         ThrowInvalidType(value.Type());
     }
 }
 
-U16U32U64 IREmitter::FPRoundEven(const U16U32U64& value) {
+F16F32F64 IREmitter::FPRoundEven(const F16F32F64& value) {
     switch (value.Type()) {
     case Type::U16:
-        return Inst<U16>(Opcode::FPRoundEven16, value);
+        return Inst<F16>(Opcode::FPRoundEven16, value);
     case Type::U32:
-        return Inst<U32>(Opcode::FPRoundEven32, value);
+        return Inst<F32>(Opcode::FPRoundEven32, value);
     case Type::U64:
-        return Inst<U64>(Opcode::FPRoundEven64, value);
+        return Inst<F64>(Opcode::FPRoundEven64, value);
     default:
         ThrowInvalidType(value.Type());
     }
 }
 
-U16U32U64 IREmitter::FPFloor(const U16U32U64& value) {
+F16F32F64 IREmitter::FPFloor(const F16F32F64& value) {
     switch (value.Type()) {
     case Type::U16:
-        return Inst<U16>(Opcode::FPFloor16, value);
+        return Inst<F16>(Opcode::FPFloor16, value);
     case Type::U32:
-        return Inst<U32>(Opcode::FPFloor32, value);
+        return Inst<F32>(Opcode::FPFloor32, value);
     case Type::U64:
-        return Inst<U64>(Opcode::FPFloor64, value);
+        return Inst<F64>(Opcode::FPFloor64, value);
     default:
         ThrowInvalidType(value.Type());
     }
 }
 
-U16U32U64 IREmitter::FPCeil(const U16U32U64& value) {
+F16F32F64 IREmitter::FPCeil(const F16F32F64& value) {
     switch (value.Type()) {
     case Type::U16:
-        return Inst<U16>(Opcode::FPCeil16, value);
+        return Inst<F16>(Opcode::FPCeil16, value);
     case Type::U32:
-        return Inst<U32>(Opcode::FPCeil32, value);
+        return Inst<F32>(Opcode::FPCeil32, value);
     case Type::U64:
-        return Inst<U64>(Opcode::FPCeil64, value);
+        return Inst<F64>(Opcode::FPCeil64, value);
     default:
         ThrowInvalidType(value.Type());
     }
 }
 
-U16U32U64 IREmitter::FPTrunc(const U16U32U64& value) {
+F16F32F64 IREmitter::FPTrunc(const F16F32F64& value) {
     switch (value.Type()) {
     case Type::U16:
-        return Inst<U16>(Opcode::FPTrunc16, value);
+        return Inst<F16>(Opcode::FPTrunc16, value);
     case Type::U32:
-        return Inst<U32>(Opcode::FPTrunc32, value);
+        return Inst<F32>(Opcode::FPTrunc32, value);
     case Type::U64:
-        return Inst<U64>(Opcode::FPTrunc64, value);
+        return Inst<F64>(Opcode::FPTrunc64, value);
     default:
         ThrowInvalidType(value.Type());
     }
@@ -605,7 +698,7 @@ U1 IREmitter::LogicalNot(const U1& value) {
     return Inst<U1>(Opcode::LogicalNot, value);
 }
 
-U32U64 IREmitter::ConvertFToS(size_t bitsize, const U16U32U64& value) {
+U32U64 IREmitter::ConvertFToS(size_t bitsize, const F16F32F64& value) {
     switch (bitsize) {
     case 16:
         switch (value.Type()) {
@@ -645,7 +738,7 @@ U32U64 IREmitter::ConvertFToS(size_t bitsize, const U16U32U64& value) {
     }
 }
 
-U32U64 IREmitter::ConvertFToU(size_t bitsize, const U16U32U64& value) {
+U32U64 IREmitter::ConvertFToU(size_t bitsize, const F16F32F64& value) {
     switch (bitsize) {
     case 16:
         switch (value.Type()) {
@@ -685,7 +778,7 @@ U32U64 IREmitter::ConvertFToU(size_t bitsize, const U16U32U64& value) {
     }
 }
 
-U32U64 IREmitter::ConvertFToI(size_t bitsize, bool is_signed, const U16U32U64& value) {
+U32U64 IREmitter::ConvertFToI(size_t bitsize, bool is_signed, const F16F32F64& value) {
     if (is_signed) {
         return ConvertFToS(bitsize, value);
     } else {
-- 
cgit v1.2.3-70-g09d2