From 34aba9627a8fad20b3b173180e2f3d679dd32293 Mon Sep 17 00:00:00 2001
From: FernandoS27 <fsahmkow27@gmail.com>
Date: Sat, 27 Mar 2021 22:30:24 +0100
Subject: shader: Implement BRX

---
 .../frontend/maxwell/structured_control_flow.cpp   | 57 ++++++++++++++++++++++
 1 file changed, 57 insertions(+)

(limited to 'src/shader_recompiler/frontend/maxwell/structured_control_flow.cpp')

diff --git a/src/shader_recompiler/frontend/maxwell/structured_control_flow.cpp b/src/shader_recompiler/frontend/maxwell/structured_control_flow.cpp
index 9d46883902..a6e55f61ed 100644
--- a/src/shader_recompiler/frontend/maxwell/structured_control_flow.cpp
+++ b/src/shader_recompiler/frontend/maxwell/structured_control_flow.cpp
@@ -17,6 +17,7 @@
 #include "shader_recompiler/environment.h"
 #include "shader_recompiler/frontend/ir/basic_block.h"
 #include "shader_recompiler/frontend/ir/ir_emitter.h"
+#include "shader_recompiler/frontend/maxwell/decode.h"
 #include "shader_recompiler/frontend/maxwell/structured_control_flow.h"
 #include "shader_recompiler/frontend/maxwell/translate/translate.h"
 #include "shader_recompiler/object_pool.h"
