Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Amortize sha2 compression loop #231

Merged
merged 10 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,10 @@ gen_message_schedule_remaining_loop:
// stack: output_addr, x[output_addr - 16*4], sigma_0(x[output_addr - 15*4]), x[output_addr - 7*4], sigma_1(x[output_addr - 2*4]), counter, block[0], block[1], retdest
SWAP4
// stack: sigma_1(x[output_addr - 2*4]), x[output_addr - 16*4], sigma_0(x[output_addr - 15*4]), x[output_addr - 7*4], output_addr, counter, block[0], block[1], retdest
%add_u32
%add_u32
%add_u32
ADD
ADD
ADD
%as_u32
// stack: sigma_1(x[output_addr - 2*4]) + x[output_addr - 16*4] + sigma_0(x[output_addr - 15*4]) + x[output_addr - 7*4], output_addr, counter, block[0], block[1], retdest
DUP2
// stack: output_addr, sigma_1(x[output_addr - 2*4]) + x[output_addr - 16*4] + sigma_0(x[output_addr - 15*4]) + x[output_addr - 7*4], output_addr, counter, block[0], block[1], retdest
Expand All @@ -182,12 +183,14 @@ global sha2_gen_all_message_schedules:
// stack: output_addr, retdest
DUP1
// stack: output_addr, output_addr, retdest
%mload_current_general_no_offset
// stack: num_blocks, output_addr, output_addr, retdest
PUSH 1
// stack: cur_offset = 1, counter = num_blocks, output_addr, output_addr, retdest
%build_current_general_address
// stack: cur_addr, counter, output_addr, output_addr, retdest
%build_current_general_address_no_offset
DUP1
// stack: base_addr, base_addr, output_addr, output_addr, retdest
MLOAD_GENERAL
// stack: num_blocks, base_addr, output_addr, output_addr, retdest
SWAP1
%increment
// stack: cur_addr (offset = 1), counter = num_blocks, output_addr, output_addr, retdest
gen_all_message_schedules_loop:
// stack: cur_addr, counter, cur_output_addr, output_addr, retdest
PUSH gen_all_message_schedules_loop_end
Expand Down
120 changes: 49 additions & 71 deletions evm_arithmetization/src/cpu/kernel/asm/hash/sha2/ops.asm
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
// 32-bit right rotation
%macro rotr(rot)
// stack: value
DUP1
// stack: value, value
PUSH $rot
// stack: rot, value
DUP2
DUP2
// stack: rot, value, rot, value
// stack: rot, value, value
SHR
// stack: value >> rot, rot, value
%stack (shifted, rot, value) -> (rot, value, shifted)
// stack: value >> rot, value
SWAP1
PUSH $rot
Copy link
Collaborator Author

@Nashtare Nashtare May 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I'm pushing again instead of calling DUP because $rot always fits in 4 bytes at most anyway, and the savings on the # CPU cycles are worth the extra overhead on BytePackingStark.

// stack: rot, value, value >> rot
PUSH 32
SUB
Expand All @@ -26,16 +26,14 @@
// stack: x, x
%rotr(7)
// stack: rotr(x, 7), x
SWAP1
// stack: x, rotr(x, 7)
DUP1
// stack: x, x, rotr(x, 7)
%rotr(18)
// stack: rotr(x, 18), x, rotr(x, 7)
SWAP1
// stack: x, rotr(x, 18), rotr(x, 7)
// stack: rotr(x, 7), rotr(x, 7), x
%rotr(11)
// stack: rotr(x, 18), rotr(x, 7), x
SWAP2
// stack: x, rotr(x, 7), rotr(x, 18)
%shr_const(3)
// stack: shr(x, 3), rotr(x, 18), rotr(x, 7)
// stack: shr(x, 3), rotr(x, 7), rotr(x, 18)
XOR
XOR
%endmacro
Expand All @@ -46,98 +44,78 @@
// stack: x, x
%rotr(17)
// stack: rotr(x, 17), x
SWAP1
// stack: x, rotr(x, 17)
DUP1
// stack: x, x, rotr(x, 17)
%rotr(19)
// stack: rotr(x, 19), x, rotr(x, 17)
SWAP1
// stack: x, rotr(x, 19), rotr(x, 17)
// stack: rotr(x, 17), rotr(x, 17), x
%rotr(2)
// stack: rotr(x, 19), rotr(x, 17), x
SWAP2
// stack: x, rotr(x, 17), rotr(x, 19)
PUSH 10
SHR
// stack: shr(x, 10), rotr(x, 19), rotr(x, 17)
// stack: shr(x, 10), rotr(x, 17), rotr(x, 19)
XOR
XOR
%endmacro

