Skip to content

Commit

Permalink
Fix state optimization with fallback implicit casts.
Browse files Browse the repository at this point in the history
  • Loading branch information
chumer committed Jan 14, 2025
1 parent 35c885b commit 1e345b2
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright (c) 2025, 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* The Universal Permissive License (UPL), Version 1.0
*
* Subject to the condition set forth below, permission is hereby granted to any
* person obtaining a copy of this software, associated documentation and/or
* data (collectively the "Software"), free of charge and under any and all
* copyright rights in the Software, and any and all patent rights owned or
* freely licensable by each licensor hereunder covering either (i) the
* unmodified Software as contributed to or provided by such licensor, or (ii)
* the Larger Works (as defined below), to deal in both
*
* (a) the Software, and
*
* (b) any piece of software and/or hardware listed in the lrgrwrks.txt file if
* one is included with the Software each a "Larger Work" to which the Software
* is contributed by such licensors),
*
* without restriction, including without limitation the rights to copy, create
* derivative works of, display, perform, and distribute the Software and make,
* use, sell, offer for sale, import, export, have made, and have sold the
* Software and the Larger Work(s), and to sublicense the foregoing rights on
* either these or other terms.
*
* This license is subject to the following condition:
*
* The above copyright notice and either this complete permission notice or at a
* minimum a reference to the UPL must be included in all copies or substantial
* portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
package com.oracle.truffle.api.dsl.test;

import static org.junit.Assert.assertEquals;

import org.junit.Test;

import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.ImplicitCast;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.dsl.TypeSystem;
import com.oracle.truffle.api.dsl.TypeSystemReference;
import com.oracle.truffle.api.nodes.Node;

