From 6c4cc0cd062fbbba5349da1108d3c23cb330ca8a Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Tue, 2 Feb 2021 21:07:00 -0300
Subject: shader: SSA and dominance

---
 .../frontend/maxwell/control_flow.cpp              | 130 ++++++++++++++++++++-
 1 file changed, 124 insertions(+), 6 deletions(-)

(limited to 'src/shader_recompiler/frontend/maxwell/control_flow.cpp')

diff --git a/src/shader_recompiler/frontend/maxwell/control_flow.cpp b/src/shader_recompiler/frontend/maxwell/control_flow.cpp
index fc4dba8269..21ee981371 100644
--- a/src/shader_recompiler/frontend/maxwell/control_flow.cpp
+++ b/src/shader_recompiler/frontend/maxwell/control_flow.cpp
@@ -36,6 +36,7 @@ static std::array<Block, 2> Split(Block&& block, Location pc, BlockId new_id) {
             .cond{true},
             .branch_true{new_id},
             .branch_false{UNREACHABLE_BLOCK_ID},
+            .imm_predecessors{},
         },
         Block{
             .begin{pc},
@@ -46,6 +47,7 @@ static std::array<Block, 2> Split(Block&& block, Location pc, BlockId new_id) {
             .cond{block.cond},
             .branch_true{block.branch_true},
             .branch_false{block.branch_false},
+            .imm_predecessors{},
         },
     };
 }
@@ -108,7 +110,7 @@ static bool HasFlowTest(Opcode opcode) {
     }
 }
 
