diff --git a/blog/2024-11-21-optimizing-matrix-mul/code/bin/blog/src/bin.rs b/blog/2024-11-21-optimizing-matrix-mul/code/bin/blog/src/bin.rs index 57850d0..3aef2b8 100644 --- a/blog/2024-11-21-optimizing-matrix-mul/code/bin/blog/src/bin.rs +++ b/blog/2024-11-21-optimizing-matrix-mul/code/bin/blog/src/bin.rs @@ -30,7 +30,7 @@ fn main() { run_tests(matmul::naive::wgpu(), &sizes); run_tests(matmul::workgroup_256::wgpu(), &sizes); run_tests(matmul::workgroup_2d::wgpu(), &sizes); - //run_tests(matmul::tiling_1d::wgpu(), &sizes); + run_tests(matmul::tiling_1d::wgpu(), &sizes); run_tests(matmul::tiling_2d_simd::wgpu(), &sizes); run_tests(matmul::isomorphic::wgpu(), &sizes); diff --git a/blog/2024-11-21-optimizing-matrix-mul/code/crates/cpu/matmul/src/backends/cpu.rs b/blog/2024-11-21-optimizing-matrix-mul/code/crates/cpu/matmul/src/backends/cpu.rs index 80a2e23..ad0d66d 100644 --- a/blog/2024-11-21-optimizing-matrix-mul/code/crates/cpu/matmul/src/backends/cpu.rs +++ b/blog/2024-11-21-optimizing-matrix-mul/code/crates/cpu/matmul/src/backends/cpu.rs @@ -171,6 +171,37 @@ mod tests { assert_eq!(result, expected); } + #[test] + fn test_single_threaded_matmul_4x4() { + let m = 4; + let k = 4; + let n = 4; + + // Define matrix `a` (4x4) in row-major order + let a = vec![ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + ]; + + // Define matrix `b` (4x4) in row-major order + let b = vec![ + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, + 31.0, 32.0, + ]; + + // Expected result (4x4) after multiplying `a` and `b` + let expected = vec![ + 250.0, 260.0, 270.0, 280.0, 618.0, 644.0, 670.0, 696.0, 986.0, 1028.0, 1070.0, 1112.0, + 1354.0, 1412.0, 1470.0, 1528.0, + ]; + + let variant = crate::variants::Isomorphic; + let matrix_multiplier = futures::executor::block_on(SingleThreadedMatMul::new(variant)); + + let result = matrix_multiplier.multiply(&a, &b, m, k, n); + + assert_eq!(result, expected); + } + #[test] fn test_multithreaded_matmul_2x1x1() { let m = 2; diff --git a/blog/2024-11-21-optimizing-matrix-mul/code/crates/gpu/tiling_1d/src/lib.rs b/blog/2024-11-21-optimizing-matrix-mul/code/crates/gpu/tiling_1d/src/lib.rs index 766a4fa..8c8b75c 100644 --- a/blog/2024-11-21-optimizing-matrix-mul/code/crates/gpu/tiling_1d/src/lib.rs +++ b/blog/2024-11-21-optimizing-matrix-mul/code/crates/gpu/tiling_1d/src/lib.rs @@ -27,14 +27,30 @@ pub fn matmul( for i in 0..dimensions.k as usize { let a_elem = a[row * dimensions.k as usize + i]; - sum00 += a_elem * b[i * dimensions.n as usize + col]; - sum01 += a_elem * b[i * dimensions.n as usize + col + 1]; - sum02 += a_elem * b[i * dimensions.n as usize + col + 2]; - sum03 += a_elem * b[i * dimensions.n as usize + col + 3]; + if col < dimensions.n as usize { + sum00 += a_elem * b[i * dimensions.n as usize + col]; + } + if col + 1 < dimensions.n as usize { + sum01 += a_elem * b[i * dimensions.n as usize + col + 1]; + } + if col + 2 < dimensions.n as usize { + sum02 += a_elem * b[i * dimensions.n as usize + col + 2]; + } + if col + 3 < dimensions.n as usize { + sum03 += a_elem * b[i * dimensions.n as usize + col + 3]; + } } - result[row * dimensions.n as usize + col] = sum00; - result[row * dimensions.n as usize + col + 1] = sum01; - result[row * dimensions.n as usize + col + 2] = sum02; - result[row * dimensions.n as usize + col + 3] = sum03; + if col < dimensions.n as usize { + result[row * dimensions.n as usize + col] = sum00; + } + if col + 1 < dimensions.n as usize { + result[row * dimensions.n as usize + col + 1] = sum01; + } + if col + 2 < dimensions.n as usize { + result[row * dimensions.n as usize + col + 2] = sum02; + } + if col + 3 < dimensions.n as usize { + result[row * dimensions.n as usize + col + 3] = sum03; + } } diff --git a/blog/2024-11-21-optimizing-matrix-mul/index.md b/blog/2024-11-21-optimizing-matrix-mul/index.md index 50ac366..1444a9f 100644 --- a/blog/2024-11-21-optimizing-matrix-mul/index.md +++ b/blog/2024-11-21-optimizing-matrix-mul/index.md @@ -275,7 +275,8 @@ import { RustTiling1d } from './snippets/tiling_1d.tsx'; The kernel looks roughly the same as before except we've unrolled the computation and -are calculating `TILE_SIZE` results per thread. +are calculating `TILE_SIZE` results per thread. We also need some error checking for +when our matrices don't fit nicely. We can take this a step further and calculate 2D results per thread! Instead of calculating 4 elements per single row, we can calculate 4 elements for 4 rows (e.g. a 2D