WebGPU Rendering: Part 11 Prefix Sum
Introduction
I have been reading through this online book on WebGPU. In this series of articles, I will be going through this book and implementing the lessons in a more structured typescript class approach and eventually we will build three types of WebGPU renderers: Gaussian Splatting, Ray tracing and Rasterization.
In this article we will talk about the compute shader, which is a new addition in WebGPU which was not present in WebGL. The compute shader can be used for generalized parallel computation instead of normal parallel rendering that the render pipeline we have been using is for.
The following link is the commit in my Github repo that matches the code we will go over.
Prefix Sum
A prefix sum is a precomputed array where each element at index i contains the sum of the array elements from the start up to i. This allows you to quickly calculate the sum of any subarray in constant time by subtracting two values from the prefix sum array. It’s especially useful for optimizing repeated range sum queries.
In a sequential setting we would compute the prefix sum array by just iterating over the array of numbers and keeping track of a cumulative sum.
Example:
For [3,4,1,5]
- The first value is 0,
[0]
- Then we add 3 (the first element),
[0,3]
- Then add the element — 4,
[0, 3, 7]
- and so on…
[0, 3, 7, 8, 13]
The algorithm is very simple in the sequential world, but when we cannot loop over the array — as is the case in parallel computation — we will require multiple GPU compute passes to generate the resulting prefix sum array.
The sequential algorithm is O(n) time complexity, however — with parallelism — we can get an O(logn) run time by splitting the work up into a tree like setup and running each tree level in parallel
Example:
For [3,4,1,5]
We would work at our “leaf” nodes first, combining the siblings. Giving us 5+1, 1+4 and 4+3 as this level’s computation and we would store the results in the right child
[3,7,5,6]
Then we multiply the offset by 2, simulating moving up a level on this binary tree. Our operations are 6+7 and 5 + 3
[3,7,8,13]
We will need to do logn parallel passes to simulate this tree sum approach
Shaders
To preform Prefix Sum in the GPU we will need three separate shaders.
The prefix sum shader establishes the groundwork for efficient data processing in the prefix sum computation by preparing localized partial results, mitigating memory access conflicts, and laying the foundation for the subsequent shaders to complete the operation.
The scan sum shader will take the local prefix sums and combine them together to form a global one.
Finally, the add sum shader will dump the results into our output array buffer.
Prefix Sum
First, we will look at some input variables and constants in our first shader.
@binding(0) @group(0) var<storage, read> input: array<f32>;
@binding(1) @group(0) var<storage, read_write> output: array<f32>;
@binding(2) @group(0) var<storage, read_write> sums: array<f32>;
const n: u32 = 512;
var<workgroup> temp: array<f32, 532>; // why 532??
const bank_size:u32 = 32;
fn bank_conflict_free_idx(idx: u32) -> u32 {
var chunk_id:u32 = idx / bank_size;
return idx + chunk_id;
}
input
: The input array off32
values, providing the raw data for the prefix sum.output
: The output array, where the partially computed prefix sums for this shader will be written.sums
: A shared buffer to store intermediate sums from each workgroup, allowing subsequent shaders to combine them.- The function
bank_conflict_free_idx(idx)
adjusts array indices to include padding, ensuring access remains efficient
Next, we will initialize our values in our shader, we have a max 512 array size and 256 threads in our workgroup, so each thread will need to manage two input values.
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) GlobalInvocationID: vec3<u32>, @builtin(local_invocation_id) LocalInvocationId: vec3<u32>, @builtin(workgroup_id) WorkgroupID: vec3<u32>) {
var thread_id:u32 = LocalInvocationId.x;
var global_thread_id: u32 = GlobalInvocationID.x;
if (thread_id < (n >>1)) {
temp[bank_conflict_free_idx(2* thread_id)] = input[2*global_thread_id];
temp[bank_conflict_free_idx(2* thread_id + 1)] = input[2*global_thread_id + 1];
}
workgroupBarrier();
... MORE CODE
}
workgroupBarrier()
ensures all threads within a workgroup have reached the same point before proceeding, maintaining data consistency during the shared memory read operations.
Then, we will preform the upsweep phase. This phase computes intermediate sums in a tree structure. The offset doubles in each iteration, and pairs of elements are added together.
var offset:u32 = 1;
for (var d:u32 = n >> 1; d > 0; d >>= 1) {
if (thread_id < d) {
var ai: u32 = offset * (2 * thread_id + 1) - 1;
var bi: u32 = offset * (2 * thread_id + 2) - 1;
temp[bank_conflict_free_idx(bi)] += temp[bank_conflict_free_idx(ai)];
}
offset *= 2;
}
workgroupBarrier();
This allows us to preform a sort of merge sort-esq approach to our summation, storing the sum of each pair in the binary tree representation of the array in the right sibling based value in the temp array.
The last element in the workgroup (temp[n - 1]
) contains the total sum for that workgroup. This value is saved to the sums
buffer for later processing by subsequent shaders. Additionally, it is reset to 0.0
to prepare for the downsweep phase
if (thread_id == 0) {
sums[WorkgroupID.x] = temp[bank_conflict_free_idx(n - 1)];
temp[bank_conflict_free_idx(n - 1)] = 0.0;
}
workgroupBarrier();
Then, we will proceed to the downsweep phase. This phase propagates the intermediate sums back down the tree, finalizing the prefix sum for the local workgroup.
for (var d: u32 = 1; d < n; d *= 2) {
offset >>= 1;
if (thread_id < d) {
var ai: u32 = offset * (2 * thread_id + 1) - 1;
var bi: u32 = offset * (2 * thread_id + 2) - 1;
var t: f32 = temp[bank_conflict_free_idx(ai)];
temp[bank_conflict_free_idx(ai)] = temp[bank_conflict_free_idx(bi)];
temp[bank_conflict_free_idx(bi)] += t;
}
workgroupBarrier();
}
As we sweep down, we set the left child to have the sum of the right one and the right one will be the value of itself + the old value of the left.
Finally, we will store our intermediate prefix sum values to our output buffer.
if (thread_id < (n>>1)) {
output[2*global_thread_id] = temp[bank_conflict_free_idx(2*thread_id)];
output[2*global_thread_id + 1] = temp[bank_conflict_free_idx(2*thread_id + 1)];
}
Scan Sum
This shader does pretty much the exact same thing that the first shader does, except it expects the entire array to be in one work group (i.e. len(arr) ≤ 256).
@binding(0) @group(0) var<storage, read> input: array<f32>;
@binding(1) @group(0) var<storage, read_write> output: array<f32>;
@binding(2) @group(0) var<uniform> n: u32;
const bank_size: u32 = 32;
var<workgroup> temp: array<f32, 532>; // why 532??
fn bank_conflict_free_idx(idx: u32) -> u32 {
var chunk_id: u32 = idx / bank_size;
return idx + chunk_id;
}
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) GlobalInvocationID: vec3<u32>, @builtin(local_invocation_id) LocalInvocationId: vec3<u32>, @builtin(workgroup_id) WorkgroupID: vec3<u32>) {
var thread_id: u32 = LocalInvocationId.x;
var global_thread_id: u32 = GlobalInvocationID.x;
if (thread_id < (n>>1)) {
temp[bank_conflict_free_idx(2*thread_id)] = input[2*global_thread_id];
temp[bank_conflict_free_idx(2*thread_id + 1)] = input[2*global_thread_id + 1];
}
workgroupBarrier();
var offset: u32 = 1;
for (var d:u32 = n>>1; d > 0; d >>= 1) {
if (thread_id < d) {
var ai: u32 = offset * (2 * thread_id + 1) - 1;
var bi: u32 = offset * (2 * thread_id + 2) - 1;
temp[bank_conflict_free_idx(bi)] += temp[bank_conflict_free_idx(ai)];
}
offset *= 2;
workgroupBarrier();
}
if (thread_id == 0) {
temp[bank_conflict_free_idx(n - 1)] = 0.0;
}
workgroupBarrier();
for (var d: u32 = 1; d < n; d *= 2) {
offset >>= 1;
if (thread_id < d) {
var ai: u32 = offset * (2 * thread_id + 1) - 1;
var bi: u32 = offset * (2 * thread_id + 2) - 1;
var t: f32 = temp[bank_conflict_free_idx(ai)];
temp[bank_conflict_free_idx(ai)] = temp[bank_conflict_free_idx(bi)];
temp[bank_conflict_free_idx(bi)] += t;
}
workgroupBarrier();
}
if (thread_id < (n>>1)) {
output[2*global_thread_id] = temp[bank_conflict_free_idx(2*thread_id)];
output[2*global_thread_id + 1] = temp[bank_conflict_free_idx(2*thread_id + 1)];
}
}
The first shader performed a parallel prefix sum for each chunk of the input array, writing the partial results to an output buffer and the final sums of each chunk to a separate sums
buffer. The second shader processes the sums
buffer to compute a global prefix sum across all chunks, producing offsets for each chunk. Finally, these offsets are added to the partial results from the first shader, ensuring a globally consistent prefix sum across the entire input array.
Add Sum
When we are all done, we can sum the values in our sums array to generate a complete output array.
@binding(0) @group(0) var<storage, read_write> output: array<f32>;
@binding(1) @group(0) var<storage, read> sums: array<f32>;
const n:u32 = 512;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) GlobalInvocationID: vec3<u32>, @builtin(local_invocation_id) LocalInvocationId: vec3<u32>, @builtin(workgroup_id) WorkgroupID: vec3<u32>) {
var thread_id: u32 = LocalInvocationId.x;
var global_thread_id: u32 = GlobalInvocationID.x;
if (thread_id < (n>>1)) {
output[2*global_thread_id] += sums[WorkgroupID.x];
output[2*global_thread_id + 1] += sums[WorkgroupID.x];
}
}
WebGPUCompute
To help us run our shaders, we will create a WebGPUCompute class to handle and encapsulate our GPU communication.
class WebGPUComputeContext {
private static _instance: WebGPUComputeContext | null = null;
private _device: GPUDevice;
public static async create() {
if (WebGPUComputeContext._instance) {
return { instance: WebGPUComputeContext._instance };
}
// make sure gpu is supported
if (!navigator.gpu) {
return { error: "WebGPU not supported" };
}
//grab the adapter
const adapter = await navigator.gpu.requestAdapter();
if (!adapter) {
return { error: "Failed to get WebGPU adapter" };
}
//create the device (should be done immediately after adapter in case adapter is lost)
const device = await adapter.requestDevice();
if (!device) {
return { error: "Failed to get WebGPU device" };
}
WebGPUComputeContext._instance = new WebGPUComputeContext(device);
return { instance: WebGPUComputeContext._instance };
}
private constructor(device: GPUDevice) {
this._device = device;
}
private _createShaderModule(source: string) {
const shaderModule = this._device.createShaderModule({ code: source });
return shaderModule;
}
public createGPUBuffer(data: Float32Array | Uint16Array | Uint32Array, usage: GPUBufferUsageFlags): GPUBuffer {
const bufferDesc: GPUBufferDescriptor = {
size: data.byteLength,
usage: usage,
mappedAtCreation: true
}
const buffer = this._device.createBuffer(bufferDesc);
if (data instanceof Float32Array) {
const writeArray = new Float32Array(buffer.getMappedRange());
writeArray.set(data);
} else if (data instanceof Uint16Array) {
const writeArray = new Uint16Array(buffer.getMappedRange());
writeArray.set(data);
} else if (data instanceof Uint32Array) {
const writeArray = new Uint32Array(buffer.getMappedRange());
writeArray.set(data);
}
buffer.unmap();
return buffer;
}
}
We will create a bunch of familiar helper functionality to construct our GPU device and upload buffers which we will use when running prefix sum. A lot of this code should look very familiar to our WebGPUContext from our rendering tutorials.
Prefix Sum Function
We will encapsulate running Prefix sum into a single function.
public async prefix_sum(input: Float32Array): Promise<Float32Array> {
const pass1ShaderModule = this._createShaderModule(prefixSum);
const pass2ShaderModule = this._createShaderModule(scanSum);
const pass3ShaderModule = this._createShaderModule(addSum);
const pass1UniformBindGroupLayout = this._device.createBindGroupLayout({
entries: [
{
binding: 0,
visibility: GPUShaderStage.COMPUTE,
buffer: { type: 'read-only-storage' }
},
{
binding: 1,
visibility: GPUShaderStage.COMPUTE,
buffer: { type: "storage" }
},
{
binding: 2,
visibility: GPUShaderStage.COMPUTE,
buffer: { type: "storage" }
}
]
});
const pass2UniformBindGroupLayout = this._device.createBindGroupLayout({
entries: [
{
binding: 0,
visibility: GPUShaderStage.COMPUTE,
buffer: { type: 'read-only-storage' }
},
{
binding: 1,
visibility: GPUShaderStage.COMPUTE,
buffer: { type: "storage" }
},
{
binding: 2,
visibility: GPUShaderStage.COMPUTE,
buffer: {}
}
]
});
const pass3UniformBindGroupLayout = this._device.createBindGroupLayout({
entries: [
{
binding: 0,
visibility: GPUShaderStage.COMPUTE,
buffer: { type: "storage" }
},
{
binding: 1,
visibility: GPUShaderStage.COMPUTE,
buffer: { type: "read-only-storage" }
}
]
});
const arraySize = input.length;
const chunkCount = Math.ceil(arraySize / 512);
// get nearest power of 2 for chunkCount
let powerOf2 = 1;
while (powerOf2 < chunkCount) {
powerOf2 *= 2;
}
const inputArrayBuffer = this.createGPUBuffer(new Float32Array(input), GPUBufferUsage.STORAGE);
const outputArrayBuffer = this.createGPUBuffer(new Float32Array(input), GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC);
const readOutputArrayBuffer = this.createGPUBuffer(new Float32Array(input), GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST);
const sumArrayBuffer = this.createGPUBuffer(new Float32Array(powerOf2), GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC);
const outputSumArrayBuffer = this.createGPUBuffer(new Float32Array(powerOf2), GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC);
const readSumArrayBuffer = this.createGPUBuffer(new Float32Array(powerOf2), GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST);
const sumSizeBuffer = this.createGPUBuffer(new Uint32Array([powerOf2]), GPUBufferUsage.UNIFORM);
const pass1UniformBindGroup = this._device.createBindGroup({
layout: pass1UniformBindGroupLayout,
entries: [
{
binding: 0,
resource: {
buffer: inputArrayBuffer
}
},
{
binding: 1,
resource: {
buffer: outputArrayBuffer
}
},
{
binding: 2,
resource: {
buffer: sumArrayBuffer
}
}
]
});
const pass2UniformBindGroup = this._device.createBindGroup({
layout: pass2UniformBindGroupLayout,
entries: [
{
binding: 0,
resource: {
buffer: sumArrayBuffer
}
},
{
binding: 1,
resource: {
buffer: outputSumArrayBuffer
}
},
{
binding: 2,
resource: {
buffer: sumSizeBuffer
}
}
]
});
const pass3UniformBindGroup = this._device.createBindGroup({
layout: pass3UniformBindGroupLayout,
entries: [
{
binding: 0,
resource: {
buffer: outputArrayBuffer
}
},
{
binding: 1,
resource: {
buffer: outputSumArrayBuffer
}
}
]
});
const pass1Pipeline = this._device.createComputePipeline({
layout: this._device.createPipelineLayout({
bindGroupLayouts: [pass1UniformBindGroupLayout]
}),
compute: {
module: pass1ShaderModule,
entryPoint: 'main'
}
});
const pass2Pipeline = this._device.createComputePipeline({
layout: this._device.createPipelineLayout({
bindGroupLayouts: [pass2UniformBindGroupLayout]
}),
compute: {
module: pass2ShaderModule,
entryPoint: 'main'
}
});
const pass3Pipeline = this._device.createComputePipeline({
layout: this._device.createPipelineLayout({
bindGroupLayouts: [pass3UniformBindGroupLayout]
}),
compute: {
module: pass3ShaderModule,
entryPoint: 'main'
}
});
const computePassDescriptor = {};
const commandEncoder = this._device.createCommandEncoder();
const passEncoder1 = commandEncoder.beginComputePass(computePassDescriptor);
passEncoder1.setPipeline(pass1Pipeline);
passEncoder1.setBindGroup(0, pass1UniformBindGroup);
passEncoder1.dispatchWorkgroups(chunkCount);
passEncoder1.end();
const passEncoder2 = commandEncoder.beginComputePass(computePassDescriptor);
passEncoder2.setPipeline(pass2Pipeline);
passEncoder2.setBindGroup(0, pass2UniformBindGroup);
passEncoder2.dispatchWorkgroups(1);
passEncoder2.end();
const passEncoder3 = commandEncoder.beginComputePass(computePassDescriptor);
passEncoder3.setPipeline(pass3Pipeline);
passEncoder3.setBindGroup(0, pass3UniformBindGroup);
passEncoder3.dispatchWorkgroups(chunkCount);
passEncoder3.end();
commandEncoder.copyBufferToBuffer(outputArrayBuffer, 0, readOutputArrayBuffer, 0, input.length * 4);
commandEncoder.copyBufferToBuffer(outputSumArrayBuffer, 0, readSumArrayBuffer, 0, powerOf2 * 4);
await this._device.queue.submit([commandEncoder.finish()]);
await readOutputArrayBuffer.mapAsync(GPUMapMode.READ, 0, input.length * 4);
const outputArray = new Float32Array(readOutputArrayBuffer.getMappedRange());
return outputArray;
}
This code will take in the input array and pass it through into our first shader which will put data in our sumArray and output buffers
Then, our second shader will take in the sumArrayBuffer and the sumSizeBuffer value which it will populate the outputSumArrayBuffer.
Finally, we will use the outputSumArrayBuffer and use them to put our final values in the output buffer.
Conclusion
In this article, we did our first experimentation with the compute shader and used it improve our speed from an O(n) algorithm in a sequential setting to O(logn) time with parallelism. The sequential prefix sum algorithm was much much simpler but these kinds of GPU computations will shine in their performance boost when it is something we must do on each render pass.