From 8af9297f0972d0aaa8306369c5d04926b886a89e Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Sun, 14 Feb 2021 01:24:32 -0300
Subject: shader: Misc fixes

---
 src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp | 113 +++++++++++-----------
 1 file changed, 59 insertions(+), 54 deletions(-)

(limited to 'src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp')

diff --git a/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp b/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp
index 8ca996e935..7eaf719c4e 100644
--- a/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp
+++ b/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp
@@ -113,6 +113,7 @@ private:
     IR::Value ReadVariableRecursive(auto variable, IR::Block* block) {
         IR::Value val;
         if (const std::span preds{block->ImmediatePredecessors()}; preds.size() == 1) {
+            // Optimize the common case of one predecessor: no phi needed
             val = ReadVariable(variable, preds.front());
         } else {
             // Break potential cycles with operandless phi
@@ -160,66 +161,70 @@ private:
 
     DefTable current_def;
 };
+
+void VisitInst(Pass& pass, IR::Block* block, IR::Inst& inst) {
+    switch (inst.Opcode()) {
+    case IR::Opcode::SetRegister:
+        if (const IR::Reg reg{inst.Arg(0).Reg()}; reg != IR::Reg::RZ) {
+            pass.WriteVariable(reg, block, inst.Arg(1));
+        }
+        break;
+    case IR::Opcode::SetPred:
+        if (const IR::Pred pred{inst.Arg(0).Pred()}; pred != IR::Pred::PT) {
+            pass.WriteVariable(pred, block, inst.Arg(1));
+        }
+        break;
+    case IR::Opcode::SetGotoVariable:
+        pass.WriteVariable(GotoVariable{inst.Arg(0).U32()}, block, inst.Arg(1));
+        break;
+    case IR::Opcode::SetZFlag:
+        pass.WriteVariable(ZeroFlagTag{}, block, inst.Arg(0));
+        break;
+    case IR::Opcode::SetSFlag:
+        pass.WriteVariable(SignFlagTag{}, block, inst.Arg(0));
+        break;
+    case IR::Opcode::SetCFlag:
+        pass.WriteVariable(CarryFlagTag{}, block, inst.Arg(0));
+        break;
+    case IR::Opcode::SetOFlag:
+        pass.WriteVariable(OverflowFlagTag{}, block, inst.Arg(0));
+        break;
+    case IR::Opcode::GetRegister:
+        if (const IR::Reg reg{inst.Arg(0).Reg()}; reg != IR::Reg::RZ) {
+            inst.ReplaceUsesWith(pass.ReadVariable(reg, block));
+        }
+        break;
+    case IR::Opcode::GetPred:
+        if (const IR::Pred pred{inst.Arg(0).Pred()}; pred != IR::Pred::PT) {
+            inst.ReplaceUsesWith(pass.ReadVariable(pred, block));
+        }
+        break;
+    case IR::Opcode::GetGotoVariable:
+        inst.ReplaceUsesWith(pass.ReadVariable(GotoVariable{inst.Arg(0).U32()}, block));
+        break;
+    case IR::Opcode::GetZFlag:
+        inst.ReplaceUsesWith(pass.ReadVariable(ZeroFlagTag{}, block));
+        break;
+    case IR::Opcode::GetSFlag:
+        inst.ReplaceUsesWith(pass.ReadVariable(SignFlagTag{}, block));
+        break;
+    case IR::Opcode::GetCFlag:
+        inst.ReplaceUsesWith(pass.ReadVariable(CarryFlagTag{}, block));
+        break;
+    case IR::Opcode::GetOFlag:
+        inst.ReplaceUsesWith(pass.ReadVariable(OverflowFlagTag{}, block));
+        break;
+    default:
+        break;
+    }
+}
 } // Anonymous namespace
 
 void SsaRewritePass(IR::Function& function) {
     Pass pass;
     for (IR::Block* const block : function.blocks) {
         for (IR::Inst& inst : block->Instructions()) {
-            switch (inst.Opcode()) {
-            case IR::Opcode::SetRegister:
-                if (const IR::Reg reg{inst.Arg(0).Reg()}; reg != IR::Reg::RZ) {
-                    pass.WriteVariable(reg, block, inst.Arg(1));
-                }
-                break;
-            case IR::Opcode::SetPred:
-                if (const IR::Pred pred{inst.Arg(0).Pred()}; pred != IR::Pred::PT) {
-                    pass.WriteVariable(pred, block, inst.Arg(1));
-                }
-                break;
-            case IR::Opcode::SetGotoVariable:
-                pass.WriteVariable(GotoVariable{inst.Arg(0).U32()}, block, inst.Arg(1));
-                break;
-            case IR::Opcode::SetZFlag:
-                pass.WriteVariable(ZeroFlagTag{}, block, inst.Arg(0));
-                break;
-            case IR::Opcode::SetSFlag:
-                pass.WriteVariable(SignFlagTag{}, block, inst.Arg(0));
-                break;
-            case IR::Opcode::SetCFlag:
-                pass.WriteVariable(CarryFlagTag{}, block, inst.Arg(0));
-                break;
-            case IR::Opcode::SetOFlag:
-                pass.WriteVariable(OverflowFlagTag{}, block, inst.Arg(0));
-                break;
-            case IR::Opcode::GetRegister:
-                if (const IR::Reg reg{inst.Arg(0).Reg()}; reg != IR::Reg::RZ) {
-                    inst.ReplaceUsesWith(pass.ReadVariable(reg, block));
-                }
-                break;
-            case IR::Opcode::GetPred:
-                if (const IR::Pred pred{inst.Arg(0).Pred()}; pred != IR::Pred::PT) {
-                    inst.ReplaceUsesWith(pass.ReadVariable(pred, block));
-                }
-                break;
-            case IR::Opcode::GetGotoVariable:
-                inst.ReplaceUsesWith(pass.ReadVariable(GotoVariable{inst.Arg(0).U32()}, block));
-                break;
-            case IR::Opcode::GetZFlag:
-                inst.ReplaceUsesWith(pass.ReadVariable(ZeroFlagTag{}, block));
-                break;
-            case IR::Opcode::GetSFlag:
-                inst.ReplaceUsesWith(pass.ReadVariable(SignFlagTag{}, block));
-                break;
-            case IR::Opcode::GetCFlag:
-                inst.ReplaceUsesWith(pass.ReadVariable(CarryFlagTag{}, block));
-                break;
-            case IR::Opcode::GetOFlag:
-                inst.ReplaceUsesWith(pass.ReadVariable(OverflowFlagTag{}, block));
-                break;
-            default:
-                break;
-            }
+            VisitInst(pass, block, inst);
         }
     }
 }
-- 
cgit v1.2.3-70-g09d2