View on GitHub

Gaussian Splatting on Metal Dev Log


kernel void bitonic_swap_asc(device unsigned int* left, device unsigned int* right, uint index [[thread_position_in_grid]])
    unsigned int l = left[index];
    unsigned int r = right[index];
    left[index] = min(l, r);
    right[index] = max(l, r);

The complexity of which elements should be swapped will be encoded in the command buffer. This means the earlier layers will have many more commands (operating over less data though). Doubt this will be optimal but interested to see both if it’s correct without any memory fences, and how it performs.

Somehow it worked first-try! 🤯

Performance wise, it’s not good:

Generating 1048576 random integers
Generated 1048576 random integers
std::sort() execution time: 21 ms
sort_radix() execution time: 35 ms
sort_bitonic() execution time: 53 ms
[0ms] - Starting bitonic encoding
[2665ms | Δ2665ms] - Finished bitonic encoding
[37180ms | Δ34515ms] - Bitonic execution completed
bitonic_sort_gpu() execution time: 37180 ms

The previous O(n^2) GPU sort took around 27 seconds for 1 million values, while this method is taking over 37 seconds despite the fact it’s meant to be O(n log^2 n). It’s also weird just how long it’s spending encoding the commands, so I suspect there’s a bug where I’m encoding way too many but because it’s idempotent it’s not impacting the results.


Bitonic sort is ~3x slower than the other algorithms at this data scale:

Generating 1048576 random integers
Generated 1048576 random integers
std::sort() execution time: 29 ms
sort_radix() execution time: 34 ms
sort_bitonic() execution time: 105 ms

Extending to about 10x slower at higher data scale:

Generating 16777216 random integers
Generated 16777216 random integers
std::sort() execution time: 296 ms
sort_radix() execution time: 523 ms
sort_bitonic() execution time: 2297 ms

Spent a little bit of time trying out some perf optimisations. Initially, the hot loop looked like this:

