From dbd882ddeb1a1a9233c0085d0b8ccb022db385b2 Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Sat, 27 Mar 2021 04:59:58 -0300
Subject: shader: Better interpolation and disabled attributes support

---
 .../backend/spirv/emit_context.cpp                 | 29 ++++++++++++++++--
 .../backend/spirv/emit_spirv_context_get_set.cpp   | 29 ++++++++++++------
 src/shader_recompiler/frontend/maxwell/program.cpp | 35 ++++++++++++++++++++++
 .../translate/impl/load_store_attribute.cpp        | 10 +------
 .../ir_opt/collect_shader_info_pass.cpp            |  2 +-
 src/shader_recompiler/profile.h                    |  1 +
 src/shader_recompiler/shader_info.h                | 13 +++++++-
 7 files changed, 96 insertions(+), 23 deletions(-)

(limited to 'src/shader_recompiler')

diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp
index 4d5dabcbfd..a8ca33c1db 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_context.cpp
@@ -76,6 +76,8 @@ Id GetAttributeType(EmitContext& ctx, AttributeType type) {
         return ctx.TypeVector(ctx.TypeInt(32, true), 4);
     case AttributeType::UnsignedInt:
         return ctx.U32[4];
+    case AttributeType::Disabled:
+        break;
     }
     throw InvalidArgument("Invalid attribute type {}", type);
 }
