Skip to content

Commit 2adf7cf

Browse files
authored
Merge pull request #57 from smoogipoo/reduce-mtl-cb-allocs
Use constant storage space for MTL command buffers
2 parents fe61932 + 675f318 commit 2adf7cf

File tree

4 files changed

+188
-31
lines changed

4 files changed

+188
-31
lines changed

src/Veldrid.MetalBindings/MTLCommandBuffer.cs

+17-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
namespace Veldrid.MetalBindings
66
{
77
[StructLayout(LayoutKind.Sequential)]
8-
public struct MTLCommandBuffer
8+
public struct MTLCommandBuffer : IEquatable<MTLCommandBuffer>
99
{
1010
public readonly IntPtr NativePtr;
1111

@@ -29,6 +29,7 @@ public MTLComputeCommandEncoder computeCommandEncoder()
2929

3030
public void addCompletedHandler(MTLCommandBufferHandler block)
3131
=> objc_msgSend(NativePtr, sel_addCompletedHandler, block);
32+
3233
public void addCompletedHandler(IntPtr block)
3334
=> objc_msgSend(NativePtr, sel_addCompletedHandler, block);
3435

@@ -42,5 +43,20 @@ public void addCompletedHandler(IntPtr block)
4243
private static readonly Selector sel_waitUntilCompleted = "waitUntilCompleted";
4344
private static readonly Selector sel_addCompletedHandler = "addCompletedHandler:";
4445
private static readonly Selector sel_status = "status";
46+
47+
public bool Equals(MTLCommandBuffer other)
48+
{
49+
return NativePtr == other.NativePtr;
50+
}
51+
52+
public override bool Equals(object obj)
53+
{
54+
return obj is MTLCommandBuffer other && Equals(other);
55+
}
56+
57+
public override int GetHashCode()
58+
{
59+
return NativePtr.GetHashCode();
60+
}
4561
}
4662
}
+151
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
// Copyright (c) ppy Pty Ltd <[email protected]>. Licensed under the MIT Licence.
2+
// See the LICENCE file in the repository root for full licence text.
3+
4+
using System.Collections;
5+
using System.Collections.Generic;
6+
using System.Runtime.InteropServices;
7+
using Veldrid.MetalBindings;
8+
9+
namespace Veldrid.MTL
10+
{
11+
internal class CommandBufferUsageList<T>
12+
{
13+
private readonly List<(MTLCommandBuffer buffer, T value)> items = new List<(MTLCommandBuffer buffer, T item)>();
14+
15+
public void Add(MTLCommandBuffer cb, T value)
16+
=> items.Add((cb, value));
17+
18+
public ItemsEnumerator EnumerateItems()
19+
=> new ItemsEnumerator(items);
20+
21+
public RemovalEnumerator EnumerateAndRemove(MTLCommandBuffer cb)
22+
=> new RemovalEnumerator(items, cb);
23+
24+
public bool Contains(MTLCommandBuffer cb)
25+
{
26+
foreach (var (buffer, _) in items)
27+
{
28+
if (buffer.Equals(cb))
29+
return true;
30+
}
31+
32+
return false;
33+
}
34+
35+
public void Clear()
36+
=> items.Clear();
37+
38+
/// <summary>
39+
/// This is a basic enumerator for the list.
40+
/// </summary>
41+
public struct ItemsEnumerator : IEnumerator<T>, IEnumerable
42+
{
43+
private readonly List<(MTLCommandBuffer buffer, T value)> list;
44+
private int index;
45+
46+
public ItemsEnumerator(List<(MTLCommandBuffer buffer, T value)> list)
47+
{
48+
this.list = list;
49+
}
50+
51+
public bool MoveNext()
52+
{
53+
if (index == list.Count)
54+
return false;
55+
56+
Current = list[index].value;
57+
index++;
58+
59+
return true;
60+
}
61+
62+
public void Reset()
63+
{
64+
index = 0;
65+
}
66+
67+
public T Current { get; private set; }
68+
69+
object IEnumerator.Current => Current;
70+
71+
public void Dispose()
72+
{
73+
}
74+
75+
public ItemsEnumerator GetEnumerator() => this;
76+
77+
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
78+
}
79+
80+
/// <summary>
81+
/// This is a combined enumerate + remove enumerator for the list.
82+
///
83+
/// It works by duplicating the items that shall be retained to the end of the list
84+
/// and then moving them in-place to the front of the list upon disposal.
85+
///
86+
/// The combined operation has therefore O(n) time complexity.
87+
/// </summary>
88+
public struct RemovalEnumerator : IEnumerator<T>, IEnumerable
89+
{
90+
private readonly List<(MTLCommandBuffer buffer, T value)> list;
91+
private readonly MTLCommandBuffer cb;
92+
private readonly int count;
93+
private int index;
94+
95+
public RemovalEnumerator(List<(MTLCommandBuffer buffer, T value)> list, MTLCommandBuffer cb)
96+
{
97+
this.list = list;
98+
this.cb = cb;
99+
100+
count = list.Count;
101+
list.EnsureCapacity(count * 2);
102+
}
103+
104+
public bool MoveNext()
105+
{
106+
while (true)
107+
{
108+
if (index == count)
109+
return false;
110+
111+
if (list[index].buffer.Equals(cb))
112+
break;
113+
114+
// Track the item to be kept.
115+
list.Add(list[index]);
116+
index++;
117+
}
118+
119+
Current = list[index].value;
120+
index++;
121+
122+
return true;
123+
}
124+
125+
public void Reset()
126+
{
127+
index = 0;
128+
}
129+
130+
public T Current { get; private set; }
131+
132+
object IEnumerator.Current => Current;
133+
134+
public void Dispose()
135+
{
136+
if (list.Count == 0)
137+
return;
138+
139+
int toKeepItemCount = list.Count - count;
140+
var listSpan = CollectionsMarshal.AsSpan(list);
141+
142+
listSpan.Slice(count, toKeepItemCount).CopyTo(listSpan);
143+
list.RemoveRange(toKeepItemCount, list.Count - toKeepItemCount);
144+
}
145+
146+
public RemovalEnumerator GetEnumerator() => this;
147+
148+
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
149+
}
150+
}
151+
}

