Skip to content

Commit

Permalink
[FORK][FIX] IP weights compression: scalar scale
Browse files Browse the repository at this point in the history
[FORK][FEATURE] InnerProduct primitive: squashed weight decompression
  • Loading branch information
dmitry-gorokhov committed Jan 14, 2025
1 parent 1efdaaa commit c7ecd8f
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions src/cpu/x64/brgemm/jit_brgemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ struct jit_brgemm_kernel_t : public jit_generator {
used_vregs = 5;
else if (brg.is_f16_b_non_amx_vnni())
used_vregs = 2;

if (one_of(brg.dt_b, data_type::nf4) && brg.isa_impl == avx2) {
used_vregs += 5;
}
Expand Down Expand Up @@ -2431,7 +2431,11 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel_dyn_quant(int bd_block2,
for (int bd = bd_b; bd < bd_e; bd++) {
uni_vbroadcastss(vmm_src_scales, ptr[reg_local_src_scales + bd * brg.src_scales_stride * sizeof(float)]);
for (int ld = 0; ld < ld_block2; ld++) {
uni_vmovups(load(ld), ptr[reg_local_wei_scales + ld * brg.ld_block * sizeof(float)]);
if (brg.wei_decomp_scales_stride == 0) {
uni_vbroadcastss(load(ld), ptr[reg_local_wei_scales]);
} else {
uni_vmovups(load(ld), ptr[reg_local_wei_scales + ld * brg.ld_block * sizeof(float)]);
}
}
for (int ld = 0; ld < ld_block2; ld++) {
auto vmm_accm_aux = vmm_accm_tmp(ld_block2, bd, ld);
Expand Down Expand Up @@ -2901,7 +2905,11 @@ void jit_brgemm_kernel_t<Wmm>::gemm_microkernel(int bd_block2, bool is_bdb_tail,
for (int ld = 0; ld < ld_block2; ld++) {
auto vmm_accm_tmp = accm_tmp(ld_block2, 0, ld);
auto vmm_accm = accm(ld_block2, 0, ld);
load_scales(bcst(), ptr[reg_local_wei_scales + ld * brg.ld_block * types::data_type_size(brg.wei_decomp_scales_dt)]);
if (brg.wei_decomp_scales_stride == 0) {
load_scales(bcst(), ptr[reg_local_wei_scales]);
} else {
load_scales(bcst(), ptr[reg_local_wei_scales + ld * brg.ld_block * types::data_type_size(brg.wei_decomp_scales_dt)]);
}
uni_vfmadd231ps(vmm_accm, vmm_accm_tmp, bcst());
}
}
Expand Down Expand Up @@ -3025,8 +3033,8 @@ void jit_brgemm_kernel_t<Wmm>::ldb_loop(int bd_block2, bool is_bdb_tail,
mov(reg_rdb_loop, brg.rdb);
L_aligned(rdb_loop_label, 64);
{
if (brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 ||
brg.wei_decomp_zero_points_stride != 0)) {
if ((brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 ||
brg.wei_decomp_zero_points_stride != 0)) || brg.with_src_dyn_quant) {
auto reg_local_ic = reg_aux_D;
auto reg_local_wei_params = reg_bdb_loop;
auto reg_local_ic_group = reg_ldb_loop;
Expand Down

0 comments on commit c7ecd8f

Please sign in to comment.