@@ -305,15 +307,36 @@ void EmitContext::DefineInputs(const Info& info) {
     if (info.loads_front_face) {
         front_face = DefineInput(*this, U1, spv::BuiltIn::FrontFacing);
     }
-    for (size_t index = 0; index < info.loads_generics.size(); ++index) {
-        if (!info.loads_generics[index]) {
+    for (size_t index = 0; index < info.input_generics.size(); ++index) {
+        const InputVarying generic{info.input_generics[index]};
+        if (!generic.used) {
             continue;
         }
-        const Id type{GetAttributeType(*this, profile.generic_input_types[index])};
+        const AttributeType input_type{profile.generic_input_types[index]};
+        if (input_type == AttributeType::Disabled) {
+            continue;
+        }
+        const Id type{GetAttributeType(*this, input_type)};
         const Id id{DefineInput(*this, type)};
         Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
         Name(id, fmt::format("in_attr{}", index));
         input_generics[index] = id;
+
+        if (stage != Stage::Fragment) {
+            continue;
+        }
+        switch (generic.interpolation) {
+        case Interpolation::Smooth:
+            // Default
+            // Decorate(id, spv::Decoration::Smooth);
+            break;
+        case Interpolation::NoPerspective:
+            Decorate(id, spv::Decoration::NoPerspective);
+            break;
+        case Interpolation::Flat:
+            Decorate(id, spv::Decoration::Flat);
+            break;
+        }
     }
 }
 
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
index 6fa16eb805..4cbc2aec10 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
@@ -10,16 +10,23 @@
 
 namespace Shader::Backend::SPIRV {
 namespace {
-std::tuple<Id, Id, bool> AttrTypes(EmitContext& ctx, u32 index) {
-    const bool is_first_reader{ctx.stage == Stage::VertexB};
+struct AttrInfo {
+    Id pointer;
+    Id id;
+    bool needs_cast;
+};
+
+std::optional<AttrInfo> AttrTypes(EmitContext& ctx, u32 index) {
     const AttributeType type{ctx.profile.generic_input_types.at(index)};
     switch (type) {
     case AttributeType::Float:
-        return {ctx.input_f32, ctx.F32[1], false};
+        return AttrInfo{ctx.input_f32, ctx.F32[1], false};
     case AttributeType::UnsignedInt:
-        return {ctx.input_u32, ctx.U32[1], true};
+        return AttrInfo{ctx.input_u32, ctx.U32[1], true};
     case AttributeType::SignedInt:
-        return {ctx.input_s32, ctx.TypeInt(32, true), true};
+        return AttrInfo{ctx.input_s32, ctx.TypeInt(32, true), true};
+    case AttributeType::Disabled:
+        return std::nullopt;
     }
     throw InvalidArgument("Invalid attribute type {}", type);
 }
@@ -129,11 +136,15 @@ Id EmitGetAttribute(EmitContext& ctx, IR::Attribute attr) {
     const auto element_id{[&] { return ctx.Constant(ctx.U32[1], element); }};
     if (IR::IsGeneric(attr)) {
         const u32 index{IR::GenericAttributeIndex(attr)};
-        const auto [pointer_type, type, needs_cast]{AttrTypes(ctx, index)};
+        const std::optional<AttrInfo> type{AttrTypes(ctx, index)};
+        if (!type) {
+            // Attribute is disabled
+            return ctx.Constant(ctx.F32[1], 0.0f);
+        }
         const Id generic_id{ctx.input_generics.at(index)};
-        const Id pointer{ctx.OpAccessChain(pointer_type, generic_id, element_id())};
-        const Id value{ctx.OpLoad(type, pointer)};
-        return needs_cast ? ctx.OpBitcast(ctx.F32[1], value) : value;
+        const Id pointer{ctx.OpAccessChain(type->pointer, generic_id, element_id())};
+        const Id value{ctx.OpLoad(type->id, pointer)};
+        return type->needs_cast ? ctx.OpBitcast(ctx.F32[1], value) : value;
     }
     switch (attr) {
     case IR::Attribute::PositionX:
diff --git a/src/shader_recompiler/frontend/maxwell/program.cpp b/src/shader_recompiler/frontend/maxwell/program.cpp
index 6efaf6ee08..a914a91f48 100644
--- a/src/shader_recompiler/frontend/maxwell/program.cpp
+++ b/src/shader_recompiler/frontend/maxwell/program.cpp
@@ -27,6 +27,40 @@ static void RemoveUnreachableBlocks(IR::Program& program) {
     });
 }
 
+static void CollectInterpolationInfo(Environment& env, IR::Program& program) {
+    if (program.stage != Stage::Fragment) {
+        return;
+    }
+    const ProgramHeader& sph{env.SPH()};
+    for (size_t index = 0; index < program.info.input_generics.size(); ++index) {
+        std::optional<PixelImap> imap;
+        for (const PixelImap value : sph.ps.GenericInputMap(static_cast<u32>(index))) {
+            if (value == PixelImap::Unused) {
+                continue;
+            }
+            if (imap && imap != value) {
+                throw NotImplementedException("Per component interpolation");
+            }
+            imap = value;
+        }
+        if (!imap) {
+            continue;
+        }
+        program.info.input_generics[index].interpolation = [&] {
+            switch (*imap) {
+            case PixelImap::Unused:
+            case PixelImap::Perspective:
+                return Interpolation::Smooth;
+            case PixelImap::Constant:
+                return Interpolation::Flat;
+            case PixelImap::ScreenLinear:
+                return Interpolation::NoPerspective;
+            }
+            throw NotImplementedException("Unknown interpolation {}", *imap);
+        }();
+    }
+}
+
 IR::Program TranslateProgram(ObjectPool<IR::Inst>& inst_pool, ObjectPool<IR::Block>& block_pool,
                              Environment& env, Flow::CFG& cfg) {
     IR::Program program;
@@ -51,6 +85,7 @@ IR::Program TranslateProgram(ObjectPool<IR::Inst>& inst_pool, ObjectPool<IR::Blo
     Optimization::IdentityRemovalPass(program);
     Optimization::VerificationPass(program);
     Optimization::CollectShaderInfoPass(program);
+    CollectInterpolationInfo(env, program);
     return program;
 }
 
diff --git a/src/shader_recompiler/frontend/maxwell/translate/impl/load_store_attribute.cpp b/src/shader_recompiler/frontend/maxwell/translate/impl/load_store_attribute.cpp
index 516ffec2da..54bc1e34c3 100644
--- a/src/shader_recompiler/frontend/maxwell/translate/impl/load_store_attribute.cpp
+++ b/src/shader_recompiler/frontend/maxwell/translate/impl/load_store_attribute.cpp
@@ -151,16 +151,8 @@ void TranslatorVisitor::IPA(u64 insn) {
             value = ir.FPMul(value, position_w);
         }
     }
-    switch (ipa.interpolation_mode) {
-    case InterpolationMode::Pass:
-        break;
-    case InterpolationMode::Multiply:
+    if (ipa.interpolation_mode == InterpolationMode::Multiply) {
         value = ir.FPMul(value, F(ipa.multiplier));
-        break;
-    case InterpolationMode::Constant:
-        throw NotImplementedException("IPA.CONSTANT");
-    case InterpolationMode::Sc:
-        throw NotImplementedException("IPA.SC");
     }
 
     // Saturated IPAs are generally generated out of clamped varyings.
diff --git a/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp b/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
index 0ec0d4c019..60be672283 100644
--- a/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
+++ b/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
@@ -28,7 +28,7 @@ void AddConstantBufferDescriptor(Info& info, u32 index, u32 count) {
 
 void GetAttribute(Info& info, IR::Attribute attribute) {
     if (IR::IsGeneric(attribute)) {
-        info.loads_generics.at(IR::GenericAttributeIndex(attribute)) = true;
+        info.input_generics.at(IR::GenericAttributeIndex(attribute)).used = true;
         return;
     }
     switch (attribute) {
diff --git a/src/shader_recompiler/profile.h b/src/shader_recompiler/profile.h
index 41550bfc63..e260477511 100644
--- a/src/shader_recompiler/profile.h
+++ b/src/shader_recompiler/profile.h
@@ -14,6 +14,7 @@ enum class AttributeType : u8 {
     Float,
     SignedInt,
     UnsignedInt,
+    Disabled,
 };
 
 struct Profile {
diff --git a/src/shader_recompiler/shader_info.h b/src/shader_recompiler/shader_info.h
index 8ab66bb2ab..9111159f36 100644
--- a/src/shader_recompiler/shader_info.h
+++ b/src/shader_recompiler/shader_info.h
@@ -31,6 +31,17 @@ enum class TextureType : u32 {
     ShadowArrayCube,
 };
 
+enum class Interpolation {
+    Smooth,
+    Flat,
+    NoPerspective,
+};
+
+struct InputVarying {
+    Interpolation interpolation{Interpolation::Smooth};
+    bool used{false};
+};
+
 struct TextureDescriptor {
     TextureType type;
     u32 cbuf_index;
@@ -58,7 +69,7 @@ struct Info {
     bool uses_local_invocation_id{};
     bool uses_subgroup_invocation_id{};
 
-    std::array<bool, 32> loads_generics{};
+    std::array<InputVarying, 32> input_generics{};
     bool loads_position{};
     bool loads_instance_id{};
     bool loads_vertex_id{};
-- 
cgit v1.2.3-70-g09d2