%macro sha2_bigsigma_0
// stack: x
DUP1
// stack: x, x
%rotr(2)
// stack: rotr(x, 2), x
SWAP1
// stack: x, rotr(x, 2)
// stack: rotr(x, 2)
DUP1
// stack: x, x, rotr(x, 2)
%rotr(13)
// stack: rotr(x, 13), x, rotr(x, 2)
SWAP1
// stack: x, rotr(x, 13), rotr(x, 2)
%rotr(22)
// stack: rotr(x, 2), rotr(x, 2)
%rotr(11)
// stack: rotr(x, 13), rotr(x, 2)
DUP1
// stack: rotr(x, 13), rotr(x, 13), rotr(x, 2)
%rotr(9)
// stack: rotr(x, 22), rotr(x, 13), rotr(x, 2)
XOR
XOR
%endmacro

%macro sha2_bigsigma_1
// stack: x
DUP1
// stack: x, x
%rotr(6)
// stack: rotr(x, 6), x
SWAP1
// stack: x, rotr(x, 6)
// stack: rotr(x, 6)
DUP1
// stack: x, x, rotr(x, 6)
%rotr(11)
// stack: rotr(x, 11), x, rotr(x, 6)
SWAP1
// stack: x, rotr(x, 11), rotr(x, 6)
%rotr(25)
// stack: rotr(x, 6), rotr(x, 6)
%rotr(5)
// stack: rotr(x, 11), rotr(x, 6)
DUP1
// stack: rotr(x, 11), rotr(x, 11), rotr(x, 6)
%rotr(14)
// stack: rotr(x, 25), rotr(x, 11), rotr(x, 6)
XOR
XOR
%endmacro

%macro sha2_choice
// stack: x, y, z
DUP1
// stack: x, x, y, z
NOT
// stack: not x, x, y, z
SWAP1
// stack: x, not x, y, z
SWAP3
// stack: z, not x, y, x
AND
// stack: (not x) and z, y, x
SWAP2
// stack: x, y, (not x) and z
// stack: y, x, z
DUP3
// stack: z, y, x, z
XOR
// stack: z xor y, x, z
AND
// stack: x and y, (not x) and z
OR
// stack: (z xor y) and x, z
XOR
// stack: ((z xor y) and x) xor z == (x and y) xor (not x and z)
%endmacro

%macro sha2_majority
// stack: x, y, z
DUP1
// stack: x, x, y, z
DUP3
// stack: y, x, x, y, z
DUP5
// stack: z, y, x, x, y, z
AND
// stack: z and y, x, x, y, z
SWAP4
// stack: z, x, x, y, z and y
DUP2
DUP2
AND
// stack: z and x, x, y, z and y
// stack: x and y, x, y, z
SWAP2
// stack: y, x, z and x, z and y
AND
// stack: y and x, z and x, z and y
// stack: y, x, x and y, z
OR
// stack: y or x, x and y, z
%stack(y_or_x, x_and_y, z) -> (z, y_or_x, x_and_y)
AND
// stack: (z and (y or x), x and y
Nashtare marked this conversation as resolved.
Show resolved Hide resolved
OR
// stack: (z and (y or x) or (x and y) == (x and y) or (x and z) or (y and z)
%endmacro
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
// stack: e, f, g, Sigma_1(e), h, K[i], W[i]
%sha2_choice
// stack: Ch(e, f, g), Sigma_1(e), h, K[i], W[i]
%add_u32
%add_u32
%add_u32
%add_u32
ADD
ADD
ADD
ADD
%as_u32
// stack: Ch(e, f, g) + Sigma_1(e) + h + K[i] + W[i]
%endmacro

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
%build_current_general_address
SWAP1
// stack: length, last_addr
DUP1
// stack: length, length, last_addr
DUP2
DUP2
// stack: length, last_addr, length, last_addr
%and_const(0xff)
// stack: length % (1 << 8), length, last_addr
DUP3
// stack: last_addr, length % (1 << 8), length, last_addr
%swap_mstore
// stack: length % (1 << 8), last_addr, length, last_addr
MSTORE_GENERAL

%rep 7
// For i = 0 to 6
Expand All @@ -20,14 +19,13 @@
// stack: length >> (8 * i), last_addr - i - 2
%shr_const(8)
// stack: length >> (8 * (i + 1)), last_addr - i - 2
PUSH 256
DUP2
// stack: length >> (8 * (i + 1)), 256, length >> (8 * (i + 1)), last_addr - i - 2
PUSH 256
DUP3
// stack: length >> (8 * (i + 1)), 256, last_addr - i - 2, length >> (8 * (i + 1)), last_addr - i - 2
MOD
// stack: (length >> (8 * (i + 1))) % (1 << 8), length >> (8 * (i + 1)), last_addr - i - 2
DUP3
// stack: last_addr - i - 2, (length >> (8 * (i + 1))) % (1 << 8), length >> (8 * (i + 1)), last_addr - i - 2
%swap_mstore
MSTORE_GENERAL
%endrep

%pop2
Expand Down