From 8750b90a7f5e76cdff991a137ec8c2eed0db00dd Mon Sep 17 00:00:00 2001
From: gdkchan <gab.dark.100@gmail.com>
Date: Wed, 30 Nov 2022 18:06:40 -0300
Subject: Ensure that vertex attribute buffer index is valid on GPU (#3942)

* Ensure that vertex attribute buffer index is valid on GPU

* Remove vertex buffer validation code from OpenGL

* Remove some fields that are no longer necessary
---
 Ryujinx.Graphics.Gpu/Engine/Threed/StateUpdater.cs | 27 ++++++++++++++-
 Ryujinx.Graphics.OpenGL/VertexArray.cs             | 40 ----------------------
 2 files changed, 26 insertions(+), 41 deletions(-)

diff --git a/Ryujinx.Graphics.Gpu/Engine/Threed/StateUpdater.cs b/Ryujinx.Graphics.Gpu/Engine/Threed/StateUpdater.cs
index d51077dc..8da5ea5e 100644
--- a/Ryujinx.Graphics.Gpu/Engine/Threed/StateUpdater.cs
+++ b/Ryujinx.Graphics.Gpu/Engine/Threed/StateUpdater.cs
@@ -37,6 +37,7 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
         private bool _vsUsesDrawParameters;
         private bool _vtgWritesRtLayer;
         private byte _vsClipDistancesWritten;
+        private uint _vbEnableMask;
 
         private bool _prevDrawIndexed;
         private bool _prevDrawIndirect;
@@ -76,6 +77,7 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
                     nameof(ThreedClassState.VertexBufferState),
                     nameof(ThreedClassState.VertexBufferEndAddress)),
 
+                // Must be done after vertex buffer updates.
                 new StateUpdateCallbackEntry(UpdateVertexAttribState, nameof(ThreedClassState.VertexAttribState)),
 
                 new StateUpdateCallbackEntry(UpdateBlendState,
@@ -852,12 +854,23 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
         /// </summary>
         private void UpdateVertexAttribState()
         {
+            uint vbEnableMask = _vbEnableMask;
+
             Span<VertexAttribDescriptor> vertexAttribs = stackalloc VertexAttribDescriptor[Constants.TotalVertexAttribs];
 
             for (int index = 0; index < Constants.TotalVertexAttribs; index++)
             {
                 var vertexAttrib = _state.State.VertexAttribState[index];
 
+                int bufferIndex = vertexAttrib.UnpackBufferIndex();
+
+                if ((vbEnableMask & (1u << bufferIndex)) == 0)
+                {
+                    // Using a vertex buffer that doesn't exist is invalid, so let's use a dummy attribute for those cases.
+                    vertexAttribs[index] = new VertexAttribDescriptor(0, 0, true, Format.R32G32B32A32Float);
+                    continue;
+                }
+
                 if (!FormatTable.TryGetAttribFormat(vertexAttrib.UnpackFormat(), out Format format))
                 {
                     Logger.Debug?.Print(LogClass.Gpu, $"Invalid attribute format 0x{vertexAttrib.UnpackFormat():X}.");
@@ -866,7 +879,7 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
                 }
 
                 vertexAttribs[index] = new VertexAttribDescriptor(
-                    vertexAttrib.UnpackBufferIndex(),
+                    bufferIndex,
                     vertexAttrib.UnpackOffset(),
                     vertexAttrib.UnpackIsConstant(),
                     format);
@@ -954,6 +967,7 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
 
             bool drawIndexed = _drawState.DrawIndexed;
             bool drawIndirect = _drawState.DrawIndirect;
+            uint vbEnableMask = 0;
 
             for (int index = 0; index < Constants.TotalVertexBuffers; index++)
             {
@@ -971,6 +985,11 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
 
                 ulong address = vertexBuffer.Address.Pack();
 
+                if (_channel.MemoryManager.IsMapped(address))
+                {
+                    vbEnableMask |= 1u << index;
+                }
+
                 int stride = vertexBuffer.UnpackStride();
 
                 bool instanced = _state.State.VertexBufferInstanced[index];
@@ -1017,6 +1036,12 @@ namespace Ryujinx.Graphics.Gpu.Engine.Threed
                 _pipeline.VertexBuffers[index] = new BufferPipelineDescriptor(_channel.MemoryManager.IsMapped(address), stride, divisor);
                 _channel.BufferManager.SetVertexBuffer(index, address, size, stride, divisor);
             }
+
+            if (_vbEnableMask != vbEnableMask)
+            {
+                _vbEnableMask = vbEnableMask;
+                UpdateVertexAttribState();
+            }
         }
 
         /// <summary>
diff --git a/Ryujinx.Graphics.OpenGL/VertexArray.cs b/Ryujinx.Graphics.OpenGL/VertexArray.cs
index d466199d..7d22033e 100644
--- a/Ryujinx.Graphics.OpenGL/VertexArray.cs
+++ b/Ryujinx.Graphics.OpenGL/VertexArray.cs
@@ -10,13 +10,9 @@ namespace Ryujinx.Graphics.OpenGL
     {
         public int Handle { get; private set; }
 
-        private bool _needsAttribsUpdate;
-
         private readonly VertexAttribDescriptor[] _vertexAttribs;
         private readonly VertexBufferDescriptor[] _vertexBuffers;
 
-        private int _vertexAttribsCount;
-        private int _vertexBuffersCount;
         private int _minVertexCount;
 
         private uint _vertexAttribsInUse;
@@ -76,9 +72,7 @@ namespace Ryujinx.Graphics.OpenGL
                 _vertexBuffers[bindingIndex] = vb;
             }
 
-            _vertexBuffersCount = bindingIndex;
             _minVertexCount = minVertexCount;
-            _needsAttribsUpdate = true;
         }
 
         public void SetVertexAttributes(ReadOnlySpan<VertexAttribDescriptor> vertexAttribs)
@@ -131,8 +125,6 @@ namespace Ryujinx.Graphics.OpenGL
                 _vertexAttribs[index] = attrib;
             }
 
