From 3da8a6a376a83593b663e782872a046ee2c59926 Mon Sep 17 00:00:00 2001 From: Christian Legnitto Date: Wed, 20 Nov 2024 19:41:22 -0400 Subject: [PATCH] More changes to matmul blog (#28) * Reword and move dispatch around * More changes to matmul blog * Add header for dispatch * Move to 16,16 workgroups everywhere 2D * Make it clear dispatch count it running on CPU not dispatches are for CPU --- .../code/crates/cpu/matmul/src/variants.rs | 2 +- .../code/crates/gpu/workgroup_2d/src/lib.rs | 2 +- .../2024-11-21-optimizing-matrix-mul/index.md | 90 +++++++++++++------ .../snippets/naive.tsx | 25 ++++++ .../snippets/workgroup_256.tsx | 13 --- 5 files changed, 90 insertions(+), 42 deletions(-) diff --git a/blog/2024-11-21-optimizing-matrix-mul/code/crates/cpu/matmul/src/variants.rs b/blog/2024-11-21-optimizing-matrix-mul/code/crates/cpu/matmul/src/variants.rs index 357d73b..79058ab 100644 --- a/blog/2024-11-21-optimizing-matrix-mul/code/crates/cpu/matmul/src/variants.rs +++ b/blog/2024-11-21-optimizing-matrix-mul/code/crates/cpu/matmul/src/variants.rs @@ -81,7 +81,7 @@ impl Gpu for Workgroup2d { impl GridComputation for Workgroup2d { fn workgroup(&self) -> UVec3 { - UVec3::new(8, 8, 1) + UVec3::new(16, 16, 1) } fn dispatch_count(&self, m: u32, n: u32) -> UVec3 { diff --git a/blog/2024-11-21-optimizing-matrix-mul/code/crates/gpu/workgroup_2d/src/lib.rs b/blog/2024-11-21-optimizing-matrix-mul/code/crates/gpu/workgroup_2d/src/lib.rs index 54e5021..9faafdd 100644 --- a/blog/2024-11-21-optimizing-matrix-mul/code/crates/gpu/workgroup_2d/src/lib.rs +++ b/blog/2024-11-21-optimizing-matrix-mul/code/crates/gpu/workgroup_2d/src/lib.rs @@ -4,7 +4,7 @@ use settings::Dimensions; use spirv_std::glam::UVec3; use spirv_std::spirv; -#[spirv(compute(threads(8, 8)))] +#[spirv(compute(threads(16, 16)))] pub fn matmul( #[spirv(global_invocation_id)] global_id: UVec3, #[spirv(uniform, descriptor_set = 0, binding = 0)] dimensions: &Dimensions, diff --git a/blog/2024-11-21-optimizing-matrix-mul/index.md b/blog/2024-11-21-optimizing-matrix-mul/index.md index c1fb5a6..e1c7112 100644 --- a/blog/2024-11-21-optimizing-matrix-mul/index.md +++ b/blog/2024-11-21-optimizing-matrix-mul/index.md @@ -88,6 +88,22 @@ platforms, including Windows, Linux, macOS, iOS[^1], Android, and the web[^2]. By using Rust GPU and `wgpu`, we have a clean, portable setup with everything written in Rust. +## GPU program basics + +The smallest unit of execution is a thread, which executes the GPU program. + +Workgroups are groups of threads: they are grouped together and run in parallel (they’re +called [thread blocks in +CUDA]()). They can access +the same shared memory. + +We can dispatch many of these workgroups at once. CUDA calls this a grid (which is made +of thread blocks). + +Workgroups and dispatching workgroups are defined in 3D. The size of a workgroup is +defined by `compute(threads((x, y, z)))` where the number of threads per workgroup is +x \* y \* z. + ## Writing the kernel ### Kernel 1: Naive kernel @@ -159,6 +175,35 @@ examples. ::: +#### Dispatching workgroups + +Each workgroup, since it’s only one thread (`#[spirv(compute(threads(1)))]`), processes +one `result[i, j]`. + +To calculate the full matrix, we need to launch as many entries as there are in the +matrix. Here we specify that (`Uvec3::new(m * n, 1, 1`) on the CPU: + +import { RustNaiveWorkgroupCount } from './snippets/naive.tsx'; + + + +The `dispatch_count()` function runs on the CPU and is used by the CPU-to-GPU API (in +our case `wgpu`) to configure and dispatch work to the GPU: + +import { RustNaiveDispatch } from './snippets/naive.tsx'; + + + +:::warning + +This code appears more complicated than it needs to be. I abstracted the CPU-side code +that talks to the GPU using generics and traits so I could easily slot in different +kernels and their settings while writing the blog post. + +You could just hardcode the value for simplicity. + +::: + ### Kernel 2: Moarrr threads! With the first kernel, we're only able to compute small square matrices due to limits on @@ -187,33 +232,19 @@ import { RustWorkgroup256WorkgroupCount } from './snippets/workgroup_256.tsx'; -The `dispatch_count()` function runs on the CPU and is used by the CPU-to-GPU API (in -our case `wgpu`) to configure and dispatch to the GPU: - -import { RustWorkgroup256WgpuDispatch } from './snippets/workgroup_256.tsx'; - - - -:::warning - -This code appears more complicated than it needs to be. I abstracted the CPU-side code -that talks to the GPU using generics and traits so I could easily slot in different -kernels and their settings while writing the blog post. - -You could just hardcode a value for simplicity. - -::: +With these two small changes we can handle larger matrices without hitting hardware +workgroup limits. ### Kernel 3: Calculating with 2D workgroups -However doing all the computation in "1 dimension" limits the matrix size we can +However, doing all the computation in "1 dimension" still limits the matrix size we can calculate. Although we don't change much about our code, if we distribute our work in 2 dimensions we're able to bypass these limits and launch more workgroups that are larger. This allows us to calculate a 4096x4096 matmul. -We update our `compute(threads(256)))` to `compute(threads((8, 8)))`, and make the small +We update our `compute(threads(256)))` to `compute(threads((16, 16)))`, and make the small change to `row` and `col` from Zach's post to increase speed: import { RustWorkgroup2d } from './snippets/workgroup_2d.tsx'; @@ -257,24 +288,29 @@ import { RustTiling2dSimd } from './snippets/tiling_2d_simd.tsx'; Each thread now calculates a 4x4 grid of the output matrix and we see a slight improvement over the last kernel. +To stay true to the spirit of Zach's original blog post, we'll wrap things up here and +leave the "fancier" experiments for another time. + ## Reflections on porting to Rust GPU Porting to Rust GPU went quickly, as the kernels Zach used were fairly simple. Most of the time was spent with concerns that were not specifically about writing GPU code. For example, deciding how much to abstract vs how much to make the code easy to follow, if everything should be available at runtime or if each kernel should be a compilation -target, etc. The code is not _great_ as it is still blog post code! +target, etc. [The +code](https://github.com/Rust-GPU/rust-gpu.github.io/tree/main/blog/2024-11-21-optimizing-matrix-mul/code) +is not _great_ as it is still blog post code! My background is not in GPU programming, but I do have Rust experience. I joined the Rust GPU project because I tried to use standard GPU languages and knew there must be a better way. Writing these GPU kernels felt like writing any other Rust code (other than -debugging, more on that later) which is a huge win to me. Not only the language itself, +debugging, more on that later) which is a huge win to me. Not just the language itself, but the entire development experience. ## Rust-specific party tricks Rust lets us write code for both the CPU and GPU in ways that are often impossible—or at -least less elegant—with other languages. I'm going to highlight some benefits of Rust I +least less elegant—with other languages. I'm going to highlight some benefits I experienced while working on this blog post. ### Shared code across GPU and CPU @@ -351,8 +387,9 @@ Testing the kernel in isolation is useful, but it does not reflect how the GPU e it with multiple invocations across workgroups and dispatches. To test the kernel end-to-end, I needed a test harness that simulated this behavior on the CPU. -Building the harness was straightforward. By enforcing the same invariants as the GPU I -could validate the kernel under the same conditions the GPU would run it: +Building the harness was straightforward due to the borrow checker. By enforcing the +same invariants as the GPU I could validate the kernel under the same conditions the GPU +would run it: import { RustCpuBackendHarness } from './snippets/party.tsx'; @@ -484,10 +521,9 @@ future. This kernel doesn't use conditional compilation, but it's a key feature of Rust that works with Rust GPU. With `#[cfg(...)]`, you can adapt kernels to different hardware or configurations without duplicating code. GPU languages like WGSL or GLSL offer -preprocessor directives, but these tools lack standardization across ecosystems. Rust -GPU leverages the existing Cargo ecosystem, so conditional compilation follows the same -standards all Rust developers already know. This makes adapting kernels for different -targets easier and more maintainable. +preprocessor directives, but these tools lack standardization across projects. Rust GPU +leverages the existing Cargo ecosystem, so conditional compilation follows the same +standards all Rust developers already know. ## Come join us! diff --git a/blog/2024-11-21-optimizing-matrix-mul/snippets/naive.tsx b/blog/2024-11-21-optimizing-matrix-mul/snippets/naive.tsx index 90ba96f..d637a49 100644 --- a/blog/2024-11-21-optimizing-matrix-mul/snippets/naive.tsx +++ b/blog/2024-11-21-optimizing-matrix-mul/snippets/naive.tsx @@ -2,6 +2,8 @@ import React from "react"; import CodeBlock from "@theme/CodeBlock"; import Snippet from "@site/src/components/Snippet"; import RustKernelSource from "!!raw-loader!../code/crates/gpu/naive/src/lib.rs"; +import RustWorkgroupCount from "!!raw-loader!../code/crates/cpu/matmul/src/variants.rs"; +import RustWgpuBackend from "!!raw-loader!../code/crates/cpu/matmul/src/backends/wgpu.rs"; export const WebGpuInputs: React.FC = () => ( @@ -52,6 +54,29 @@ export const RustNaiveInputs: React.FC = () => ( ); +export const RustNaiveWorkgroupCount: React.FC = () => ( + + {RustWorkgroupCount} + +); + +export const RustNaiveDispatch: React.FC = () => ( + + {RustWgpuBackend} + +); + export const RustNaiveWorkgroup: React.FC = () => ( {RustKernelSource} diff --git a/blog/2024-11-21-optimizing-matrix-mul/snippets/workgroup_256.tsx b/blog/2024-11-21-optimizing-matrix-mul/snippets/workgroup_256.tsx index f47f5cc..def6ba2 100644 --- a/blog/2024-11-21-optimizing-matrix-mul/snippets/workgroup_256.tsx +++ b/blog/2024-11-21-optimizing-matrix-mul/snippets/workgroup_256.tsx @@ -2,7 +2,6 @@ import React from "react"; import Snippet from "@site/src/components/Snippet"; import RustKernelSource from "!!raw-loader!../code/crates/gpu/workgroup_256/src/lib.rs"; import VariantsSource from "!!raw-loader!../code/crates/cpu/matmul/src/variants.rs"; -import WgpuBackendSource from "!!raw-loader!../code/crates/cpu/matmul/src/backends/wgpu.rs"; export const RustWorkgroup256Workgroup: React.FC = () => ( @@ -20,15 +19,3 @@ export const RustWorkgroup256WorkgroupCount: React.FC = () => ( {VariantsSource} ); - -export const RustWorkgroup256WgpuDispatch: React.FC = () => ( - - {WgpuBackendSource} - -);