From a60c3cdbbf4f35c1077046e4ff990ad9cfa96a10 Mon Sep 17 00:00:00 2001 From: Andreas Woess Date: Mon, 16 Dec 2024 21:43:57 +0100 Subject: [PATCH 1/3] Fix br_table_i32 branch profile counter offset. (cherry picked from commit c4492878c31c25973c19e1d16c61ccf3871b6da1) --- .../src/org/graalvm/wasm/nodes/WasmFunctionNode.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java index 620c89b04b51..fe1a41970ffe 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java @@ -435,6 +435,7 @@ public Object executeBodyFromOffset(WasmContext context, WasmInstance instance, stackPointer--; int index = popInt(frame, stackPointer); final int size = rawPeekU8(bytecode, offset); + final int counterOffset = offset + 1; if (index < 0 || index >= size) { // If unsigned index is larger or equal to the table size use the // default (last) index. @@ -443,7 +444,7 @@ public Object executeBodyFromOffset(WasmContext context, WasmInstance instance, if (CompilerDirectives.inInterpreter()) { final int indexOffset = offset + 3 + index * 6; - updateBranchTableProfile(bytecode, offset + 1, indexOffset + 4); + updateBranchTableProfile(bytecode, counterOffset, indexOffset + 4); final int offsetDelta = rawPeekI32(bytecode, indexOffset); offset = indexOffset + offsetDelta; break; @@ -453,7 +454,7 @@ public Object executeBodyFromOffset(WasmContext context, WasmInstance instance, // time constants, since the loop is unrolled. for (int i = 0; i < size; i++) { final int indexOffset = offset + 3 + i * 6; - if (profileBranchTable(bytecode, offset + 1, indexOffset + 4, i == index)) { + if (profileBranchTable(bytecode, counterOffset, indexOffset + 4, i == index)) { final int offsetDelta = rawPeekI32(bytecode, indexOffset); offset = indexOffset + offsetDelta; continue loop; @@ -467,6 +468,7 @@ public Object executeBodyFromOffset(WasmContext context, WasmInstance instance, stackPointer--; int index = popInt(frame, stackPointer); final int size = rawPeekI32(bytecode, offset); + final int counterOffset = offset + 4; if (index < 0 || index >= size) { // If unsigned index is larger or equal to the table size use the // default (last) index. @@ -475,7 +477,7 @@ public Object executeBodyFromOffset(WasmContext context, WasmInstance instance, if (CompilerDirectives.inInterpreter()) { final int indexOffset = offset + 6 + index * 6; - updateBranchTableProfile(bytecode, offset + 4, indexOffset + 4); + updateBranchTableProfile(bytecode, counterOffset, indexOffset + 4); final int offsetDelta = rawPeekI32(bytecode, indexOffset); offset = indexOffset + offsetDelta; break; @@ -485,7 +487,7 @@ public Object executeBodyFromOffset(WasmContext context, WasmInstance instance, // time constants, since the loop is unrolled. for (int i = 0; i < size; i++) { final int indexOffset = offset + 6 + i * 6; - if (profileBranchTable(bytecode, offset + 1, indexOffset + 4, i == index)) { + if (profileBranchTable(bytecode, counterOffset, indexOffset + 4, i == index)) { final int offsetDelta = rawPeekI32(bytecode, indexOffset); offset = indexOffset + offsetDelta; continue loop; From 2aefeed451b6846d106431cab6e4703e5a74167a Mon Sep 17 00:00:00 2001 From: Andreas Woess Date: Mon, 16 Dec 2024 21:46:48 +0100 Subject: [PATCH 2/3] Simplify br_table dispatch (avoid conditional around index). (cherry picked from commit 5f2329b00f6ab8cd3d313748f1606359fa2aecdf) --- .../graalvm/wasm/nodes/WasmFunctionNode.java | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java index fe1a41970ffe..be20687dc384 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java @@ -436,13 +436,14 @@ public Object executeBodyFromOffset(WasmContext context, WasmInstance instance, int index = popInt(frame, stackPointer); final int size = rawPeekU8(bytecode, offset); final int counterOffset = offset + 1; - if (index < 0 || index >= size) { - // If unsigned index is larger or equal to the table size use the - // default (last) index. - index = size - 1; - } if (CompilerDirectives.inInterpreter()) { + if (index < 0 || index >= size) { + // If unsigned index is larger or equal to the table size use the + // default (last) index. + index = size - 1; + } + final int indexOffset = offset + 3 + index * 6; updateBranchTableProfile(bytecode, counterOffset, indexOffset + 4); final int offsetDelta = rawPeekI32(bytecode, indexOffset); @@ -454,28 +455,28 @@ public Object executeBodyFromOffset(WasmContext context, WasmInstance instance, // time constants, since the loop is unrolled. for (int i = 0; i < size; i++) { final int indexOffset = offset + 3 + i * 6; - if (profileBranchTable(bytecode, counterOffset, indexOffset + 4, i == index)) { + if (profileBranchTable(bytecode, counterOffset, indexOffset + 4, i == index || i == size - 1)) { final int offsetDelta = rawPeekI32(bytecode, indexOffset); offset = indexOffset + offsetDelta; continue loop; } } + throw CompilerDirectives.shouldNotReachHere("br_table"); } - enterErrorBranch(); - throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, this, "Should not reach here"); } case Bytecode.BR_TABLE_I32: { stackPointer--; int index = popInt(frame, stackPointer); final int size = rawPeekI32(bytecode, offset); final int counterOffset = offset + 4; - if (index < 0 || index >= size) { - // If unsigned index is larger or equal to the table size use the - // default (last) index. - index = size - 1; - } if (CompilerDirectives.inInterpreter()) { + if (index < 0 || index >= size) { + // If unsigned index is larger or equal to the table size use the + // default (last) index. + index = size - 1; + } + final int indexOffset = offset + 6 + index * 6; updateBranchTableProfile(bytecode, counterOffset, indexOffset + 4); final int offsetDelta = rawPeekI32(bytecode, indexOffset); @@ -487,15 +488,14 @@ public Object executeBodyFromOffset(WasmContext context, WasmInstance instance, // time constants, since the loop is unrolled. for (int i = 0; i < size; i++) { final int indexOffset = offset + 6 + i * 6; - if (profileBranchTable(bytecode, counterOffset, indexOffset + 4, i == index)) { + if (profileBranchTable(bytecode, counterOffset, indexOffset + 4, i == index || i == size - 1)) { final int offsetDelta = rawPeekI32(bytecode, indexOffset); offset = indexOffset + offsetDelta; continue loop; } } + throw CompilerDirectives.shouldNotReachHere("br_table"); } - enterErrorBranch(); - throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, this, "Should not reach here"); } case Bytecode.CALL_U8: case Bytecode.CALL_I32: { From 5fd921c97e0320927a0cc8e79e396cf036a57d07 Mon Sep 17 00:00:00 2001 From: Andreas Woess Date: Mon, 16 Dec 2024 21:53:48 +0100 Subject: [PATCH 3/3] Fix br_table branch profile counter overflow handling. (cherry picked from commit 49ad451fd0faed030531dfb4d10773814352dfcb) --- .../graalvm/wasm/nodes/WasmFunctionNode.java | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java index be20687dc384..96bdfaf4beab 100644 --- a/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java +++ b/wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/nodes/WasmFunctionNode.java @@ -4681,11 +4681,26 @@ private static boolean profileCondition(byte[] data, final int profileOffset, bo } private static void updateBranchTableProfile(byte[] data, final int counterOffset, final int profileOffset) { - assert CompilerDirectives.inInterpreter(); + CompilerAsserts.neverPartOfCompilation(); int counter = rawPeekU16(data, counterOffset); + int profile = rawPeekU16(data, profileOffset); + /* + * Even if the total hit counter has already reached the limit, we need to increment the + * branch profile counter from 0 to 1 iff it's still 0 to mark the branch as having been + * taken at least once, to prevent recurrent deoptimizations due to profileBranchTable + * assuming that a value of 0 means the branch has never been reached. + * + * Similarly, we need to make sure we never increase any branch counter to the max value, + * otherwise we can get into a situation where both the branch and the total counter values + * are at the max value that we cannot recover from since we never decrease counter values; + * profileBranchTable would then deoptimize every time that branch is not taken (see below). + */ + assert profile != MAX_TABLE_PROFILE_VALUE; if (counter < MAX_TABLE_PROFILE_VALUE) { BinaryStreamParser.writeU16(data, counterOffset, counter + 1); - BinaryStreamParser.writeU16(data, profileOffset, rawPeekU16(data, profileOffset) + 1); + } + if ((counter < MAX_TABLE_PROFILE_VALUE || profile == 0) && (profile < MAX_TABLE_PROFILE_VALUE - 1)) { + BinaryStreamParser.writeU16(data, profileOffset, profile + 1); } }