Skip to content

Commit

Permalink
More changes to matmul blog
Browse files Browse the repository at this point in the history
* Add header for dispatch
* Move to 16,16 workgroups everywhere 2D
* Make it clear dispatch count it running on CPU not dispatches are for CPU
  • Loading branch information
LegNeato committed Nov 20, 2024
1 parent 9fd6eeb commit f7ff7d4
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ impl Gpu for Workgroup2d {

impl GridComputation for Workgroup2d {
fn workgroup(&self) -> UVec3 {
UVec3::new(8, 8, 1)
UVec3::new(16, 16, 1)
}

fn dispatch_count(&self, m: u32, n: u32) -> UVec3 {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use settings::Dimensions;
use spirv_std::glam::UVec3;
use spirv_std::spirv;

#[spirv(compute(threads(8, 8)))]
#[spirv(compute(threads(16, 16)))]
pub fn matmul(
#[spirv(global_invocation_id)] global_id: UVec3,
#[spirv(uniform, descriptor_set = 0, binding = 0)] dimensions: &Dimensions,
Expand Down
7 changes: 5 additions & 2 deletions blog/2024-11-21-optimizing-matrix-mul/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,10 @@ examples.

:::

Each workgroup, since it’s only one thread, processes one `result[i, j]`.
#### Dispatching workgroups

Each workgroup, since it’s only one thread (`#[spirv(compute(threads(1)))]`), processes
one `result[i, j]`.

To calculate the full matrix, we need to launch as many entries as there are in the
matrix. Here we specify that (`Uvec3::new(m * n, 1, 1`) on the CPU:
Expand Down Expand Up @@ -241,7 +244,7 @@ Although we don't change much about our code, if we distribute our work in 2 dim
we're able to bypass these limits and launch more workgroups that are larger. This
allows us to calculate a 4096x4096 matmul.

We update our `compute(threads(256)))` to `compute(threads((8, 8)))`, and make the small
We update our `compute(threads(256)))` to `compute(threads((16, 16)))`, and make the small
change to `row` and `col` from Zach's post to increase speed:

import { RustWorkgroup2d } from './snippets/workgroup_2d.tsx';
Expand Down
4 changes: 2 additions & 2 deletions blog/2024-11-21-optimizing-matrix-mul/snippets/naive.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ export const RustNaiveWorkgroupCount: React.FC = () => (
language="rust"
className="text-xs"
lines="26-34"
title="Calculating how many workgroup dispatches are needed on the CPU"
title="Calculating on the CPU how many workgroup dispatches are needed"
>
{RustWorkgroupCount}
</Snippet>
Expand All @@ -71,7 +71,7 @@ export const RustNaiveDispatch: React.FC = () => (
className="text-xs"
lines="145,147"
strip_leading_spaces
title="Using wgpu on the CPU to dispatch to the GPU"
title="Using wgpu on the CPU to dispatch workgroups to the GPU"
>
{RustWgpuBackend}
</Snippet>
Expand Down

0 comments on commit f7ff7d4

Please sign in to comment.