diff options
author | ReinUsesLisp <reinuseslisp@airmail.cc> | 2021-02-14 01:24:32 -0300 |
---|---|---|
committer | ameerj <52414509+ameerj@users.noreply.github.com> | 2021-07-22 21:51:22 -0400 |
commit | 8af9297f0972d0aaa8306369c5d04926b886a89e (patch) | |
tree | 43bb3f50d694b615d2ae821eef84e417166d4890 /src/shader_recompiler/ir_opt | |
parent | 9170200a11715d131645d1ffb92e86e6ef0d7e88 (diff) |
shader: Misc fixes
Diffstat (limited to 'src/shader_recompiler/ir_opt')
-rw-r--r-- | src/shader_recompiler/ir_opt/constant_propagation_pass.cpp | 27 | ||||
-rw-r--r-- | src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp | 113 |
2 files changed, 70 insertions, 70 deletions
diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp index 9fba6ac239..cbde65b9b4 100644 --- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp +++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp @@ -32,6 +32,8 @@ template <typename T> return value.U1(); } else if constexpr (std::is_same_v<T, u32>) { return value.U32(); + } else if constexpr (std::is_same_v<T, s32>) { + return static_cast<s32>(value.U32()); } else if constexpr (std::is_same_v<T, f32>) { return value.F32(); } else if constexpr (std::is_same_v<T, u64>) { @@ -39,17 +41,8 @@ template <typename T> } } -template <typename ImmFn> +template <typename T, typename ImmFn> bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) { - const auto arg = [](const IR::Value& value) { - if constexpr (std::is_invocable_r_v<bool, ImmFn, bool, bool>) { - return value.U1(); - } else if constexpr (std::is_invocable_r_v<u32, ImmFn, u32, u32>) { - return value.U32(); - } else if constexpr (std::is_invocable_r_v<u64, ImmFn, u64, u64>) { - return value.U64(); - } - }; const IR::Value lhs{inst.Arg(0)}; const IR::Value rhs{inst.Arg(1)}; @@ -57,14 +50,14 @@ bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) { const bool is_rhs_immediate{rhs.IsImmediate()}; if (is_lhs_immediate && is_rhs_immediate) { - const auto result{imm_fn(arg(lhs), arg(rhs))}; + const auto result{imm_fn(Arg<T>(lhs), Arg<T>(rhs))}; inst.ReplaceUsesWith(IR::Value{result}); return false; } if (is_lhs_immediate && !is_rhs_immediate) { IR::Inst* const rhs_inst{rhs.InstRecursive()}; if (rhs_inst->Opcode() == inst.Opcode() && rhs_inst->Arg(1).IsImmediate()) { - const auto combined{imm_fn(arg(lhs), arg(rhs_inst->Arg(1)))}; + const auto combined{imm_fn(Arg<T>(lhs), Arg<T>(rhs_inst->Arg(1)))}; inst.SetArg(0, rhs_inst->Arg(0)); inst.SetArg(1, IR::Value{combined}); } else { @@ -76,7 +69,7 @@ bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) { if (!is_lhs_immediate && is_rhs_immediate) { const IR::Inst* const lhs_inst{lhs.InstRecursive()}; if (lhs_inst->Opcode() == inst.Opcode() && lhs_inst->Arg(1).IsImmediate()) { - const auto combined{imm_fn(arg(rhs), arg(lhs_inst->Arg(1)))}; + const auto combined{imm_fn(Arg<T>(rhs), Arg<T>(lhs_inst->Arg(1)))}; inst.SetArg(0, lhs_inst->Arg(0)); inst.SetArg(1, IR::Value{combined}); } @@ -101,7 +94,7 @@ void FoldAdd(IR::Inst& inst) { if (inst.HasAssociatedPseudoOperation()) { return; } - if (!FoldCommutative(inst, [](T a, T b) { return a + b; })) { + if (!FoldCommutative<T>(inst, [](T a, T b) { return a + b; })) { return; } const IR::Value rhs{inst.Arg(1)}; @@ -119,7 +112,7 @@ void FoldSelect(IR::Inst& inst) { } void FoldLogicalAnd(IR::Inst& inst) { - if (!FoldCommutative(inst, [](bool a, bool b) { return a && b; })) { + if (!FoldCommutative<bool>(inst, [](bool a, bool b) { return a && b; })) { return; } const IR::Value rhs{inst.Arg(1)}; @@ -133,7 +126,7 @@ void FoldLogicalAnd(IR::Inst& inst) { } void FoldLogicalOr(IR::Inst& inst) { - if (!FoldCommutative(inst, [](bool a, bool b) { return a || b; })) { + if (!FoldCommutative<bool>(inst, [](bool a, bool b) { return a || b; })) { return; } const IR::Value rhs{inst.Arg(1)}; @@ -226,6 +219,8 @@ void ConstantPropagation(IR::Inst& inst) { return FoldLogicalOr(inst); case IR::Opcode::LogicalNot: return FoldLogicalNot(inst); + case IR::Opcode::SLessThan: + return FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a < b; }); case IR::Opcode::ULessThan: return FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; }); case IR::Opcode::BitFieldUExtract: diff --git a/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp b/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp index 8ca996e935..7eaf719c4e 100644 --- a/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp +++ b/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp @@ -113,6 +113,7 @@ private: IR::Value ReadVariableRecursive(auto variable, IR::Block* block) { IR::Value val; if (const std::span preds{block->ImmediatePredecessors()}; preds.size() == 1) { + // Optimize the common case of one predecessor: no phi needed val = ReadVariable(variable, preds.front()); } else { // Break potential cycles with operandless phi @@ -160,66 +161,70 @@ private: DefTable current_def; }; + +void VisitInst(Pass& pass, IR::Block* block, IR::Inst& inst) { + switch (inst.Opcode()) { + case IR::Opcode::SetRegister: + if (const IR::Reg reg{inst.Arg(0).Reg()}; reg != IR::Reg::RZ) { + pass.WriteVariable(reg, block, inst.Arg(1)); + } + break; + case IR::Opcode::SetPred: + if (const IR::Pred pred{inst.Arg(0).Pred()}; pred != IR::Pred::PT) { + pass.WriteVariable(pred, block, inst.Arg(1)); + } + break; + case IR::Opcode::SetGotoVariable: + pass.WriteVariable(GotoVariable{inst.Arg(0).U32()}, block, inst.Arg(1)); + break; + case IR::Opcode::SetZFlag: + pass.WriteVariable(ZeroFlagTag{}, block, inst.Arg(0)); + break; + case IR::Opcode::SetSFlag: + pass.WriteVariable(SignFlagTag{}, block, inst.Arg(0)); + break; + case IR::Opcode::SetCFlag: + pass.WriteVariable(CarryFlagTag{}, block, inst.Arg(0)); + break; + case IR::Opcode::SetOFlag: + pass.WriteVariable(OverflowFlagTag{}, block, inst.Arg(0)); + break; + case IR::Opcode::GetRegister: + if (const IR::Reg reg{inst.Arg(0).Reg()}; reg != IR::Reg::RZ) { + inst.ReplaceUsesWith(pass.ReadVariable(reg, block)); + } + break; + case IR::Opcode::GetPred: + if (const IR::Pred pred{inst.Arg(0).Pred()}; pred != IR::Pred::PT) { + inst.ReplaceUsesWith(pass.ReadVariable(pred, block)); + } + break; + case IR::Opcode::GetGotoVariable: + inst.ReplaceUsesWith(pass.ReadVariable(GotoVariable{inst.Arg(0).U32()}, block)); + break; + case IR::Opcode::GetZFlag: + inst.ReplaceUsesWith(pass.ReadVariable(ZeroFlagTag{}, block)); + break; + case IR::Opcode::GetSFlag: + inst.ReplaceUsesWith(pass.ReadVariable(SignFlagTag{}, block)); + break; + case IR::Opcode::GetCFlag: + inst.ReplaceUsesWith(pass.ReadVariable(CarryFlagTag{}, block)); + break; + case IR::Opcode::GetOFlag: + inst.ReplaceUsesWith(pass.ReadVariable(OverflowFlagTag{}, block)); + break; + default: + break; + } +} } // Anonymous namespace void SsaRewritePass(IR::Function& function) { Pass pass; for (IR::Block* const block : function.blocks) { for (IR::Inst& inst : block->Instructions()) { - switch (inst.Opcode()) { - case IR::Opcode::SetRegister: - if (const IR::Reg reg{inst.Arg(0).Reg()}; reg != IR::Reg::RZ) { - pass.WriteVariable(reg, block, inst.Arg(1)); - } - break; - case IR::Opcode::SetPred: - if (const IR::Pred pred{inst.Arg(0).Pred()}; pred != IR::Pred::PT) { - pass.WriteVariable(pred, block, inst.Arg(1)); - } - break; - case IR::Opcode::SetGotoVariable: - pass.WriteVariable(GotoVariable{inst.Arg(0).U32()}, block, inst.Arg(1)); - break; - case IR::Opcode::SetZFlag: - pass.WriteVariable(ZeroFlagTag{}, block, inst.Arg(0)); - break; - case IR::Opcode::SetSFlag: - pass.WriteVariable(SignFlagTag{}, block, inst.Arg(0)); - break; - case IR::Opcode::SetCFlag: - pass.WriteVariable(CarryFlagTag{}, block, inst.Arg(0)); - break; - case IR::Opcode::SetOFlag: - pass.WriteVariable(OverflowFlagTag{}, block, inst.Arg(0)); - break; - case IR::Opcode::GetRegister: - if (const IR::Reg reg{inst.Arg(0).Reg()}; reg != IR::Reg::RZ) { - inst.ReplaceUsesWith(pass.ReadVariable(reg, block)); - } - break; - case IR::Opcode::GetPred: - if (const IR::Pred pred{inst.Arg(0).Pred()}; pred != IR::Pred::PT) { - inst.ReplaceUsesWith(pass.ReadVariable(pred, block)); - } - break; - case IR::Opcode::GetGotoVariable: - inst.ReplaceUsesWith(pass.ReadVariable(GotoVariable{inst.Arg(0).U32()}, block)); - break; - case IR::Opcode::GetZFlag: - inst.ReplaceUsesWith(pass.ReadVariable(ZeroFlagTag{}, block)); - break; - case IR::Opcode::GetSFlag: - inst.ReplaceUsesWith(pass.ReadVariable(SignFlagTag{}, block)); - break; - case IR::Opcode::GetCFlag: - inst.ReplaceUsesWith(pass.ReadVariable(CarryFlagTag{}, block)); - break; - case IR::Opcode::GetOFlag: - inst.ReplaceUsesWith(pass.ReadVariable(OverflowFlagTag{}, block)); - break; - default: - break; - } + VisitInst(pass, block, inst); } } } |