void bitonic_split(std::vector<unsigned int> &bitonic_seq, const int start, const int end, bool ascending) {
    int diff = (end - start) / 2;
    for (int i = start; i < start + diff; i++) {
        int j = i + diff;
        unsigned int left = bitonic_seq[i];
        unsigned int right = bitonic_seq[j];
        if (ascending) {
            bitonic_seq[i] = std::min(left, right);
            bitonic_seq[j] = std::max(left, right);
        } else {
            bitonic_seq[i] = std::max(left, right);
            bitonic_seq[j] = std::min(left, right);

The first thing I tried out was splitting into two separate functions (bitonic_split_asc and bitonic_split_dec), and using a single if statement for the swapping. I figured that would reduce branching and memory reads, but it had a neglible effect (both ran at about 2.6s for 16777216 uints). I also tried replacing the two if branches with a single more complex if (if ((left < right && !ascending) || (right < left && ascending))) and direct assignment. This was a lot slower than the min/max approach (~3.3s). By far the largest perf gain came from combining the min/max and the separate asc/dec functions:

Generating 16777216 random integers
Generated 16777216 random integers
std::sort() execution time: 348 ms
sort_radix() execution time: 599 ms
sort_bitonic() execution time: 1196 ms

An impressive 2x speedup for that change alone, contrary to expectations. I guess branch-prediction failures are much more expensive than the extra memory writes? And maybe something about the cache hierarchy makes writing/not writing fairly cheap.


Algorithms / Reference

The NVIDIA book chapter mentions two algorithms for sorting. The first is the O(n^2) technique covered in the other blog post, which can be implemented trivially with a single shader (or with two that alternate, but I eventually realised you can just shift the offsets and grid size and re-use the same kernel the whole time.)

kernel void slow_sort(device unsigned int* data, uint index [[thread_position_in_grid]])
    uint idx = index*2;
    uint left = data[idx];
    uint right = data[idx+1];

    if (left < right) {
        data[idx] = left;
        data[idx+1] = right;
    } else {
        data[idx] = right;
        data[idx+1] = left;

This kernel must be run n times such that any element at the start of the list can swap it’s way to the end.

The “odd-even merge sort” is then described as an algorithm that sorts odd and even keys separately, then merges them. The stages are then scaled up in powers of two until the whole array is sorted. Unlike the previous algorithm this needs log n passes and results in an O(n^2 log n) runtime. The formatting for their code is broken - here’s the (CUDA) kernel they provide that implements the algorithm (comments are theirs):

uniform vec3 Param1;
uniform vec3 Param2;
uniform sampler2D Data;
#define OwnPos gl_TexCoord[0]
// contents of the uniform data fields
#define TwoStage Param1.x
#define Pass_mod_Stage Param1.y
#define TwoStage_PmS_1 Param1.z
#define Width Param2.x
#define Height Param2.y
#define Pass Param2.z
void main(void)  {
    // get self
    vec4 self = texture2D(Data, OwnPos.xy);
    float i = floor(OwnPos.x * Width) + floor(OwnPos.y * Height) * Width;
    // my position within the range to merge
    float j = floor(mod(i, TwoStage));
    float compare;
    if ( (j < Pass_mod_Stage) || (j > TwoStage_PmS_1) )
        // must copy -> compare with self
        compare = 0.0;
    else if ( mod((j + Pass_mod_Stage) / Pass, 2.0) < 1.0)
        // we are on the left side -> compare with partner on the right
        compare = 1.0;
        // we are on the right side -> compare with partner on the left
        compare = -1.0;

    // get the partner
    float adr = i + compare * Pass;
    vec4 partner = texture2D(Data, vec2(floor(mod(adr, Width)) / Width,floor(adr / Width) / Height));
    // on the left it's a < operation; on the right it's a >= operation
    gl_FragColor = (self.x * compare < partner.x * compare) ? self : partner;

This is called with some specific pass loops on the CPU. If you squint, this looks pretty similar to the code above with a few extra steps, so I’m assuming there’s some aspect here about which sections of the texture are passed to the shader for performing the sort. The article itself seems focused on how to implement the algorithm efficiently using fragment and vertex shaders, and seems to have been written before the advent of GPGPU given the amount of time spent dealing with the specific restrictions they have. It is interesting to read about those considerations, and I’d be curious to know how much is still relevant even when writing pure-compute shaders.

Development Log

Elements CPU GPU
32 1µs 5773µs
65536 4993µs 9034348µs (over 9 seconds!)
// array size of 65536
std::sort() execution time: 4030 µs
sort_radix() execution time: 7227 µs
[0µs] - Starting encoding pass
[9µs | Δ9µs] - Finished encoding
[9µs | Δ0µs] - Committed command buffer
[303µs | Δ293µs] - Execution completed
slow_sort_gpu() execution time: 303853 µs

// array size of 1048576
std::sort() execution time: 58079 µs
sort_radix() execution time: 96485 µs
[0µs] - Starting encoding pass
[112µs | Δ112µs] - Finished encoding
[112µs | Δ0µs] - Committed command buffer
[26947µs | Δ26834µs] - Execution completed
slow_sort_gpu() execution time: 26948016 µs

Looks like the vast majority of the time is the compute, 9 ms vs 112 ms difference in encoding is a little over 10x for a data scale increase of 16, but the compute time dwarfs it by an order of magnitude.

Interestingly, it turns out that using an if(a<b) implementation is consistently a few milliseconds faster than an x = max(a,b); y = min(a,b) one (~275ms vs ~250ms), which is the opposite of what I’d have expected. Will need to enable profiling to understand what’s going on here, interesting stuff.



Today: lets get started on writing a sort in Metal and see how far we can get.

Doubling 16777216 values:

Elements GPU CPU
1 9460µs 4437µs
2 8536µs 3672µs
8 7877μs 3148μs
64 9367µs 3747µs

(the CPU value should technically be the same every time, so this really just shows some of the variance in this simple testing script).

Seems like will need to do a little more work to optimise the Metal compute. It’s also possible that despite this being a very parallelisable problem, CPUs are just efficient enough at it and it’s memory-bound enough that the parallel advantage isn’t coming through.

Unlogged Work

Started on this a few months back but forgot to do any writeups.

Decided to do some C++ / Metal integration as that felt like a relatively well-supported/well trodden path. Ran into some annoyances with the build system as usual, but ended up with a simple enough Makefile for both the metal shaders and the C++ part.

The two things that make it annoying:

  1. Metal shaders have a separate proprietary compiler and have to be pulled into a bundle in their own way. Annoyingly, there was a weird bug that led to different behaviour with the same shader when I compiled/combined the shaders in one particular way (I forget which).
  2. Showing a window with the results is done via Objective-C (Objective-C++ in this case), meaning two ways to define classes and manage memory.


Gaussian splats take a huge number of values, sort them by distance from camera (I think), and then remove occluded values. One of the reasons this works so well is that GPUs can do sorting very very fast. CUDA has some sophisticated GPU sorting libraries already, but Metal doesn’t seem to - so, writing a sorting algorithm in Metal seemed like a good starting point.

There are two algorithms that commonly get implemented on GPUs - Bitonic Sort and Radix Sort. I spent a while trying to figure out how the bitonic sorting algorithm actually works before switching to a simple Radix and implementing it in raw C++ as a baseline for performance. It took a while to get into the swing of that kind of coding, but eventually wrote a working in-place binary radix sort (phew). It performs pretty well - just 2x slower than std::sort on large random data.

Generating 33554432 random integers
Generated 33554432 random integers
std::sort() execution time: 2194525 microseconds
sort_radix() execution time: 3719207 microseconds