@@ -46,12 +47,15 @@ enum class StatementType {
     Break,
     Return,
     Kill,
+    Unreachable,
     Function,
     Identity,
     Not,
     Or,
     SetVariable,
+    SetIndirectBranchVariable,
     Variable,
+    IndirectBranchCond,
 };
 
 bool HasChildren(StatementType type) {
@@ -72,12 +76,15 @@ struct Loop {};
 struct Break {};
 struct Return {};
 struct Kill {};
+struct Unreachable {};
 struct FunctionTag {};
 struct Identity {};
 struct Not {};
 struct Or {};
 struct SetVariable {};
+struct SetIndirectBranchVariable {};
 struct Variable {};
+struct IndirectBranchCond {};
 
 #ifdef _MSC_VER
 #pragma warning(push)
@@ -96,6 +103,7 @@ struct Statement : ListBaseHook {
         : cond{cond_}, up{up_}, type{StatementType::Break} {}
     Statement(Return) : type{StatementType::Return} {}
     Statement(Kill) : type{StatementType::Kill} {}
+    Statement(Unreachable) : type{StatementType::Unreachable} {}
     Statement(FunctionTag) : children{}, type{StatementType::Function} {}
     Statement(Identity, IR::Condition cond_) : guest_cond{cond_}, type{StatementType::Identity} {}
     Statement(Not, Statement* op_) : op{op_}, type{StatementType::Not} {}
@@ -103,7 +111,12 @@ struct Statement : ListBaseHook {
         : op_a{op_a_}, op_b{op_b_}, type{StatementType::Or} {}
     Statement(SetVariable, u32 id_, Statement* op_, Statement* up_)
         : op{op_}, id{id_}, up{up_}, type{StatementType::SetVariable} {}
+    Statement(SetIndirectBranchVariable, IR::Reg branch_reg_, s32 branch_offset_)
+        : branch_offset{branch_offset_},
+          branch_reg{branch_reg_}, type{StatementType::SetIndirectBranchVariable} {}
     Statement(Variable, u32 id_) : id{id_}, type{StatementType::Variable} {}
+    Statement(IndirectBranchCond, u32 location_)
+        : location{location_}, type{StatementType::IndirectBranchCond} {}
 
     ~Statement() {
         if (HasChildren(type)) {
@@ -118,11 +131,14 @@ struct Statement : ListBaseHook {
         IR::Condition guest_cond;
         Statement* op;
         Statement* op_a;
+        u32 location;
+        s32 branch_offset;
     };
     union {
         Statement* cond;
         Statement* op_b;
         u32 id;
+        IR::Reg branch_reg;
     };
     Statement* up{};
     StatementType type;
@@ -141,6 +157,8 @@ std::string DumpExpr(const Statement* stmt) {
         return fmt::format("{} || {}", DumpExpr(stmt->op_a), DumpExpr(stmt->op_b));
     case StatementType::Variable:
         return fmt::format("goto_L{}", stmt->id);
+    case StatementType::IndirectBranchCond:
+        return fmt::format("(indirect_branch == {:x})", stmt->location);
     default:
         return "<invalid type>";
     }
@@ -182,14 +200,22 @@ std::string DumpTree(const Tree& tree, u32 indentation = 0) {
         case StatementType::Kill:
             ret += fmt::format("{}    kill;\n", indent);
             break;
+        case StatementType::Unreachable:
+            ret += fmt::format("{}    unreachable;\n", indent);
+            break;
         case StatementType::SetVariable:
             ret += fmt::format("{}    goto_L{} = {};\n", indent, stmt->id, DumpExpr(stmt->op));
             break;
+        case StatementType::SetIndirectBranchVariable:
+            ret += fmt::format("{}    indirect_branch = {} + {};\n", indent, stmt->branch_reg,
+                               stmt->branch_offset);
+            break;
         case StatementType::Function:
         case StatementType::Identity:
         case StatementType::Not:
         case StatementType::Or:
         case StatementType::Variable:
+        case StatementType::IndirectBranchCond:
             throw LogicError("Statement can't be printed");
         }
     }
@@ -417,6 +443,17 @@ private:
                 }
                 break;
             }
+            case Flow::EndClass::IndirectBranch:
+                root.insert(ip, *pool.Create(SetIndirectBranchVariable{}, block.branch_reg,
+                                             block.branch_offset));
+                for (Flow::Block* const branch : block.indirect_branches) {
+                    const Node indirect_label{local_labels.at(branch)};
+                    Statement* cond{pool.Create(IndirectBranchCond{}, branch->begin.Offset())};
+                    Statement* goto_stmt{pool.Create(Goto{}, cond, indirect_label, &root_stmt)};
+                    gotos.push_back(root.insert(ip, *goto_stmt));
+                }
+                root.insert(ip, *pool.Create(Unreachable{}));
+                break;
             case Flow::EndClass::Call: {
                 Flow::Function& call{cfg.Functions()[block.function_call]};
                 const Node call_return_label{local_labels.at(block.return_block)};
@@ -623,6 +660,8 @@ IR::Block* TryFindForwardBlock(const Statement& stmt) {
         return ir.LogicalOr(VisitExpr(ir, *stmt.op_a), VisitExpr(ir, *stmt.op_b));
     case StatementType::Variable:
         return ir.GetGotoVariable(stmt.id);
+    case StatementType::IndirectBranchCond:
+        return ir.IEqual(ir.GetIndirectBranchVariable(), ir.Imm32(stmt.location));
     default:
         throw NotImplementedException("Statement type {}", stmt.type);
     }
@@ -670,6 +709,15 @@ private:
                 ir.SetGotoVariable(stmt.id, VisitExpr(ir, *stmt.op));
                 break;
             }
+            case StatementType::SetIndirectBranchVariable: {
+                if (!current_block) {
+                    current_block = MergeBlock(parent, stmt);
+                }
+                IR::IREmitter ir{*current_block};
+                IR::U32 address{ir.IAdd(ir.GetReg(stmt.branch_reg), ir.Imm32(stmt.branch_offset))};
+                ir.SetIndirectBranchVariable(address);
+                break;
+            }
             case StatementType::If: {
                 if (!current_block) {
                     current_block = block_pool.Create(inst_pool);
@@ -756,6 +804,15 @@ private:
                 current_block = demote_block;
                 break;
             }
+            case StatementType::Unreachable: {
+                if (!current_block) {
+                    current_block = block_pool.Create(inst_pool);
+                    block_list.push_back(current_block);
+                }
+                IR::IREmitter{*current_block}.Unreachable();
+                current_block = nullptr;
+                break;
+            }
             default:
                 throw NotImplementedException("Statement type {}", stmt.type);
             }
-- 
cgit v1.2.3-70-g09d2