Skip to content

Commit

Permalink
Amortize sha2 compression loop (#231)
Browse files Browse the repository at this point in the history
  • Loading branch information
Nashtare committed May 20, 2024
1 parent 827925d commit bde253b
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 144 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ compression_loop:
%mload_kernel_code_u32
// stack: K[i], W[i], a[i], b[i], c[i], d[i], e[i], f[i], g[i], h[i], num_blocks, scratch_space_addr, message_schedule_addr, i, a[0]..h[0], retdest
DUP10
DUP10
DUP10
DUP10
// stack: e[i], f[i], g[i], h[i], K[i], W[i], a[i], b[i], c[i], d[i], e[i], f[i], g[i], h[i], num_blocks, scratch_space_addr, message_schedule_addr, i, a[0]..h[0], retdest
DUP8
DUP11
DUP11
DUP11
// stack: e[i], f[i], g[i], e[i], h[i], K[i], W[i], a[i], b[i], c[i], d[i], e[i], f[i], g[i], h[i], num_blocks, scratch_space_addr, message_schedule_addr, i, a[0]..h[0], retdest
%sha2_temp_word1
// stack: T1[i], a[i], b[i], c[i], d[i], e[i], f[i], g[i], h[i], num_blocks, scratch_space_addr, message_schedule_addr, i, a[0]..h[0], retdest
DUP4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@ gen_message_schedule_from_block:
// stack: block_addr, output_addr, retdest
DUP1
// stack: block_addr, block_addr, output_addr, retdest
%add_const(32)
// stack: block_addr + 32, block_addr, output_addr, retdest
SWAP1
// stack: block_addr, block_addr + 32, output_addr, retdest
%mload_u256
// stack: block[0], block_addr + 32, output_addr, retdest
// stack: block[0], block_addr, output_addr, retdest
SWAP1
// stack: block_addr, block[0], output_addr, retdest
%add_const(32)
// stack: block_addr + 32, block[0], output_addr, retdest
%mload_u256
// stack: block[1], block[0], output_addr, retdest
Expand All @@ -45,27 +43,22 @@ gen_message_schedule_from_block_0_loop:
// stack: output_addr, block[0] % (1 << 32), block[0] >> 32, output_addr, counter, block[1], retdest
%mstore_u32
// stack: block[0] >> 32, output_addr, counter, block[1], retdest
SWAP1
// stack: output_addr, block[0] >> 32, counter, block[1], retdest
%sub_const(4)
// stack: output_addr - 4, block[0] >> 32, counter, block[1], retdest
SWAP1
// stack: block[0] >> 32, output_addr - 4, counter, block[1], retdest
%stack (block0_shifted, output_addr, counter) -> (output_addr, 4, 1, counter, block0_shifted)
SUB
// stack: output_addr - 4, 1, counter, block[0] >> 32, block[1], retdest
SWAP2
// stack: counter, output_addr - 4, block[0] >> 32, block[1], retdest
%decrement
SUB
// stack: counter - 1, output_addr - 4, block[0] >> 32, block[1], retdest
DUP1
%jumpi(gen_message_schedule_from_block_0_loop)
gen_message_schedule_from_block_0_end:
// stack: old counter=0, output_addr, block[0], block[1], retdest
POP
// stack: output_addr, block[0], block[1], retdest
%stack (out, b0, b1) -> (out, 8, b1, b0)
// stack: output_addr, counter=8, block[1], block[0], retdest
%add_const(64)
// stack: output_addr + 64, counter, block[1], block[0], retdest
SWAP1
// stack: counter, output_addr + 64, block[1], block[0], retdest
// stack: output_addr + 64, block[0], block[1], retdest
%stack (out, b0, b1) -> (8, out, b1, b0)
// stack: counter=8, output_addr + 64, block[1], block[0], retdest
gen_message_schedule_from_block_1_loop:
// Split the second half (256 bits) of the block into the next eight (32-bit) chunks of the message sdchedule.
// stack: counter, output_addr, block[1], block[0], retdest
Expand All @@ -83,29 +76,22 @@ gen_message_schedule_from_block_1_loop:
// stack: output_addr, block[1] % (1 << 32), block[1] >> 32, output_addr, counter, block[0], retdest
%mstore_u32
// stack: block[1] >> 32, output_addr, counter, block[0], retdest
SWAP1
// stack: output_addr, block[1] >> 32, counter, block[0], retdest
%sub_const(4)
// stack: output_addr - 4, block[1] >> 32, counter, block[0], retdest
SWAP1
// stack: block[1] >> 32, output_addr - 4, counter, block[0], retdest
%stack (block1_shifted, output_addr, counter) -> (output_addr, 4, 1, counter, block1_shifted)
SUB
// stack: output_addr - 4, 1, counter, block[1] >> 32, block[0], retdest
SWAP2
// stack: counter, output_addr - 4, block[1] >> 32, block[0], retdest
%decrement
SUB
// stack: counter - 1, output_addr - 4, block[1] >> 32, block[0], retdest
DUP1
%jumpi(gen_message_schedule_from_block_1_loop)
gen_message_schedule_from_block_1_end:
// stack: old counter=0, output_addr, block[1], block[0], retdest
POP
// stack: output_addr, block[0], block[1], retdest
PUSH 48
// stack: counter=48, output_addr, block[0], block[1], retdest
SWAP1
// stack: output_addr, counter, block[0], block[1], retdest
%add_const(36)
// stack: output_addr + 36, counter, block[0], block[1], retdest
SWAP1
// stack: counter, output_addr + 36, block[0], block[1], retdest
// stack: output_addr + 36, block[0], block[1], retdest
PUSH 48
// stack: counter=48, output_addr + 36, block[0], block[1], retdest
gen_message_schedule_remaining_loop:
// Generate the next 48 chunks of the message schedule, one at a time, from prior chunks.
// stack: counter, output_addr, block[0], block[1], retdest
Expand Down Expand Up @@ -153,9 +139,10 @@ gen_message_schedule_remaining_loop:
// stack: output_addr, x[output_addr - 16*4], sigma_0(x[output_addr - 15*4]), x[output_addr - 7*4], sigma_1(x[output_addr - 2*4]), counter, block[0], block[1], retdest
SWAP4
// stack: sigma_1(x[output_addr - 2*4]), x[output_addr - 16*4], sigma_0(x[output_addr - 15*4]), x[output_addr - 7*4], output_addr, counter, block[0], block[1], retdest
%add_u32
%add_u32
%add_u32
ADD
ADD
ADD
%as_u32
// stack: sigma_1(x[output_addr - 2*4]) + x[output_addr - 16*4] + sigma_0(x[output_addr - 15*4]) + x[output_addr - 7*4], output_addr, counter, block[0], block[1], retdest
DUP2
// stack: output_addr, sigma_1(x[output_addr - 2*4]) + x[output_addr - 16*4] + sigma_0(x[output_addr - 15*4]) + x[output_addr - 7*4], output_addr, counter, block[0], block[1], retdest
Expand All @@ -174,20 +161,22 @@ gen_message_schedule_remaining_end:
%pop4
JUMP

// Precodition: memory, starting at 0, contains num_blocks, block0[0], ..., block0[63], block1[0], ..., blocklast[63]
// Precondition: memory, starting at 0, contains num_blocks, block0[0], ..., block0[63], block1[0], ..., blocklast[63]
// stack contains output_addr
// Postcondition: starting at output_addr, set of 256 bytes per block
// each contains the 64 32-bit chunks of the message schedule for that block (in four-byte increments)
global sha2_gen_all_message_schedules:
// stack: output_addr, retdest
DUP1
// stack: output_addr, output_addr, retdest
%mload_current_general_no_offset
// stack: num_blocks, output_addr, output_addr, retdest
PUSH 1
// stack: cur_offset = 1, counter = num_blocks, output_addr, output_addr, retdest
%build_current_general_address
// stack: cur_addr, counter, output_addr, output_addr, retdest
%build_current_general_address_no_offset
DUP1
// stack: base_addr, base_addr, output_addr, output_addr, retdest
MLOAD_GENERAL
// stack: num_blocks, base_addr, output_addr, output_addr, retdest
SWAP1
%increment
// stack: cur_addr (offset = 1), counter = num_blocks, output_addr, output_addr, retdest
gen_all_message_schedules_loop:
// stack: cur_addr, counter, cur_output_addr, output_addr, retdest
PUSH gen_all_message_schedules_loop_end
Expand Down
120 changes: 49 additions & 71 deletions evm_arithmetization/src/cpu/kernel/asm/hash/sha2/ops.asm
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
// 32-bit right rotation
%macro rotr(rot)
// stack: value
DUP1
// stack: value, value
PUSH $rot
// stack: rot, value
DUP2
DUP2
// stack: rot, value, rot, value
// stack: rot, value, value
SHR
// stack: value >> rot, rot, value
%stack (shifted, rot, value) -> (rot, value, shifted)
// stack: value >> rot, value
SWAP1
PUSH $rot
// stack: rot, value, value >> rot
PUSH 32
SUB
Expand All @@ -26,16 +26,14 @@
// stack: x, x
%rotr(7)
// stack: rotr(x, 7), x
SWAP1
// stack: x, rotr(x, 7)
DUP1
// stack: x, x, rotr(x, 7)
%rotr(18)
// stack: rotr(x, 18), x, rotr(x, 7)
SWAP1
// stack: x, rotr(x, 18), rotr(x, 7)
// stack: rotr(x, 7), rotr(x, 7), x
%rotr(11)
// stack: rotr(x, 18), rotr(x, 7), x
SWAP2
// stack: x, rotr(x, 7), rotr(x, 18)
%shr_const(3)
// stack: shr(x, 3), rotr(x, 18), rotr(x, 7)
// stack: shr(x, 3), rotr(x, 7), rotr(x, 18)
XOR
XOR
%endmacro
Expand All @@ -46,98 +44,78 @@
// stack: x, x
%rotr(17)
// stack: rotr(x, 17), x
SWAP1
// stack: x, rotr(x, 17)
DUP1
// stack: x, x, rotr(x, 17)
%rotr(19)
// stack: rotr(x, 19), x, rotr(x, 17)
SWAP1
// stack: x, rotr(x, 19), rotr(x, 17)
// stack: rotr(x, 17), rotr(x, 17), x
%rotr(2)
// stack: rotr(x, 19), rotr(x, 17), x
SWAP2
// stack: x, rotr(x, 17), rotr(x, 19)
PUSH 10
SHR
// stack: shr(x, 10), rotr(x, 19), rotr(x, 17)
// stack: shr(x, 10), rotr(x, 17), rotr(x, 19)
XOR
XOR
%endmacro

%macro sha2_bigsigma_0
// stack: x
DUP1
// stack: x, x
%rotr(2)
// stack: rotr(x, 2), x
SWAP1
// stack: x, rotr(x, 2)
// stack: rotr(x, 2)
DUP1
// stack: x, x, rotr(x, 2)
%rotr(13)
// stack: rotr(x, 13), x, rotr(x, 2)
SWAP1
// stack: x, rotr(x, 13), rotr(x, 2)
%rotr(22)
// stack: rotr(x, 2), rotr(x, 2)
%rotr(11)
// stack: rotr(x, 13), rotr(x, 2)
DUP1
// stack: rotr(x, 13), rotr(x, 13), rotr(x, 2)
%rotr(9)
// stack: rotr(x, 22), rotr(x, 13), rotr(x, 2)
XOR
XOR
%endmacro

%macro sha2_bigsigma_1
// stack: x
DUP1
// stack: x, x
%rotr(6)
// stack: rotr(x, 6), x
SWAP1
// stack: x, rotr(x, 6)
// stack: rotr(x, 6)
DUP1
// stack: x, x, rotr(x, 6)
%rotr(11)
// stack: rotr(x, 11), x, rotr(x, 6)
SWAP1
// stack: x, rotr(x, 11), rotr(x, 6)
%rotr(25)
// stack: rotr(x, 6), rotr(x, 6)
%rotr(5)
// stack: rotr(x, 11), rotr(x, 6)
DUP1
// stack: rotr(x, 11), rotr(x, 11), rotr(x, 6)
%rotr(14)
// stack: rotr(x, 25), rotr(x, 11), rotr(x, 6)
XOR
XOR
%endmacro

%macro sha2_choice
// stack: x, y, z
DUP1
// stack: x, x, y, z
NOT
// stack: not x, x, y, z
SWAP1
// stack: x, not x, y, z
SWAP3
// stack: z, not x, y, x
AND
// stack: (not x) and z, y, x
SWAP2
// stack: x, y, (not x) and z
// stack: y, x, z
DUP3
// stack: z, y, x, z
XOR
// stack: z xor y, x, z
AND
// stack: x and y, (not x) and z
OR
// stack: (z xor y) and x, z
XOR
// stack: ((z xor y) and x) xor z == (x and y) xor (not x and z)
%endmacro

%macro sha2_majority
// stack: x, y, z
DUP1
// stack: x, x, y, z
DUP3
// stack: y, x, x, y, z
DUP5
// stack: z, y, x, x, y, z
AND
// stack: z and y, x, x, y, z
SWAP4
// stack: z, x, x, y, z and y
DUP2
DUP2
AND
// stack: z and x, x, y, z and y
// stack: x and y, x, y, z
SWAP2
// stack: y, x, z and x, z and y
AND
// stack: y and x, z and x, z and y
// stack: y, x, x and y, z
OR
// stack: y or x, x and y, z
%stack(y_or_x, x_and_y, z) -> (z, y_or_x, x_and_y)
AND
// stack: z and (y or x), x and y
OR
// stack: (z and (y or x) or (x and y) == (x and y) or (x and z) or (y and z)
%endmacro
24 changes: 11 additions & 13 deletions evm_arithmetization/src/cpu/kernel/asm/hash/sha2/temp_words.asm
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
// "T_1" in the SHA-256 spec
%macro sha2_temp_word1
// stack: e, f, g, h, K[i], W[i]
DUP1
// stack: e, e, f, g, h, K[i], W[i]
%sha2_bigsigma_1
// stack: Sigma_1(e), e, f, g, h, K[i], W[i]
%stack (sig, e, f, g) -> (e, f, g, sig)
// stack: e, f, g, Sigma_1(e), h, K[i], W[i]
// stack: e, f, g, e, h, K[i], W[i]
%sha2_choice
// stack: Ch(e, f, g), Sigma_1(e), h, K[i], W[i]
%add_u32
%add_u32
%add_u32
%add_u32
// stack: Ch(e, f, g), e, h, K[i], W[i]
SWAP1
// stack: e, Ch(e, f, g), h, K[i], W[i]
%sha2_bigsigma_1
// stack: Sigma_1(e), Ch(e, f, g), h, K[i], W[i]
ADD
ADD
ADD
ADD
// stack: Ch(e, f, g) + Sigma_1(e) + h + K[i] + W[i]
%endmacro

Expand All @@ -27,6 +25,6 @@
// stack: c, a, b, Sigma_0(a)
%sha2_majority
// stack: Maj(c, a, b), Sigma_0(a)
%add_u32
ADD
// stack: Maj(c, a, b) + Sigma_0(a)
%endmacro
Loading

0 comments on commit bde253b

Please sign in to comment.