src/Veldrid/MTL/MTLCommandList.cs

+13-24
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ internal unsafe class MtlCommandList : CommandList
1515
private readonly MtlGraphicsDevice gd;
1616

1717
private readonly List<MtlBuffer> availableStagingBuffers = new List<MtlBuffer>();
18-
private readonly Dictionary<MTLCommandBuffer, List<MtlBuffer>> submittedStagingBuffers = new Dictionary<MTLCommandBuffer, List<MtlBuffer>>();
18+
private readonly CommandBufferUsageList<MtlBuffer> submittedStagingBuffers = new CommandBufferUsageList<MtlBuffer>();
1919
private readonly object submittedCommandsLock = new object();
20-
private readonly Dictionary<MTLCommandBuffer, MtlFence> completionFences = new Dictionary<MTLCommandBuffer, MtlFence>();
20+
private readonly CommandBufferUsageList<MtlFence> completionFences = new CommandBufferUsageList<MtlFence>();
2121

2222
private readonly Dictionary<UIntPtr, DeviceBufferRange> boundVertexBuffers = new Dictionary<UIntPtr, DeviceBufferRange>();
2323
private readonly Dictionary<UIntPtr, DeviceBufferRange> boundFragmentBuffers = new Dictionary<UIntPtr, DeviceBufferRange>();
@@ -85,12 +85,11 @@ public override void Dispose()
8585

8686
lock (submittedStagingBuffers)
8787
{
88-
foreach (var buffer in availableStagingBuffers) buffer.Dispose();
88+
foreach (var buffer in availableStagingBuffers)
89+
buffer.Dispose();
8990

90-
foreach (var kvp in submittedStagingBuffers)
91-
{
92-
foreach (var buffer in kvp.Value) buffer.Dispose();
93-
}
91+
foreach (var buffer in submittedStagingBuffers.EnumerateItems())
92+
buffer.Dispose();
9493

9594
submittedStagingBuffers.Clear();
9695
}
@@ -160,26 +159,20 @@ public void SetCompletionFence(MTLCommandBuffer cb, MtlFence fence)
160159
{
161160
lock (submittedCommandsLock)
162161
{
163-
Debug.Assert(!completionFences.ContainsKey(cb));
164-
completionFences[cb] = fence;
162+
Debug.Assert(!completionFences.Contains(cb));
163+
completionFences.Add(cb, fence);
165164
}
166165
}
167166

