Skip to content

Commit

Permalink
muriscv_nn_vec_mat_mult_t_s8: fallback to scalar mode if USE_VECT and…
Browse files Browse the repository at this point in the history
… address_offset != 1 (see #81)
  • Loading branch information
PhilippvK committed Oct 31, 2024
1 parent 975d0fa commit 3702842
Showing 1 changed file with 127 additions and 0 deletions.
127 changes: 127 additions & 0 deletions Source/NNSupportFunctions/muriscv_nn_vec_mat_mult_t_s8.c
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ muriscv_nn_status muriscv_nn_vec_mat_mult_t_s8(const q7_t *lhs,
const int32_t rhs_offset) // Currently Unused
{
#if defined(USE_VEXT)
if (address_offset == 1) { // TODO: remove this after #81 is fixed

/* At some point in time there might be implementations with the Zvediv extension.
* It might provide dot-product functions which would simplify this code here and
Expand Down Expand Up @@ -249,8 +250,134 @@ muriscv_nn_status muriscv_nn_vec_mat_mult_t_s8(const q7_t *lhs,
*dst = (q7_t)res_scalar;
dst += address_offset;
rhs += rhs_cols;

}
} else { // TODO: remove this after #81 is fixed
/* Uses 5x loop unrolling in order to expose more ILP */
const int32_t row_loop_cnt = rhs_rows / 5;
for (int i_row_loop_cnt = 0; i_row_loop_cnt < row_loop_cnt; i_row_loop_cnt++)
{

const q7_t *lhs_ptr = lhs;
const q7_t *rhs_ptr_0 = &rhs[0];
const q7_t *rhs_ptr_1 = &rhs[rhs_cols];
const q7_t *rhs_ptr_2 = &rhs[rhs_cols * 2];
const q7_t *rhs_ptr_3 = &rhs[rhs_cols * 3];
const q7_t *rhs_ptr_4 = &rhs[rhs_cols * 4];

q31_t res00 = 0;
q31_t res01 = 0;
q31_t res02 = 0;
q31_t res03 = 0;
q31_t res04 = 0;
if (bias)
{
res00 = *bias++;
res01 = *bias++;
res02 = *bias++;
res03 = *bias++;
res04 = *bias++;
}
for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
{
const q31_t rhs_value0 = (int8_t)*rhs_ptr_0;
const q31_t rhs_value1 = (int8_t)*rhs_ptr_1;
const q31_t rhs_value2 = (int8_t)*rhs_ptr_2;
const q31_t rhs_value3 = (int8_t)*rhs_ptr_3;
const q31_t rhs_value4 = (int8_t)*rhs_ptr_4;
const q31_t lhs_value = (int8_t)*lhs_ptr + lhs_offset;

res00 += lhs_value * rhs_value0;
res01 += lhs_value * rhs_value1;
res02 += lhs_value * rhs_value2;
res03 += lhs_value * rhs_value3;
res04 += lhs_value * rhs_value4;

++rhs_ptr_0;
++rhs_ptr_1;
++rhs_ptr_2;
++rhs_ptr_3;
++rhs_ptr_4;
++lhs_ptr;
}
// Quantize down
res00 = muriscv_nn_requantize(res00, dst_multiplier, dst_shift);
res01 = muriscv_nn_requantize(res01, dst_multiplier, dst_shift);
res02 = muriscv_nn_requantize(res02, dst_multiplier, dst_shift);
res03 = muriscv_nn_requantize(res03, dst_multiplier, dst_shift);
res04 = muriscv_nn_requantize(res04, dst_multiplier, dst_shift);

// Add offset
res00 += dst_offset;
res01 += dst_offset;
res02 += dst_offset;
res03 += dst_offset;
res04 += dst_offset;

// Clamp the result
res00 = MAX(res00, activation_min);
res00 = MIN(res00, activation_max);
res01 = MAX(res01, activation_min);
res01 = MIN(res01, activation_max);
res02 = MAX(res02, activation_min);
res02 = MIN(res02, activation_max);
res03 = MAX(res03, activation_min);
res03 = MIN(res03, activation_max);
res04 = MAX(res04, activation_min);
res04 = MIN(res04, activation_max);

*dst = (q7_t)res00;
*(dst + address_offset) = (q7_t)res01;
*(dst + 2 * address_offset) = (q7_t)res02;
*(dst + 3 * address_offset) = (q7_t)res03;
*(dst + 4 * address_offset) = (q7_t)res04;
dst += 5 * address_offset;

rhs += 5 * rhs_cols;

}

/* Handle the leftover part from 5x loop unrolling */
const int loop_cnt = rhs_rows % 5;
for (int i_loop_cnt = 0; i_loop_cnt < loop_cnt; i_loop_cnt++)
{

const q7_t *lhs_ptr = &lhs[0];
const q7_t *rhs_ptr = &rhs[0];

q31_t res00 = 0;
if (bias)
{
res00 = *bias++;
}

for (int32_t rhs_cols_idx = 0; rhs_cols_idx < rhs_cols; ++rhs_cols_idx)
{
q31_t rhs_value0 = (int8_t)rhs_ptr[0];
q31_t lhs_value = (int8_t)lhs_ptr[0] + lhs_offset;

res00 += lhs_value * rhs_value0;

++rhs_ptr;
++lhs_ptr;
}

// Quantize down
res00 = muriscv_nn_requantize(res00, dst_multiplier, dst_shift);

// Add offset
res00 += dst_offset;

// Clamp the result
res00 = MAX(res00, activation_min);
res00 = MIN(res00, activation_max);

*dst = (int8_t)res00;
dst += address_offset;
rhs += rhs_cols;

}
} // TODO: remove this after #81 is fixed

#elif defined(USE_PEXT) /* defined(USE_VEXT) */

Expand Down

0 comments on commit 3702842

Please sign in to comment.