-            _vertexAttribsCount = index;
-
             for (; index < Constants.MaxVertexAttribs; index++)
             {
                 DisableVertexAttrib(index);
@@ -160,13 +152,11 @@ namespace Ryujinx.Graphics.OpenGL
         public void PreDraw(int vertexCount)
         {
             LimitVertexBuffers(vertexCount);
-            Validate();
         }
 
         public void PreDrawVbUnbounded()
         {
             UnlimitVertexBuffers();
-            Validate();
         }
 
         public void LimitVertexBuffers(int vertexCount)
@@ -252,36 +242,6 @@ namespace Ryujinx.Graphics.OpenGL
             _vertexBuffersLimited = 0;
         }
 
-        public void Validate()
-        {
-            for (int attribIndex = 0; attribIndex < _vertexAttribsCount; attribIndex++)
-            {
-                VertexAttribDescriptor attrib = _vertexAttribs[attribIndex];
-
-                if (!attrib.IsZero)
-                {
-                    if ((uint)attrib.BufferIndex >= _vertexBuffersCount)
-                    {
-                        DisableVertexAttrib(attribIndex);
-                        continue;
-                    }
-
-                    if (_vertexBuffers[attrib.BufferIndex].Buffer.Handle == BufferHandle.Null)
-                    {
-                        DisableVertexAttrib(attribIndex);
-                        continue;
-                    }
-
-                    if (_needsAttribsUpdate)
-                    {
-                        EnableVertexAttrib(attribIndex);
-                    }
-                }
-            }
-
-            _needsAttribsUpdate = false;
-        }
-
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
         private void EnableVertexAttrib(int index)
         {
-- 
cgit v1.2.3-70-g09d2