168167
public void OnCompleted(MTLCommandBuffer cb)
169168
{
170169
lock (submittedCommandsLock)
171170
{
172-
if (completionFences.TryGetValue(cb, out var completionFence))
173-
{
174-
completionFence.Set();
175-
completionFences.Remove(cb);
176-
}
171+
foreach (var fence in completionFences.EnumerateAndRemove(cb))
172+
fence.Set();
177173

178-
if (submittedStagingBuffers.TryGetValue(cb, out var bufferList))
179-
{
180-
availableStagingBuffers.AddRange(bufferList);
181-
submittedStagingBuffers.Remove(cb);
182-
}
174+
foreach (var buffer in submittedStagingBuffers.EnumerateAndRemove(cb))
175+
availableStagingBuffers.Add(buffer);
183176
}
184177
}
185178

@@ -1267,11 +1260,7 @@ private protected override void UpdateBufferCore(DeviceBuffer buffer, uint buffe
12671260
}
12681261

12691262
lock (submittedCommandsLock)
1270-
{
1271-
if (!submittedStagingBuffers.TryGetValue(cb, out var bufferList)) submittedStagingBuffers[cb] = bufferList = new List<MtlBuffer>();
1272-
1273-
bufferList.Add(staging);
1274-
}
1263+
submittedStagingBuffers.Add(cb, staging);
12751264
}
12761265

12771266
private protected override void GenerateMipmapsCore(Texture texture)

src/Veldrid/MTL/MTLGraphicsDevice.cs

+7-6
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ private static readonly Dictionary<IntPtr, MtlGraphicsDevice> s_aot_registered_b
4949
private readonly bool[] supportedSampleCounts;
5050

5151
private readonly object submittedCommandsLock = new object();
52-
private readonly Dictionary<MTLCommandBuffer, MtlCommandList> submittedCLs = new Dictionary<MTLCommandBuffer, MtlCommandList>();
52+
private readonly CommandBufferUsageList<MtlCommandList> submittedCLs = new CommandBufferUsageList<MtlCommandList>();
5353

5454
private readonly object resetEventsLock = new object();
5555
private readonly List<ManualResetEvent[]> resetEvents = new List<ManualResetEvent[]>();
@@ -385,11 +385,11 @@ private void OnCommandBufferCompleted(IntPtr block, MTLCommandBuffer cb)
385385
{
386386
lock (submittedCommandsLock)
387387
{
388-
var cl = submittedCLs[cb];
389-
submittedCLs.Remove(cb);
390-
cl.OnCompleted(cb);
388+
foreach (var cl in submittedCLs.EnumerateAndRemove(cb))
389+
cl.OnCompleted(cb);
391390

392-
if (latestSubmittedCb.NativePtr == cb.NativePtr) latestSubmittedCb = default;
391+
if (latestSubmittedCb.NativePtr == cb.NativePtr)
392+
latestSubmittedCb = default;
393393
}
394394

395395
ObjectiveCRuntime.release(cb.NativePtr);
@@ -459,7 +459,8 @@ private protected override void SubmitCommandsCore(CommandList commandList, Fenc
459459

460460
lock (submittedCommandsLock)
461461
{
462-
if (fence != null) mtlCl.SetCompletionFence(mtlCl.CommandBuffer, Util.AssertSubtype<Fence, MtlFence>(fence));
462+
if (fence != null)
463+
mtlCl.SetCompletionFence(mtlCl.CommandBuffer, Util.AssertSubtype<Fence, MtlFence>(fence));
463464

464465
submittedCLs.Add(mtlCl.CommandBuffer, mtlCl);
465466
latestSubmittedCb = mtlCl.Commit();

0 commit comments

Comments
 (0)