From 2d48a7b4d0666ad16d03a22d85712617a0849046 Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Sat, 9 Jan 2021 03:30:07 -0300
Subject: shader: Initial recompiler work

---
 .../frontend/maxwell/control_flow.cpp              | 531 +++++++++++++++++++++
 1 file changed, 531 insertions(+)
 create mode 100644 src/shader_recompiler/frontend/maxwell/control_flow.cpp

(limited to 'src/shader_recompiler/frontend/maxwell/control_flow.cpp')

diff --git a/src/shader_recompiler/frontend/maxwell/control_flow.cpp b/src/shader_recompiler/frontend/maxwell/control_flow.cpp
new file mode 100644
index 0000000000..fc4dba8269
--- /dev/null
+++ b/src/shader_recompiler/frontend/maxwell/control_flow.cpp
@@ -0,0 +1,531 @@
+// Copyright 2021 yuzu Emulator Project
+// Licensed under GPLv2 or any later version
+// Refer to the license.txt file included.
+
+#include <algorithm>
+#include <array>
+#include <optional>
+#include <ranges>
+#include <string>
+#include <utility>
+
+#include <fmt/format.h>
+
+#include "shader_recompiler/exception.h"
+#include "shader_recompiler/frontend/maxwell/control_flow.h"
+#include "shader_recompiler/frontend/maxwell/decode.h"
+#include "shader_recompiler/frontend/maxwell/location.h"
+
+namespace Shader::Maxwell::Flow {
+
+static u32 BranchOffset(Location pc, Instruction inst) {
+    return pc.Offset() + inst.branch.Offset() + 8;
+}
+
+static std::array<Block, 2> Split(Block&& block, Location pc, BlockId new_id) {
+    if (pc <= block.begin || pc >= block.end) {
+        throw InvalidArgument("Invalid address to split={}", pc);
+    }
+    return {
+        Block{
+            .begin{block.begin},
+            .end{pc},
+            .end_class{EndClass::Branch},
+            .id{block.id},
+            .stack{block.stack},
+            .cond{true},
+            .branch_true{new_id},
+            .branch_false{UNREACHABLE_BLOCK_ID},
+        },
+        Block{
+            .begin{pc},
+            .end{block.end},
+            .end_class{block.end_class},
+            .id{new_id},
+            .stack{std::move(block.stack)},
+            .cond{block.cond},
+            .branch_true{block.branch_true},
+            .branch_false{block.branch_false},
+        },
+    };
+}
+
+static Token OpcodeToken(Opcode opcode) {
+    switch (opcode) {
+    case Opcode::PBK:
+    case Opcode::BRK:
+        return Token::PBK;
+    case Opcode::PCNT:
+    case Opcode::CONT:
+        return Token::PBK;
+    case Opcode::PEXIT:
+    case Opcode::EXIT:
+        return Token::PEXIT;
+    case Opcode::PLONGJMP:
+    case Opcode::LONGJMP:
+        return Token::PLONGJMP;
+    case Opcode::PRET:
+    case Opcode::RET:
+    case Opcode::CAL:
+        return Token::PRET;
+    case Opcode::SSY:
+    case Opcode::SYNC:
+        return Token::SSY;
+    default:
+        throw InvalidArgument("{}", opcode);
+    }
+}
+
+static bool IsAbsoluteJump(Opcode opcode) {
+    switch (opcode) {
+    case Opcode::JCAL:
+    case Opcode::JMP:
+    case Opcode::JMX:
+        return true;
+    default:
+        return false;
+    }
+}
+
+static bool HasFlowTest(Opcode opcode) {
+    switch (opcode) {
+    case Opcode::BRA:
+    case Opcode::BRX:
+    case Opcode::EXIT:
+    case Opcode::JMP:
+    case Opcode::JMX:
+    case Opcode::BRK:
+    case Opcode::CONT:
+    case Opcode::LONGJMP:
+    case Opcode::RET:
+    case Opcode::SYNC:
+        return true;
+    case Opcode::CAL:
+    case Opcode::JCAL:
+        return false;
+    default:
+        throw InvalidArgument("Invalid branch {}", opcode);
+    }
+}
+
+static std::string Name(const Block& block) {
+    if (block.begin.IsVirtual()) {
+        return fmt::format("\"Virtual {}\"", block.id);
+    } else {
+        return fmt::format("\"{}\"", block.begin);
+    }
+}
+
+void Stack::Push(Token token, Location target) {
+    entries.push_back({
+        .token{token},
+        .target{target},
+    });
+}
+
+std::pair<Location, Stack> Stack::Pop(Token token) const {
+    const std::optional<Location> pc{Peek(token)};
+    if (!pc) {
+        throw LogicError("Token could not be found");
+    }
+    return {*pc, Remove(token)};
+}
+
+std::optional<Location> Stack::Peek(Token token) const {
+    const auto reverse_entries{entries | std::views::reverse};
+    const auto it{std::ranges::find(reverse_entries, token, &StackEntry::token)};
+    if (it == reverse_entries.end()) {
+        return std::nullopt;
+    }
+    return it->target;
+}
+
+Stack Stack::Remove(Token token) const {
+    const auto reverse_entries{entries | std::views::reverse};
+    const auto it{std::ranges::find(reverse_entries, token, &StackEntry::token)};
+    const auto pos{std::distance(reverse_entries.begin(), it)};
+    Stack result;
+    result.entries.insert(result.entries.end(), entries.begin(), entries.end() - pos - 1);
+    return result;
+}
+
+bool Block::Contains(Location pc) const noexcept {
+    return pc >= begin && pc < end;
+}
+
+Function::Function(Location start_address)
+    : entrypoint{start_address}, labels{Label{
+                                     .address{start_address},
+                                     .block_id{0},
+                                     .stack{},
+                                 }} {}
+
+CFG::CFG(Environment& env_, Location start_address) : env{env_} {
+    functions.emplace_back(start_address);
+    for (FunctionId function_id = 0; function_id < functions.size(); ++function_id) {
+        while (!functions[function_id].labels.empty()) {
+            Function& function{functions[function_id]};
+            Label label{function.labels.back()};
+            function.labels.pop_back();
+            AnalyzeLabel(function_id, label);
+        }
+    }
+}
+
+void CFG::AnalyzeLabel(FunctionId function_id, Label& label) {
+    if (InspectVisitedBlocks(function_id, label)) {
+        // Label address has been visited
+        return;
+    }
+    // Try to find the next block
+    Function* function{&functions[function_id]};
+    Location pc{label.address};
+    const auto next{std::upper_bound(function->blocks.begin(), function->blocks.end(), pc,
+                                     [function](Location pc, u32 block_index) {
+                                         return pc < function->blocks_data[block_index].begin;
+                                     })};
+    const auto next_index{std::distance(function->blocks.begin(), next)};
+    const bool is_last{next == function->blocks.end()};
+    Location next_pc;
+    BlockId next_id{UNREACHABLE_BLOCK_ID};
+    if (!is_last) {
+        next_pc = function->blocks_data[*next].begin;
+        next_id = function->blocks_data[*next].id;
+    }
+    // Insert before the next block
+    Block block{
+        .begin{pc},
+        .end{pc},
+        .end_class{EndClass::Branch},
+        .id{label.block_id},
+        .stack{std::move(label.stack)},
+        .cond{true},
+        .branch_true{UNREACHABLE_BLOCK_ID},
+        .branch_false{UNREACHABLE_BLOCK_ID},
+    };
+    // Analyze instructions until it reaches an already visited block or there's a branch
+    bool is_branch{false};
+    while (is_last || pc < next_pc) {
+        is_branch = AnalyzeInst(block, function_id, pc) == AnalysisState::Branch;
+        if (is_branch) {
+            break;
+        }
+        ++pc;
+    }
+    if (!is_branch) {
+        // If the block finished without a branch,
+        // it means that the next instruction is already visited, jump to it
+        block.end = pc;
+        block.cond = true;
+        block.branch_true = next_id;
+        block.branch_false = UNREACHABLE_BLOCK_ID;
+    }
+    // Function's pointer might be invalid, resolve it again
+    function = &functions[function_id];
+    const u32 new_block_index = static_cast<u32>(function->blocks_data.size());
+    function->blocks.insert(function->blocks.begin() + next_index, new_block_index);
+    function->blocks_data.push_back(std::move(block));
+}
+
+bool CFG::InspectVisitedBlocks(FunctionId function_id, const Label& label) {
+    const Location pc{label.address};
+    Function& function{functions[function_id]};
+    const auto it{std::ranges::find_if(function.blocks, [&function, pc](u32 block_index) {
+        return function.blocks_data[block_index].Contains(pc);
+    })};
+    if (it == function.blocks.end()) {
+        // Address has not been visited
+        return false;
+    }
+    Block& block{function.blocks_data[*it]};
+    if (block.begin == pc) {
+        throw LogicError("Dangling branch");
+    }
+    const u32 first_index{*it};
+    const u32 second_index{static_cast<u32>(function.blocks_data.size())};
+    const std::array new_indices{first_index, second_index};
+    std::array split_blocks{Split(std::move(block), pc, label.block_id)};
+    function.blocks_data[*it] = std::move(split_blocks[0]);
+    function.blocks_data.push_back(std::move(split_blocks[1]));
+    function.blocks.insert(function.blocks.erase(it), new_indices.begin(), new_indices.end());
+    return true;
+}
+
+CFG::AnalysisState CFG::AnalyzeInst(Block& block, FunctionId function_id, Location pc) {
+    const Instruction inst{env.ReadInstruction(pc.Offset())};
+    const Opcode opcode{Decode(inst.raw)};
+    switch (opcode) {
+    case Opcode::BRA:
+    case Opcode::BRX:
+    case Opcode::JMP:
+    case Opcode::JMX:
+    case Opcode::RET:
+        if (!AnalyzeBranch(block, function_id, pc, inst, opcode)) {
+            return AnalysisState::Continue;
+        }
+        switch (opcode) {
+        case Opcode::BRA:
+        case Opcode::JMP:
+            AnalyzeBRA(block, function_id, pc, inst, IsAbsoluteJump(opcode));
+            break;
+        case Opcode::BRX:
+        case Opcode::JMX:
+            AnalyzeBRX(block, pc, inst, IsAbsoluteJump(opcode));
+            break;
+        case Opcode::RET:
+            block.end_class = EndClass::Return;
+            break;
+        default:
+            break;
+        }
+        block.end = pc;
+        return AnalysisState::Branch;
+    case Opcode::BRK:
+    case Opcode::CONT:
+    case Opcode::LONGJMP:
+    case Opcode::SYNC: {
+        if (!AnalyzeBranch(block, function_id, pc, inst, opcode)) {
+            return AnalysisState::Continue;
+        }
+        const auto [stack_pc, new_stack]{block.stack.Pop(OpcodeToken(opcode))};
+        block.branch_true = AddLabel(block, new_stack, stack_pc, function_id);
+        block.end = pc;
+        return AnalysisState::Branch;
+    }
+    case Opcode::PBK:
+    case Opcode::PCNT:
+    case Opcode::PEXIT:
+    case Opcode::PLONGJMP:
+    case Opcode::SSY:
+        block.stack.Push(OpcodeToken(opcode), BranchOffset(pc, inst));
+        return AnalysisState::Continue;
+    case Opcode::EXIT:
+        return AnalyzeEXIT(block, function_id, pc, inst);
+    case Opcode::PRET:
+        throw NotImplementedException("PRET flow analysis");
+    case Opcode::CAL:
+    case Opcode::JCAL: {
+        const bool is_absolute{IsAbsoluteJump(opcode)};
+        const Location cal_pc{is_absolute ? inst.branch.Absolute() : BranchOffset(pc, inst)};
+        // Technically CAL pushes into PRET, but that's implicit in the function call for us
+        // Insert the function into the list if it doesn't exist
+        if (std::ranges::find(functions, cal_pc, &Function::entrypoint) == functions.end()) {
+            functions.push_back(cal_pc);
+        }
+        // Handle CAL like a regular instruction
+        break;
+    }
+    default:
+        break;
+    }
+    const Predicate pred{inst.Pred()};
+    if (pred == Predicate{true} || pred == Predicate{false}) {
+        return AnalysisState::Continue;
+    }
+    const IR::Condition cond{static_cast<IR::Pred>(pred.index), pred.negated};
+    AnalyzeCondInst(block, function_id, pc, EndClass::Branch, cond);
+    return AnalysisState::Branch;
+}
+
+void CFG::AnalyzeCondInst(Block& block, FunctionId function_id, Location pc,
+                          EndClass insn_end_class, IR::Condition cond) {
+    if (block.begin != pc) {
+        // If the block doesn't start in the conditional instruction
+        // mark it as a label to visit it later
+        block.end = pc;
+        block.cond = true;
+        block.branch_true = AddLabel(block, block.stack, pc, function_id);
+        block.branch_false = UNREACHABLE_BLOCK_ID;
+        return;
+    }
+    // Impersonate the visited block with a virtual block
+    // Jump from this virtual to the real conditional instruction and the next instruction
+    Function& function{functions[function_id]};
+    const BlockId conditional_block_id{++function.current_block_id};
+    function.blocks.push_back(static_cast<u32>(function.blocks_data.size()));
+    Block& virtual_block{function.blocks_data.emplace_back(Block{
+        .begin{}, // Virtual block
+        .end{},
+        .end_class{EndClass::Branch},
+        .id{block.id}, // Impersonating
+        .stack{block.stack},
+        .cond{cond},
+        .branch_true{conditional_block_id},
+        .branch_false{UNREACHABLE_BLOCK_ID},
+    })};
+    // Set the end properties of the conditional instruction and give it a new identity
+    Block& conditional_block{block};
+    conditional_block.end = pc;
+    conditional_block.end_class = insn_end_class;
+    conditional_block.id = conditional_block_id;
+    // Add a label to the instruction after the conditional instruction
+    const BlockId endif_block_id{AddLabel(conditional_block, block.stack, pc + 1, function_id)};
+    // Branch to the next instruction from the virtual block
+    virtual_block.branch_false = endif_block_id;
+    // And branch to it from the conditional instruction if it is a branch
+    if (insn_end_class == EndClass::Branch) {
+        conditional_block.cond = true;
+        conditional_block.branch_true = endif_block_id;
+        conditional_block.branch_false = UNREACHABLE_BLOCK_ID;
+    }
+}
+
+bool CFG::AnalyzeBranch(Block& block, FunctionId function_id, Location pc, Instruction inst,
+                        Opcode opcode) {
+    if (inst.branch.is_cbuf) {
+        throw NotImplementedException("Branch with constant buffer offset");
+    }
+    const Predicate pred{inst.Pred()};
+    if (pred == Predicate{false}) {
+        return false;
+    }
+    const bool has_flow_test{HasFlowTest(opcode)};
+    const IR::FlowTest flow_test{has_flow_test ? inst.branch.flow_test.Value() : IR::FlowTest::T};
+    if (pred != Predicate{true} || flow_test != IR::FlowTest::T) {
+        block.cond = IR::Condition(flow_test, static_cast<IR::Pred>(pred.index), pred.negated);
+        block.branch_false = AddLabel(block, block.stack, pc + 1, function_id);
+    } else {
+        block.cond = true;
+    }
+    return true;
+}
+
+void CFG::AnalyzeBRA(Block& block, FunctionId function_id, Location pc, Instruction inst,
+                     bool is_absolute) {
+    const Location bra_pc{is_absolute ? inst.branch.Absolute() : BranchOffset(pc, inst)};
+    block.branch_true = AddLabel(block, block.stack, bra_pc, function_id);
+}
+
+void CFG::AnalyzeBRX(Block&, Location, Instruction, bool is_absolute) {
+    throw NotImplementedException("{}", is_absolute ? "JMX" : "BRX");
+}
+
+void CFG::AnalyzeCAL(Location pc, Instruction inst, bool is_absolute) {
+    const Location cal_pc{is_absolute ? inst.branch.Absolute() : BranchOffset(pc, inst)};
+    // Technically CAL pushes into PRET, but that's implicit in the function call for us
+    // Insert the function to the function list if it doesn't exist
+    const auto it{std::ranges::find(functions, cal_pc, &Function::entrypoint)};
+    if (it == functions.end()) {
+        functions.emplace_back(cal_pc);
+    }
+}
+
+CFG::AnalysisState CFG::AnalyzeEXIT(Block& block, FunctionId function_id, Location pc,
+                                    Instruction inst) {
+    const IR::FlowTest flow_test{inst.branch.flow_test};
+    const Predicate pred{inst.Pred()};
+    if (pred == Predicate{false} || flow_test == IR::FlowTest::F) {
+        // EXIT will never be taken
+        return AnalysisState::Continue;
+    }
+    if (pred != Predicate{true} || flow_test != IR::FlowTest::T) {
+        if (block.stack.Peek(Token::PEXIT).has_value()) {
+            throw NotImplementedException("Conditional EXIT with PEXIT token");
+        }
+        const IR::Condition cond{flow_test, static_cast<IR::Pred>(pred.index), pred.negated};
+        AnalyzeCondInst(block, function_id, pc, EndClass::Exit, cond);
+        return AnalysisState::Branch;
+    }
+    if (const std::optional<Location> exit_pc{block.stack.Peek(Token::PEXIT)}) {
+        const Stack popped_stack{block.stack.Remove(Token::PEXIT)};
+        block.cond = true;
+        block.branch_true = AddLabel(block, popped_stack, *exit_pc, function_id);
+        block.branch_false = UNREACHABLE_BLOCK_ID;
+        return AnalysisState::Branch;
+    }
+    block.end = pc;
+    block.end_class = EndClass::Exit;
+    return AnalysisState::Branch;
+}
+
+BlockId CFG::AddLabel(const Block& block, Stack stack, Location pc, FunctionId function_id) {
+    Function& function{functions[function_id]};
+    if (block.begin == pc) {
+        return block.id;
+    }
+    const auto target{std::ranges::find(function.blocks_data, pc, &Block::begin)};
+    if (target != function.blocks_data.end()) {
+        return target->id;
+    }
+    const BlockId block_id{++function.current_block_id};
+    function.labels.push_back(Label{
+        .address{pc},
+        .block_id{block_id},
+        .stack{std::move(stack)},
+    });
+    return block_id;
+}
+
+std::string CFG::Dot() const {
+    int node_uid{0};
+
+    std::string dot{"digraph shader {\n"};
+    for (const Function& function : functions) {
+        dot += fmt::format("\tsubgraph cluster_{} {{\n", function.entrypoint);
+        dot += fmt::format("\t\tnode [style=filled];\n");
+        for (const u32 block_index : function.blocks) {
+            const Block& block{function.blocks_data[block_index]};
+            const std::string name{Name(block)};
+            const auto add_branch = [&](BlockId branch_id, bool add_label) {
+                const auto it{std::ranges::find(function.blocks_data, branch_id, &Block::id)};
+                dot += fmt::format("\t\t{}->", name);
+                if (it == function.blocks_data.end()) {
+                    dot += fmt::format("\"Unknown label {}\"", branch_id);
+                } else {
+                    dot += Name(*it);
+                };
+                if (add_label && block.cond != true && block.cond != false) {
+                    dot += fmt::format(" [label=\"{}\"]", block.cond);
+                }
+                dot += '\n';
+            };
+            dot += fmt::format("\t\t{};\n", name);
+            switch (block.end_class) {
+            case EndClass::Branch:
+                if (block.cond != false) {
+                    add_branch(block.branch_true, true);
+                }
+                if (block.cond != true) {
+                    add_branch(block.branch_false, false);
+                }
+                break;
+            case EndClass::Exit:
+                dot += fmt::format("\t\t{}->N{};\n", name, node_uid);
+                dot += fmt::format("\t\tN{} [label=\"Exit\"][shape=square][style=stripped];\n",
+                                   node_uid);
+                ++node_uid;
+                break;
+            case EndClass::Return:
+                dot += fmt::format("\t\t{}->N{};\n", name, node_uid);
+                dot += fmt::format("\t\tN{} [label=\"Return\"][shape=square][style=stripped];\n",
+                                   node_uid);
+                ++node_uid;
+                break;
+            case EndClass::Unreachable:
+                dot += fmt::format("\t\t{}->N{};\n", name, node_uid);
+                dot += fmt::format(
+                    "\t\tN{} [label=\"Unreachable\"][shape=square][style=stripped];\n", node_uid);
+                ++node_uid;
+                break;
+            }
+        }
+        if (function.entrypoint == 8) {
+            dot += fmt::format("\t\tlabel = \"main\";\n");
+        } else {
+            dot += fmt::format("\t\tlabel = \"Function {}\";\n", function.entrypoint);
+        }
+        dot += "\t}\n";
+    }
+    if (!functions.empty()) {
+        if (functions.front().blocks.empty()) {
+            dot += "Start;\n";
+        } else {
+            dot += fmt::format("\tStart -> {};\n", Name(functions.front().blocks_data.front()));
+        }
+        dot += fmt::format("\tStart [shape=diamond];\n");
+    }
+    dot += "}\n";
+    return dot;
+}
+
+} // namespace Shader::Maxwell::Flow
-- 
cgit v1.2.3-70-g09d2