/* This file is part of the sirit project.
 * Copyright (c) 2019 sirit
 * This software may be used and distributed according to the terms of the
 * 3-Clause BSD License
 */

#pragma once

#include <bit>
#include <cassert>
#include <concepts>
#include <cstddef>
#include <functional>
#include <string_view>
#include <unordered_map>
#include <variant>
#include <vector>
#include <utility>

#ifndef __cpp_lib_bit_cast
#include <cstring>
#endif

#include <spirv/unified1/spirv.hpp>

#include "common_types.h"

namespace Sirit {

class Declarations;

struct OpId {
    OpId(spv::Op opcode_) : opcode{opcode_} {}
    OpId(spv::Op opcode_, Id result_type_) : opcode{opcode_}, result_type{result_type_} {
        assert(result_type.value != 0);
    }

    spv::Op opcode{};
    Id result_type{};
};

struct EndOp {};

inline size_t WordsInString(std::string_view string) {
    return string.size() / sizeof(u32) + 1;
}

inline void InsertStringView(std::vector<u32>& words, size_t& insert_index,
                             std::string_view string) {
    const size_t size = string.size();
    const auto read = [string, size](size_t offset) {
        return offset < size ? static_cast<u32>(string[offset]) : 0u;
    };

    for (size_t i = 0; i < size; i += sizeof(u32)) {
        words[insert_index++] = read(i) | read(i + 1) << 8 | read(i + 2) << 16 | read(i + 3) << 24;
    }
    if (size % sizeof(u32) == 0) {
        words[insert_index++] = 0;
    }
}

class Stream {
    friend Declarations;

public:
    explicit Stream(u32* bound_) : bound{bound_} {}

    void Reserve(size_t num_words) {
        if (insert_index + num_words <= words.size()) {
            return;
        }
        words.resize(insert_index + num_words);
    }

    std::span<const u32> Words() const noexcept {
        return std::span(words.data(), insert_index);
    }

    u32 LocalAddress() const noexcept {
        return static_cast<u32>(words.size());
    }

    u32 Value(u32 index) const noexcept {
        return words[index];
    }

    void SetValue(u32 index, u32 value) noexcept {
        words[index] = value;
    }

    Stream& operator<<(spv::Op op) {
        op_index = insert_index;
        words[insert_index++] = static_cast<u32>(op);
        return *this;
    }

    Stream& operator<<(OpId op) {
        op_index = insert_index;
        words[insert_index++] = static_cast<u32>(op.opcode);
        if (op.result_type.value != 0) {
            words[insert_index++] = op.result_type.value;
        }
        words[insert_index++] = ++*bound;
        return *this;
    }

    Id operator<<(EndOp) {
        const size_t num_words = insert_index - op_index;
        words[op_index] |= static_cast<u32>(num_words) << 16;
        return Id{*bound};
    }

    Stream& operator<<(u32 value) {
        words[insert_index++] = value;
        return *this;
    }

    Stream& operator<<(s32 value) {
        return *this << static_cast<u32>(value);
    }

    Stream& operator<<(u64 value) {
        return *this << static_cast<u32>(value) << static_cast<u32>(value >> 32);
    }

    Stream& operator<<(s64 value) {
        return *this << static_cast<u64>(value);
    }

    Stream& operator<<(float value) {
#ifdef __cpp_lib_bit_cast
        return *this << std::bit_cast<u32>(value);
#else
        static_assert(sizeof(float) == sizeof(u32));
        u32 int_value;
        std::memcpy(&int_value, &value, sizeof(int_value));
        return *this << int_value;
#endif
    }

    Stream& operator<<(double value) {
#ifdef __cpp_lib_bit_cast
        return *this << std::bit_cast<u64>(value);
#else
        static_assert(sizeof(double) == sizeof(u64));
        u64 int_value;
        std::memcpy(&int_value, &value, sizeof(int_value));
        return *this << int_value;
#endif
    }

    Stream& operator<<(bool value) {
        return *this << static_cast<u32>(value ? 1 : 0);
    }

    Stream& operator<<(Id value) {
        assert(value.value != 0);
        return *this << value.value;
    }

    Stream& operator<<(const Literal& literal) {
        std::visit([this](auto value) { *this << value; }, literal);
        return *this;
    }

    Stream& operator<<(std::string_view string) {
        InsertStringView(words, insert_index, string);
        return *this;
    }

    Stream& operator<<(const char* string) {
        return *this << std::string_view{string};
    }

    template <typename T>
    requires std::is_enum_v<T> Stream& operator<<(T value) {
        static_assert(sizeof(T) == sizeof(u32));
        return *this << static_cast<u32>(value);
    }

    template <typename T>
    Stream& operator<<(std::optional<T> value) {
        if (value) {
            *this << *value;
        }
        return *this;
    }

    template <typename T>
    Stream& operator<<(std::span<const T> values) {
        for (const auto& value : values) {
            *this << value;
        }
        return *this;
    }

private:
    u32* bound = nullptr;
    std::vector<u32> words;
    size_t insert_index = 0;
    size_t op_index = 0;
};

class Declarations {
public:
    explicit Declarations(u32* bound) : stream{bound} {}

    void Reserve(size_t num_words) {
        return stream.Reserve(num_words);
    }

    std::span<const u32> Words() const noexcept {
        return stream.Words();
    }

    template <typename T>
    Declarations& operator<<(const T& value) {
        stream << value;
        return *this;
    }

    // Declarations without an id don't exist
    Declarations& operator<<(spv::Op) = delete;

    Declarations& operator<<(OpId op) {
        id_index = op.result_type.value != 0 ? 2 : 1;
        stream << op;
        return *this;
    }

    Id operator<<(EndOp) {
        const auto begin = stream.words.data();
        std::vector<u32> declarations(begin + stream.op_index, begin + stream.insert_index);

        // Normalize result id for lookups
        const u32 id = std::exchange(declarations[id_index], 0);

        const auto [entry, inserted] = existing_declarations.emplace(declarations, id);
        if (inserted) {
            return stream << EndOp{};
        }
        // If the declaration already exists, undo the operation
        stream.insert_index = stream.op_index;
        --*stream.bound;

        return Id{entry->second};
    }

private:
    struct HashVector {
        size_t operator()(const std::vector<u32>& vector) const noexcept {
            size_t hash = std::hash<size_t>{}(vector.size());
            for (const u32 value : vector) {
                hash ^= std::hash<u32>{}(value);
            }
            return hash;
        }
    };

    Stream stream;
    std::unordered_map<std::vector<u32>, u32, HashVector> existing_declarations;
    size_t id_index = 0;
};

} // namespace Sirit