diff options
Diffstat (limited to 'externals/sirit/src/stream.h')
-rw-r--r-- | externals/sirit/src/stream.h | 266 |
1 files changed, 266 insertions, 0 deletions
diff --git a/externals/sirit/src/stream.h b/externals/sirit/src/stream.h new file mode 100644 index 0000000000..7029b6c4e7 --- /dev/null +++ b/externals/sirit/src/stream.h @@ -0,0 +1,266 @@ +/* This file is part of the sirit project. + * Copyright (c) 2019 sirit + * This software may be used and distributed according to the terms of the + * 3-Clause BSD License + */ + +#pragma once + +#include <bit> +#include <cassert> +#include <concepts> +#include <cstddef> +#include <functional> +#include <string_view> +#include <unordered_map> +#include <variant> +#include <vector> +#include <utility> + +#ifndef __cpp_lib_bit_cast +#include <cstring> +#endif + +#include <spirv/unified1/spirv.hpp> + +#include "common_types.h" + +namespace Sirit { + +class Declarations; + +struct OpId { + OpId(spv::Op opcode_) : opcode{opcode_} {} + OpId(spv::Op opcode_, Id result_type_) : opcode{opcode_}, result_type{result_type_} { + assert(result_type.value != 0); + } + + spv::Op opcode{}; + Id result_type{}; +}; + +struct EndOp {}; + +inline size_t WordsInString(std::string_view string) { + return string.size() / sizeof(u32) + 1; +} + +inline void InsertStringView(std::vector<u32>& words, size_t& insert_index, + std::string_view string) { + const size_t size = string.size(); + const auto read = [string, size](size_t offset) { + return offset < size ? static_cast<u32>(string[offset]) : 0u; + }; + + for (size_t i = 0; i < size; i += sizeof(u32)) { + words[insert_index++] = read(i) | read(i + 1) << 8 | read(i + 2) << 16 | read(i + 3) << 24; + } + if (size % sizeof(u32) == 0) { + words[insert_index++] = 0; + } +} + +class Stream { + friend Declarations; + +public: + explicit Stream(u32* bound_) : bound{bound_} {} + + void Reserve(size_t num_words) { + if (insert_index + num_words <= words.size()) { + return; + } + words.resize(insert_index + num_words); + } + + std::span<const u32> Words() const noexcept { + return std::span(words.data(), insert_index); + } + + u32 LocalAddress() const noexcept { + return static_cast<u32>(words.size()); + } + + u32 Value(u32 index) const noexcept { + return words[index]; + } + + void SetValue(u32 index, u32 value) noexcept { + words[index] = value; + } + + Stream& operator<<(spv::Op op) { + op_index = insert_index; + words[insert_index++] = static_cast<u32>(op); + return *this; + } + + Stream& operator<<(OpId op) { + op_index = insert_index; + words[insert_index++] = static_cast<u32>(op.opcode); + if (op.result_type.value != 0) { + words[insert_index++] = op.result_type.value; + } + words[insert_index++] = ++*bound; + return *this; + } + + Id operator<<(EndOp) { + const size_t num_words = insert_index - op_index; + words[op_index] |= static_cast<u32>(num_words) << 16; + return Id{*bound}; + } + + Stream& operator<<(u32 value) { + words[insert_index++] = value; + return *this; + } + + Stream& operator<<(s32 value) { + return *this << static_cast<u32>(value); + } + + Stream& operator<<(u64 value) { + return *this << static_cast<u32>(value) << static_cast<u32>(value >> 32); + } + + Stream& operator<<(s64 value) { + return *this << static_cast<u64>(value); + } + + Stream& operator<<(float value) { +#ifdef __cpp_lib_bit_cast + return *this << std::bit_cast<u32>(value); +#else + static_assert(sizeof(float) == sizeof(u32)); + u32 int_value; + std::memcpy(&int_value, &value, sizeof(int_value)); + return *this << int_value; +#endif + } + + Stream& operator<<(double value) { +#ifdef __cpp_lib_bit_cast + return *this << std::bit_cast<u64>(value); +#else + static_assert(sizeof(double) == sizeof(u64)); + u64 int_value; + std::memcpy(&int_value, &value, sizeof(int_value)); + return *this << int_value; +#endif + } + + Stream& operator<<(bool value) { + return *this << static_cast<u32>(value ? 1 : 0); + } + + Stream& operator<<(Id value) { + assert(value.value != 0); + return *this << value.value; + } + + Stream& operator<<(const Literal& literal) { + std::visit([this](auto value) { *this << value; }, literal); + return *this; + } + + Stream& operator<<(std::string_view string) { + InsertStringView(words, insert_index, string); + return *this; + } + + Stream& operator<<(const char* string) { + return *this << std::string_view{string}; + } + + template <typename T> + requires std::is_enum_v<T> Stream& operator<<(T value) { + static_assert(sizeof(T) == sizeof(u32)); + return *this << static_cast<u32>(value); + } + + template <typename T> + Stream& operator<<(std::optional<T> value) { + if (value) { + *this << *value; + } + return *this; + } + + template <typename T> + Stream& operator<<(std::span<const T> values) { + for (const auto& value : values) { + *this << value; + } + return *this; + } + +private: + u32* bound = nullptr; + std::vector<u32> words; + size_t insert_index = 0; + size_t op_index = 0; +}; + +class Declarations { +public: + explicit Declarations(u32* bound) : stream{bound} {} + + void Reserve(size_t num_words) { + return stream.Reserve(num_words); + } + + std::span<const u32> Words() const noexcept { + return stream.Words(); + } + + template <typename T> + Declarations& operator<<(const T& value) { + stream << value; + return *this; + } + + // Declarations without an id don't exist + Declarations& operator<<(spv::Op) = delete; + + Declarations& operator<<(OpId op) { + id_index = op.result_type.value != 0 ? 2 : 1; + stream << op; + return *this; + } + + Id operator<<(EndOp) { + const auto begin = stream.words.data(); + std::vector<u32> declarations(begin + stream.op_index, begin + stream.insert_index); + + // Normalize result id for lookups + const u32 id = std::exchange(declarations[id_index], 0); + + const auto [entry, inserted] = existing_declarations.emplace(declarations, id); + if (inserted) { + return stream << EndOp{}; + } + // If the declaration already exists, undo the operation + stream.insert_index = stream.op_index; + --*stream.bound; + + return Id{entry->second}; + } + +private: + struct HashVector { + size_t operator()(const std::vector<u32>& vector) const noexcept { + size_t hash = std::hash<size_t>{}(vector.size()); + for (const u32 value : vector) { + hash ^= std::hash<u32>{}(value); + } + return hash; + } + }; + + Stream stream; + std::unordered_map<std::vector<u32>, u32, HashVector> existing_declarations; + size_t id_index = 0; +}; + +} // namespace Sirit |