-static std::string Name(const Block& block) {
+static std::string NameOf(const Block& block) {
     if (block.begin.IsVirtual()) {
         return fmt::format("\"Virtual {}\"", block.id);
     } else {
@@ -154,13 +156,127 @@ bool Block::Contains(Location pc) const noexcept {
 }
 
 Function::Function(Location start_address)
-    : entrypoint{start_address}, labels{Label{
+    : entrypoint{start_address}, labels{{
                                      .address{start_address},
                                      .block_id{0},
                                      .stack{},
                                  }} {}
 
+void Function::BuildBlocksMap() {
+    const size_t num_blocks{NumBlocks()};
+    blocks_map.resize(num_blocks);
+    for (size_t block_index = 0; block_index < num_blocks; ++block_index) {
+        Block& block{blocks_data[block_index]};
+        blocks_map[block.id] = &block;
+    }
+}
+
+void Function::BuildImmediatePredecessors() {
+    for (const Block& block : blocks_data) {
+        if (block.branch_true != UNREACHABLE_BLOCK_ID) {
+            blocks_map[block.branch_true]->imm_predecessors.push_back(block.id);
+        }
+        if (block.branch_false != UNREACHABLE_BLOCK_ID) {
+            blocks_map[block.branch_false]->imm_predecessors.push_back(block.id);
+        }
+    }
+}
+
+void Function::BuildPostOrder() {
+    boost::container::small_vector<BlockId, 0x110> block_stack;
+    post_order_map.resize(NumBlocks());
+
+    Block& first_block{blocks_data[blocks.front()]};
+    first_block.post_order_visited = true;
+    block_stack.push_back(first_block.id);
+
+    const auto visit_branch = [&](BlockId block_id, BlockId branch_id) {
+        if (branch_id == UNREACHABLE_BLOCK_ID) {
+            return false;
+        }
+        if (blocks_map[branch_id]->post_order_visited) {
+            return false;
+        }
+        blocks_map[branch_id]->post_order_visited = true;
+
+        // Calling push_back twice is faster than insert on msvc
+        block_stack.push_back(block_id);
+        block_stack.push_back(branch_id);
+        return true;
+    };
+    while (!block_stack.empty()) {
+        const Block* const block{blocks_map[block_stack.back()]};
+        block_stack.pop_back();
+
+        if (!visit_branch(block->id, block->branch_true) &&
+            !visit_branch(block->id, block->branch_false)) {
+            post_order_map[block->id] = static_cast<u32>(post_order_blocks.size());
+            post_order_blocks.push_back(block->id);
+        }
+    }
+}
+
+void Function::BuildImmediateDominators() {
+    auto transform_block_id{std::views::transform([this](BlockId id) { return blocks_map[id]; })};
+    auto reverse_order_but_first{std::views::reverse | std::views::drop(1) | transform_block_id};
+    auto has_idom{std::views::filter([](Block* block) { return block->imm_dominator; })};
+    auto intersect{[this](Block* finger1, Block* finger2) {
+        while (finger1 != finger2) {
+            while (post_order_map[finger1->id] < post_order_map[finger2->id]) {
+                finger1 = finger1->imm_dominator;
+            }
+            while (post_order_map[finger2->id] < post_order_map[finger1->id]) {
+                finger2 = finger2->imm_dominator;
+            }
+        }
+        return finger1;
+    }};
+    for (Block& block : blocks_data) {
+        block.imm_dominator = nullptr;
+    }
+    Block* const start_block{&blocks_data[blocks.front()]};
+    start_block->imm_dominator = start_block;
+
+    bool changed{true};
+    while (changed) {
+        changed = false;
+        for (Block* const block : post_order_blocks | reverse_order_but_first) {
+            Block* new_idom{};
+            for (Block* predecessor : block->imm_predecessors | transform_block_id | has_idom) {
+                new_idom = new_idom ? intersect(predecessor, new_idom) : predecessor;
+            }
+            changed |= block->imm_dominator != new_idom;
+            block->imm_dominator = new_idom;
+        }
+    }
+}
+
+void Function::BuildDominanceFrontier() {
+    auto transform_block_id{std::views::transform([this](BlockId id) { return blocks_map[id]; })};
+    auto has_enough_predecessors{[](Block& block) { return block.imm_predecessors.size() >= 2; }};
+    for (Block& block : blocks_data | std::views::filter(has_enough_predecessors)) {
+        for (Block* current : block.imm_predecessors | transform_block_id) {
+            while (current != block.imm_dominator) {
+                current->dominance_frontiers.push_back(current->id);
+                current = current->imm_dominator;
+            }
+        }
+    }
+}
+
 CFG::CFG(Environment& env_, Location start_address) : env{env_} {
+    VisitFunctions(start_address);
+
+    for (Function& function : functions) {
+        function.BuildBlocksMap();
+        function.BuildImmediatePredecessors();
+        function.BuildPostOrder();
+        function.BuildImmediateDominators();
+        function.BuildDominanceFrontier();
+    }
+}
+
+void CFG::VisitFunctions(Location start_address) {
     functions.emplace_back(start_address);
     for (FunctionId function_id = 0; function_id < functions.size(); ++function_id) {
         while (!functions[function_id].labels.empty()) {
@@ -202,6 +318,7 @@ void CFG::AnalyzeLabel(FunctionId function_id, Label& label) {
         .cond{true},
         .branch_true{UNREACHABLE_BLOCK_ID},
         .branch_false{UNREACHABLE_BLOCK_ID},
+        .imm_predecessors{},
     };
     // Analyze instructions until it reaches an already visited block or there's a branch
     bool is_branch{false};
@@ -310,7 +427,7 @@ CFG::AnalysisState CFG::AnalyzeInst(Block& block, FunctionId function_id, Locati
         // Technically CAL pushes into PRET, but that's implicit in the function call for us
         // Insert the function into the list if it doesn't exist
         if (std::ranges::find(functions, cal_pc, &Function::entrypoint) == functions.end()) {
-            functions.push_back(cal_pc);
+            functions.emplace_back(cal_pc);
         }
         // Handle CAL like a regular instruction
         break;
@@ -352,6 +469,7 @@ void CFG::AnalyzeCondInst(Block& block, FunctionId function_id, Location pc,
         .cond{cond},
         .branch_true{conditional_block_id},
         .branch_false{UNREACHABLE_BLOCK_ID},
+        .imm_predecessors{},
     })};
     // Set the end properties of the conditional instruction and give it a new identity
     Block& conditional_block{block};
@@ -465,14 +583,14 @@ std::string CFG::Dot() const {
         dot += fmt::format("\t\tnode [style=filled];\n");
         for (const u32 block_index : function.blocks) {
             const Block& block{function.blocks_data[block_index]};
-            const std::string name{Name(block)};
+            const std::string name{NameOf(block)};
             const auto add_branch = [&](BlockId branch_id, bool add_label) {
                 const auto it{std::ranges::find(function.blocks_data, branch_id, &Block::id)};
                 dot += fmt::format("\t\t{}->", name);
                 if (it == function.blocks_data.end()) {
                     dot += fmt::format("\"Unknown label {}\"", branch_id);
                 } else {
-                    dot += Name(*it);
+                    dot += NameOf(*it);
                 };
                 if (add_label && block.cond != true && block.cond != false) {
                     dot += fmt::format(" [label=\"{}\"]", block.cond);
@@ -520,7 +638,7 @@ std::string CFG::Dot() const {
         if (functions.front().blocks.empty()) {
             dot += "Start;\n";
         } else {
-            dot += fmt::format("\tStart -> {};\n", Name(functions.front().blocks_data.front()));
+            dot += fmt::format("\tStart -> {};\n", NameOf(functions.front().blocks_data.front()));
         }
         dot += fmt::format("\tStart [shape=diamond];\n");
     }
-- 
cgit v1.2.3-70-g09d2