From 71f96fa6366dc6dd306a953bca1b958fb32bc55a Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Sun, 14 Mar 2021 03:41:05 -0300
Subject: shader: Implement CAL inlining function calls

---
 .../frontend/maxwell/control_flow.cpp              | 78 +++++++++++-----------
 1 file changed, 38 insertions(+), 40 deletions(-)

(limited to 'src/shader_recompiler/frontend/maxwell/control_flow.cpp')

diff --git a/src/shader_recompiler/frontend/maxwell/control_flow.cpp b/src/shader_recompiler/frontend/maxwell/control_flow.cpp
index d0dc663307..715c0e92d8 100644
--- a/src/shader_recompiler/frontend/maxwell/control_flow.cpp
+++ b/src/shader_recompiler/frontend/maxwell/control_flow.cpp
@@ -31,13 +31,12 @@ struct Compare {
         return lhs.begin < rhs.begin;
     }
 };
-} // Anonymous namespace
 
-static u32 BranchOffset(Location pc, Instruction inst) {
+u32 BranchOffset(Location pc, Instruction inst) {
     return pc.Offset() + inst.branch.Offset() + 8;
 }
 
-static void Split(Block* old_block, Block* new_block, Location pc) {
+void Split(Block* old_block, Block* new_block, Location pc) {
     if (pc <= old_block->begin || pc >= old_block->end) {
         throw InvalidArgument("Invalid address to split={}", pc);
     }
@@ -49,21 +48,19 @@ static void Split(Block* old_block, Block* new_block, Location pc) {
         .cond{old_block->cond},
         .branch_true{old_block->branch_true},
         .branch_false{old_block->branch_false},
-        .ir{nullptr},
     };
     *old_block = Block{
         .begin{old_block->begin},
         .end{pc},
         .end_class{EndClass::Branch},
         .stack{std::move(old_block->stack)},
-        .cond{IR::Condition{true}},
+        .cond{true},
         .branch_true{new_block},
         .branch_false{nullptr},
-        .ir{nullptr},
     };
 }
 
-static Token OpcodeToken(Opcode opcode) {
+Token OpcodeToken(Opcode opcode) {
     switch (opcode) {
     case Opcode::PBK:
     case Opcode::BRK:
@@ -89,7 +86,7 @@ static Token OpcodeToken(Opcode opcode) {
     }
 }
 
-static bool IsAbsoluteJump(Opcode opcode) {
+bool IsAbsoluteJump(Opcode opcode) {
     switch (opcode) {
     case Opcode::JCAL:
     case Opcode::JMP:
@@ -100,7 +97,7 @@ static bool IsAbsoluteJump(Opcode opcode) {
     }
 }
 
-static bool HasFlowTest(Opcode opcode) {
+bool HasFlowTest(Opcode opcode) {
     switch (opcode) {
     case Opcode::BRA:
     case Opcode::BRX:
@@ -121,13 +118,14 @@ static bool HasFlowTest(Opcode opcode) {
     }
 }
 
-static std::string NameOf(const Block& block) {
+std::string NameOf(const Block& block) {
     if (block.begin.IsVirtual()) {
         return fmt::format("\"Virtual {}\"", block.begin);
     } else {
         return fmt::format("\"{}\"", block.begin);
     }
 }
+} // Anonymous namespace
 
 void Stack::Push(Token token, Location target) {
     entries.push_back({
@@ -166,26 +164,24 @@ bool Block::Contains(Location pc) const noexcept {
     return pc >= begin && pc < end;
 }
 
-Function::Function(Location start_address)
+Function::Function(ObjectPool<Block>& block_pool, Location start_address)
     : entrypoint{start_address}, labels{{
                                      .address{start_address},
-                                     .block{nullptr},
+                                     .block{block_pool.Create(Block{
+                                         .begin{start_address},
+                                         .end{start_address},
+                                         .end_class{EndClass::Branch},
+                                         .stack{},
+                                         .cond{true},
+                                         .branch_true{nullptr},
+                                         .branch_false{nullptr},
+                                     })},
                                      .stack{},
                                  }} {}
 
 CFG::CFG(Environment& env_, ObjectPool<Block>& block_pool_, Location start_address)
     : env{env_}, block_pool{block_pool_} {
-    functions.emplace_back(start_address);
-    functions.back().labels.back().block = block_pool.Create(Block{
-        .begin{start_address},
-        .end{start_address},
-        .end_class{EndClass::Branch},
-        .stack{},
-        .cond{IR::Condition{true}},
-        .branch_true{nullptr},
-        .branch_false{nullptr},
-        .ir{nullptr},
-    });
+    functions.emplace_back(block_pool, start_address);
     for (FunctionId function_id = 0; function_id < functions.size(); ++function_id) {
         while (!functions[function_id].labels.empty()) {
             Function& function{functions[function_id]};
@@ -308,11 +304,17 @@ CFG::AnalysisState CFG::AnalyzeInst(Block* block, FunctionId function_id, Locati
         const Location cal_pc{is_absolute ? inst.branch.Absolute() : BranchOffset(pc, inst)};
         // Technically CAL pushes into PRET, but that's implicit in the function call for us
         // Insert the function into the list if it doesn't exist
-        if (std::ranges::find(functions, cal_pc, &Function::entrypoint) == functions.end()) {
-            functions.emplace_back(cal_pc);
+        const auto it{std::ranges::find(functions, cal_pc, &Function::entrypoint)};
+        const bool exists{it != functions.end()};
+        const FunctionId call_id{exists ? std::distance(functions.begin(), it) : functions.size()};
+        if (!exists) {
+            functions.emplace_back(block_pool, cal_pc);
         }
-        // Handle CAL like a regular instruction
-        break;
+        block->end_class = EndClass::Call;
+        block->function_call = call_id;
+        block->return_block = AddLabel(block, block->stack, pc + 1, function_id);
+        block->end = pc;
+        return AnalysisState::Branch;
     }
     default:
         break;
@@ -348,7 +350,6 @@ void CFG::AnalyzeCondInst(Block* block, FunctionId function_id, Location pc,
         .cond{cond},
         .branch_true{conditional_block},
         .branch_false{nullptr},
-        .ir{nullptr},
     };
     // Save the contents of the visited block in the conditional block
     *conditional_block = std::move(*block);
@@ -401,16 +402,6 @@ void CFG::AnalyzeBRX(Block*, Location, Instruction, bool is_absolute) {
     throw NotImplementedException("{}", is_absolute ? "JMX" : "BRX");
 }
 
-void CFG::AnalyzeCAL(Location pc, Instruction inst, bool is_absolute) {
-    const Location cal_pc{is_absolute ? inst.branch.Absolute() : BranchOffset(pc, inst)};
-    // Technically CAL pushes into PRET, but that's implicit in the function call for us
-    // Insert the function to the function list if it doesn't exist
-    const auto it{std::ranges::find(functions, cal_pc, &Function::entrypoint)};
-    if (it == functions.end()) {
-        functions.emplace_back(cal_pc);
-    }
-}
-
 CFG::AnalysisState CFG::AnalyzeEXIT(Block* block, FunctionId function_id, Location pc,
                                     Instruction inst) {
     const IR::FlowTest flow_test{inst.branch.flow_test};
@@ -455,10 +446,9 @@ Block* CFG::AddLabel(Block* block, Stack stack, Location pc, FunctionId function
         .end{pc},
         .end_class{EndClass::Branch},
         .stack{stack},
-        .cond{IR::Condition{true}},
+        .cond{true},
         .branch_true{nullptr},
         .branch_false{nullptr},
-        .ir{nullptr},
     })};
     function.labels.push_back(Label{
         .address{pc},
@@ -495,6 +485,14 @@ std::string CFG::Dot() const {
                     add_branch(block.branch_false, false);
                 }
                 break;
+            case EndClass::Call:
+                dot += fmt::format("\t\t{}->N{};\n", name, node_uid);
+                dot += fmt::format("\t\tN{}->{};\n", node_uid, NameOf(*block.return_block));
+                dot += fmt::format("\t\tN{} [label=\"Call {}\"][shape=square][style=stripped];\n",
+                                   node_uid, block.function_call);
+                dot += '\n';
+                ++node_uid;
+                break;
             case EndClass::Exit:
                 dot += fmt::format("\t\t{}->N{};\n", name, node_uid);
                 dot += fmt::format("\t\tN{} [label=\"Exit\"][shape=square][style=stripped];\n",
-- 
cgit v1.2.3-70-g09d2