From 09fb41dc63eeda3a82580f119704e691ead9e76a Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Mon, 26 Jul 2021 04:15:23 -0300
Subject: shader: Use TryInstRecursive on XMAD multiply folding

Simplify a bit the logic.
---
 .../ir_opt/constant_propagation_pass.cpp           | 26 ++++++++++------------
 1 file changed, 12 insertions(+), 14 deletions(-)

(limited to 'src/shader_recompiler/ir_opt/constant_propagation_pass.cpp')

diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
index 8dd6d6c2c8..08a06da020 100644
--- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
+++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
@@ -116,33 +116,31 @@ bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) {
      *
      * This optimization has been proven safe by LLVM and MSVC.
      */
-    const IR::Value lhs_arg{inst.Arg(0)};
-    const IR::Value rhs_arg{inst.Arg(1)};
-    if (lhs_arg.IsImmediate() || rhs_arg.IsImmediate()) {
+    IR::Inst* const lhs_shl{inst.Arg(0).TryInstRecursive()};
+    IR::Inst* const rhs_mul{inst.Arg(1).TryInstRecursive()};
+    if (!lhs_shl || !rhs_mul) {
         return false;
     }
-    IR::Inst* const lhs_shl{lhs_arg.InstRecursive()};
     if (lhs_shl->GetOpcode() != IR::Opcode::ShiftLeftLogical32 ||
         lhs_shl->Arg(1) != IR::Value{16U}) {
         return false;
     }
-    if (lhs_shl->Arg(0).IsImmediate()) {
+    IR::Inst* const lhs_mul{lhs_shl->Arg(0).TryInstRecursive()};
+    if (!lhs_mul) {
         return false;
     }
-    IR::Inst* const lhs_mul{lhs_shl->Arg(0).InstRecursive()};
-    IR::Inst* const rhs_mul{rhs_arg.InstRecursive()};
     if (lhs_mul->GetOpcode() != IR::Opcode::IMul32 || rhs_mul->GetOpcode() != IR::Opcode::IMul32) {
         return false;
     }
-    if (lhs_mul->Arg(1).Resolve() != rhs_mul->Arg(1).Resolve()) {
+    const IR::U32 factor_b{lhs_mul->Arg(1)};
+    if (factor_b.Resolve() != rhs_mul->Arg(1).Resolve()) {
         return false;
     }
-    const IR::U32 factor_b{lhs_mul->Arg(1)};
-    if (lhs_mul->Arg(0).IsImmediate() || rhs_mul->Arg(0).IsImmediate()) {
+    IR::Inst* const lhs_bfe{lhs_mul->Arg(0).TryInstRecursive()};
+    IR::Inst* const rhs_bfe{rhs_mul->Arg(0).TryInstRecursive()};
+    if (!lhs_bfe || !rhs_bfe) {
         return false;
     }
-    IR::Inst* const lhs_bfe{lhs_mul->Arg(0).InstRecursive()};
-    IR::Inst* const rhs_bfe{rhs_mul->Arg(0).InstRecursive()};
     if (lhs_bfe->GetOpcode() != IR::Opcode::BitFieldUExtract) {
         return false;
     }
@@ -155,10 +153,10 @@ bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) {
     if (rhs_bfe->Arg(1) != IR::Value{0U} || rhs_bfe->Arg(2) != IR::Value{16U}) {
         return false;
     }
-    if (lhs_bfe->Arg(0).Resolve() != rhs_bfe->Arg(0).Resolve()) {
+    const IR::U32 factor_a{lhs_bfe->Arg(0)};
+    if (factor_a.Resolve() != rhs_bfe->Arg(0).Resolve()) {
         return false;
     }
-    const IR::U32 factor_a{lhs_bfe->Arg(0)};
     IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)};
     inst.ReplaceUsesWith(ir.IMul(factor_a, factor_b));
     return true;
-- 
cgit v1.2.3-70-g09d2