Skip to content

Commit

Permalink
[GR-61341] Backport to 24.2: Fix br_table branch profiles.
Browse files Browse the repository at this point in the history
PullRequest: graal/19790
  • Loading branch information
woess authored and ansalond committed Jan 14, 2025
2 parents 8884e97 + 5fd921c commit 35c885b
Showing 1 changed file with 37 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -435,15 +435,17 @@ public Object executeBodyFromOffset(WasmContext context, WasmInstance instance,
stackPointer--;
int index = popInt(frame, stackPointer);
final int size = rawPeekU8(bytecode, offset);
if (index < 0 || index >= size) {
// If unsigned index is larger or equal to the table size use the
// default (last) index.
index = size - 1;
}
final int counterOffset = offset + 1;

if (CompilerDirectives.inInterpreter()) {
if (index < 0 || index >= size) {
// If unsigned index is larger or equal to the table size use the
// default (last) index.
index = size - 1;
}

final int indexOffset = offset + 3 + index * 6;
updateBranchTableProfile(bytecode, offset + 1, indexOffset + 4);
updateBranchTableProfile(bytecode, counterOffset, indexOffset + 4);
final int offsetDelta = rawPeekI32(bytecode, indexOffset);
offset = indexOffset + offsetDelta;
break;
Expand All @@ -453,29 +455,30 @@ public Object executeBodyFromOffset(WasmContext context, WasmInstance instance,
// time constants, since the loop is unrolled.
for (int i = 0; i < size; i++) {
final int indexOffset = offset + 3 + i * 6;
if (profileBranchTable(bytecode, offset + 1, indexOffset + 4, i == index)) {
if (profileBranchTable(bytecode, counterOffset, indexOffset + 4, i == index || i == size - 1)) {
final int offsetDelta = rawPeekI32(bytecode, indexOffset);
offset = indexOffset + offsetDelta;
continue loop;
}
}
throw CompilerDirectives.shouldNotReachHere("br_table");
}
enterErrorBranch();
throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, this, "Should not reach here");
}
case Bytecode.BR_TABLE_I32: {
stackPointer--;
int index = popInt(frame, stackPointer);
final int size = rawPeekI32(bytecode, offset);
if (index < 0 || index >= size) {
// If unsigned index is larger or equal to the table size use the
// default (last) index.
index = size - 1;
}
final int counterOffset = offset + 4;

if (CompilerDirectives.inInterpreter()) {
if (index < 0 || index >= size) {
// If unsigned index is larger or equal to the table size use the
// default (last) index.
index = size - 1;
}

final int indexOffset = offset + 6 + index * 6;
updateBranchTableProfile(bytecode, offset + 4, indexOffset + 4);
updateBranchTableProfile(bytecode, counterOffset, indexOffset + 4);
final int offsetDelta = rawPeekI32(bytecode, indexOffset);
offset = indexOffset + offsetDelta;
break;
Expand All @@ -485,15 +488,14 @@ public Object executeBodyFromOffset(WasmContext context, WasmInstance instance,
// time constants, since the loop is unrolled.
for (int i = 0; i < size; i++) {
final int indexOffset = offset + 6 + i * 6;
if (profileBranchTable(bytecode, offset + 1, indexOffset + 4, i == index)) {
if (profileBranchTable(bytecode, counterOffset, indexOffset + 4, i == index || i == size - 1)) {
final int offsetDelta = rawPeekI32(bytecode, indexOffset);
offset = indexOffset + offsetDelta;
continue loop;
}
}
throw CompilerDirectives.shouldNotReachHere("br_table");
}
enterErrorBranch();
throw WasmException.create(Failure.UNSPECIFIED_INTERNAL, this, "Should not reach here");
}
case Bytecode.CALL_U8:
case Bytecode.CALL_I32: {
Expand Down Expand Up @@ -4679,11 +4681,26 @@ private static boolean profileCondition(byte[] data, final int profileOffset, bo
}

private static void updateBranchTableProfile(byte[] data, final int counterOffset, final int profileOffset) {
assert CompilerDirectives.inInterpreter();
CompilerAsserts.neverPartOfCompilation();
int counter = rawPeekU16(data, counterOffset);
int profile = rawPeekU16(data, profileOffset);
/*
* Even if the total hit counter has already reached the limit, we need to increment the
* branch profile counter from 0 to 1 iff it's still 0 to mark the branch as having been
* taken at least once, to prevent recurrent deoptimizations due to profileBranchTable
* assuming that a value of 0 means the branch has never been reached.
*
* Similarly, we need to make sure we never increase any branch counter to the max value,
* otherwise we can get into a situation where both the branch and the total counter values
* are at the max value that we cannot recover from since we never decrease counter values;
* profileBranchTable would then deoptimize every time that branch is not taken (see below).
*/
assert profile != MAX_TABLE_PROFILE_VALUE;
if (counter < MAX_TABLE_PROFILE_VALUE) {
BinaryStreamParser.writeU16(data, counterOffset, counter + 1);
BinaryStreamParser.writeU16(data, profileOffset, rawPeekU16(data, profileOffset) + 1);
}
if ((counter < MAX_TABLE_PROFILE_VALUE || profile == 0) && (profile < MAX_TABLE_PROFILE_VALUE - 1)) {
BinaryStreamParser.writeU16(data, profileOffset, profile + 1);
}
}

Expand Down

0 comments on commit 35c885b

Please sign in to comment.