public class GR61265Test {

@TypeSystem
abstract static class TypeSystemForFallback {
@ImplicitCast
public static long intToLong(int value) {
return value;
}
}

@TypeSystemReference(TypeSystemForFallback.class)
@SuppressWarnings({"truffle-inlining", "unused"})
abstract static class NodeWithFallBack extends Node {
public abstract int execute(Object o1, Object o2);

@Specialization
int s0(long o1, long o2) {
return 1;
}

@Specialization
int s1(long o1, String o2) {
return 2;
}

@Fallback
int fallback(Object o1, Object o2) {
return -1;
}
}

@Test
public void testFallbackX() {
NodeWithFallBack node = GR61265TestFactory.NodeWithFallBackNodeGen.create();
assertEquals(2, node.execute(1, "a"));
assertEquals(-1, node.execute("a", "a"));
assertEquals(2, node.execute(1L, "a"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2479,6 +2479,15 @@ private List<GuardExpression> getFallbackGuards() {
return fallbackState;
}

private List<TypeGuard> getFallbackImplicitCastGuards() {
List<TypeGuard> fallbackState = new ArrayList<>();
List<SpecializationData> specializations = getFallbackSpecializations();
for (SpecializationData specialization : specializations) {
fallbackState.addAll(specialization.getImplicitTypeGuards());
}
return fallbackState;
}

private Element createFallbackGuard(boolean inlined) {
boolean frameUsed = false;

Expand All @@ -2501,10 +2510,8 @@ private Element createFallbackGuard(boolean inlined) {
fallbackNeedsState = false;
fallbackNeedsFrame = frameUsed;

multiState.createLoad(frameState,
StateQuery.create(SpecializationActive.class, getFallbackSpecializations()),
StateQuery.create(GuardActive.class, getFallbackGuards())); // already
// loaded
multiState.createLoad(frameState, collectFallbackState());

multiState.addParametersTo(frameState, method);
frameState.addParametersTo(method, Integer.MAX_VALUE, FRAME_VALUE);

Expand Down Expand Up @@ -4221,19 +4228,24 @@ private boolean needsUnexpectedResultException(ExecutableTypeData executedType)
}
}

private StateQuery[] collectFallbackState() {
StateQuery fallbackActive = StateQuery.create(SpecializationActive.class, getFallbackSpecializations());
StateQuery fallbackGuardsActive = StateQuery.create(GuardActive.class, getFallbackGuards());
StateQuery fallbackImplicitCasts = StateQuery.create(ImplicitCastState.class, getFallbackImplicitCastGuards());
return new StateQuery[]{fallbackActive, fallbackGuardsActive, fallbackImplicitCasts};
}

private CodeTree createFastPathExecute(CodeTreeBuilder parent, final ExecutableTypeData forType, SpecializationData specialization, FrameState frameState) {
CodeTreeBuilder builder = parent.create();
int ifCount = 0;
if (specialization.isFallback()) {
StateQuery fallbackActive = StateQuery.create(SpecializationActive.class, getFallbackSpecializations());
StateQuery fallbackGuardsActive = StateQuery.create(GuardActive.class, getFallbackGuards());

if (fallbackNeedsState) {
builder.tree(multiState.createLoad(frameState, fallbackActive, fallbackGuardsActive));
builder.tree(multiState.createLoad(frameState, collectFallbackState()));
}
builder.startIf().startCall(createFallbackName());
if (fallbackNeedsState) {
multiState.addReferencesTo(frameState, builder, fallbackActive, fallbackGuardsActive);
multiState.addReferencesTo(frameState, builder, collectFallbackState());
}
if (fallbackNeedsFrame) {
if (frameState.get(FRAME_VALUE) != null) {
Expand Down Expand Up @@ -4582,14 +4594,12 @@ private CodeTree visitSpecializationGroup(CodeTreeBuilder parent, Specialization

NodeExecutionMode mode = frameState.getMode();
boolean hasFallthrough = false;
boolean hasImplicitCast = false;
List<IfTriple> cachedTriples = new ArrayList<>();
for (TypeGuard guard : group.getTypeGuards()) {
IfTriple triple = createTypeCheckOrCast(frameState, group, guard, mode, false, true);
if (triple != null) {
cachedTriples.add(triple);
}
hasImplicitCast = hasImplicitCast || node.getTypeSystem().hasImplicitSourceTypes(guard.getType());
if (!mode.isGuardFallback()) {
triple = createTypeCheckOrCast(frameState, group, guard, mode, true, true);
if (triple != null) {
Expand Down Expand Up @@ -4668,17 +4678,34 @@ private CodeTree visitSpecializationGroup(CodeTreeBuilder parent, Specialization

cachedTriples = IfTriple.optimize(cachedTriples);

if (specialization != null && !hasImplicitCast) {
if (specialization != null) {
IfTriple singleCondition = null;
if (cachedTriples.size() == 1) {
singleCondition = cachedTriples.get(0);
}
if (singleCondition != null) {
int index = cachedTriples.indexOf(singleCondition);
CodeTreeBuilder b = new CodeTreeBuilder(parent);

b.string("!(");
b.tree(createSpecializationActiveCheck(frameState, Arrays.asList(specialization)));

/*
* We can only optimize the fallback type check away if all implicit cast bits
* were exercised. Otherwise we need to keep in all implicit cast checks.
*/
List<TypeGuard> guards = specialization.getImplicitTypeGuards();
if (!guards.isEmpty()) {
StateQuery query = StateQuery.create(ImplicitCastState.class, guards);
CodeTree stateCheck = multiState.createContainsAll(frameState, query);
if (!stateCheck.isEmpty()) {
b.newLine();
b.string(" && ").tree(stateCheck);
}
}

b.string(")");

cachedTriples.set(index, new IfTriple(singleCondition.prepare, combineTrees(" && ", b.build(), singleCondition.condition), singleCondition.statements));
fallbackNeedsState = true;
}
Expand Down Expand Up @@ -7704,10 +7731,6 @@ int getAllCapacity() {
return length;
}

CodeTree createContainsAll(FrameState frameState, StateQuery elements) {
return createContainsImpl(all, frameState, elements);
}

List<CodeVariableElement> createCachedFields() {
List<CodeVariableElement> variables = new ArrayList<>();
for (BitSet bitSet : all) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,21 @@ public int getCapacity() {
return length;
}

public CodeTree createContainsAll(FrameState frameState, StateQuery elements) {
CodeTreeBuilder builder = CodeTreeBuilder.createBuilder();
String sep = "";
for (BitSet set : sets) {
StateQuery selected = set.filter(elements);
if (!selected.isEmpty()) {
CodeTree containsAll = set.createIs(frameState, selected, selected);
builder.string(sep);
builder.tree(containsAll);
sep = " && ";
}
}
return builder.build();
}

public CodeTree createContains(FrameState frameState, StateQuery elements) {
return createContainsImpl(sets, frameState, elements);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import com.oracle.truffle.dsl.processor.generator.FlatNodeGenFactory;
import com.oracle.truffle.dsl.processor.java.ElementUtils;
import com.oracle.truffle.dsl.processor.parser.NodeParser;
import com.oracle.truffle.dsl.processor.parser.SpecializationGroup.TypeGuard;

public final class SpecializationData extends TemplateMethod {

Expand Down Expand Up @@ -157,6 +158,22 @@ public SpecializationData copy() {
return copy;
}

public List<TypeGuard> getImplicitTypeGuards() {
TypeSystemData typeSystem = getNode().getTypeSystem();
if (typeSystem.getImplicitCasts().isEmpty()) {
return List.of();
}
int signatureIndex = 0;
List<TypeGuard> implicitTypeChecks = new ArrayList<>();
for (Parameter p : getDynamicParameters()) {
if (typeSystem.hasImplicitSourceTypes(p.getType())) {
implicitTypeChecks.add(new TypeGuard(typeSystem, p.getType(), signatureIndex));
}
signatureIndex++;
}
return implicitTypeChecks;
}

public boolean isNodeReceiverVariable(VariableElement var) {
if (getNode().isGenerateInline()) {
Parameter p = findByVariable(var);
Expand Down

0 comments on commit 1e345b2

Please sign in to comment.