From 1e345b2397d4aaefc02ee2064b7369235046b024 Mon Sep 17 00:00:00 2001 From: Christian Humer Date: Mon, 13 Jan 2025 19:18:32 +0100 Subject: [PATCH] Fix state optimization with fallback implicit casts. --- .../truffle/api/dsl/test/GR61265Test.java | 92 +++++++++++++++++++ .../generator/FlatNodeGenFactory.java | 53 ++++++++--- .../dsl/processor/generator/MultiBitSet.java | 15 +++ .../processor/model/SpecializationData.java | 17 ++++ 4 files changed, 162 insertions(+), 15 deletions(-) create mode 100644 truffle/src/com.oracle.truffle.api.dsl.test/src/com/oracle/truffle/api/dsl/test/GR61265Test.java diff --git a/truffle/src/com.oracle.truffle.api.dsl.test/src/com/oracle/truffle/api/dsl/test/GR61265Test.java b/truffle/src/com.oracle.truffle.api.dsl.test/src/com/oracle/truffle/api/dsl/test/GR61265Test.java new file mode 100644 index 000000000000..9ff416fe242c --- /dev/null +++ b/truffle/src/com.oracle.truffle.api.dsl.test/src/com/oracle/truffle/api/dsl/test/GR61265Test.java @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2025, 2025, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * The Universal Permissive License (UPL), Version 1.0 + * + * Subject to the condition set forth below, permission is hereby granted to any + * person obtaining a copy of this software, associated documentation and/or + * data (collectively the "Software"), free of charge and under any and all + * copyright rights in the Software, and any and all patent rights owned or + * freely licensable by each licensor hereunder covering either (i) the + * unmodified Software as contributed to or provided by such licensor, or (ii) + * the Larger Works (as defined below), to deal in both + * + * (a) the Software, and + * + * (b) any piece of software and/or hardware listed in the lrgrwrks.txt file if + * one is included with the Software each a "Larger Work" to which the Software + * is contributed by such licensors), + * + * without restriction, including without limitation the rights to copy, create + * derivative works of, display, perform, and distribute the Software and make, + * use, sell, offer for sale, import, export, have made, and have sold the + * Software and the Larger Work(s), and to sublicense the foregoing rights on + * either these or other terms. + * + * This license is subject to the following condition: + * + * The above copyright notice and either this complete permission notice or at a + * minimum a reference to the UPL must be included in all copies or substantial + * portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +package com.oracle.truffle.api.dsl.test; + +import static org.junit.Assert.assertEquals; + +import org.junit.Test; + +import com.oracle.truffle.api.dsl.Fallback; +import com.oracle.truffle.api.dsl.ImplicitCast; +import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.dsl.TypeSystem; +import com.oracle.truffle.api.dsl.TypeSystemReference; +import com.oracle.truffle.api.nodes.Node; + +public class GR61265Test { + + @TypeSystem + abstract static class TypeSystemForFallback { + @ImplicitCast + public static long intToLong(int value) { + return value; + } + } + + @TypeSystemReference(TypeSystemForFallback.class) + @SuppressWarnings({"truffle-inlining", "unused"}) + abstract static class NodeWithFallBack extends Node { + public abstract int execute(Object o1, Object o2); + + @Specialization + int s0(long o1, long o2) { + return 1; + } + + @Specialization + int s1(long o1, String o2) { + return 2; + } + + @Fallback + int fallback(Object o1, Object o2) { + return -1; + } + } + + @Test + public void testFallbackX() { + NodeWithFallBack node = GR61265TestFactory.NodeWithFallBackNodeGen.create(); + assertEquals(2, node.execute(1, "a")); + assertEquals(-1, node.execute("a", "a")); + assertEquals(2, node.execute(1L, "a")); + } +} diff --git a/truffle/src/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/generator/FlatNodeGenFactory.java b/truffle/src/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/generator/FlatNodeGenFactory.java index fa2cf2139604..8887c02584ec 100644 --- a/truffle/src/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/generator/FlatNodeGenFactory.java +++ b/truffle/src/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/generator/FlatNodeGenFactory.java @@ -2479,6 +2479,15 @@ private List getFallbackGuards() { return fallbackState; } + private List getFallbackImplicitCastGuards() { + List fallbackState = new ArrayList<>(); + List specializations = getFallbackSpecializations(); + for (SpecializationData specialization : specializations) { + fallbackState.addAll(specialization.getImplicitTypeGuards()); + } + return fallbackState; + } + private Element createFallbackGuard(boolean inlined) { boolean frameUsed = false; @@ -2501,10 +2510,8 @@ private Element createFallbackGuard(boolean inlined) { fallbackNeedsState = false; fallbackNeedsFrame = frameUsed; - multiState.createLoad(frameState, - StateQuery.create(SpecializationActive.class, getFallbackSpecializations()), - StateQuery.create(GuardActive.class, getFallbackGuards())); // already - // loaded + multiState.createLoad(frameState, collectFallbackState()); + multiState.addParametersTo(frameState, method); frameState.addParametersTo(method, Integer.MAX_VALUE, FRAME_VALUE); @@ -4221,19 +4228,24 @@ private boolean needsUnexpectedResultException(ExecutableTypeData executedType) } } + private StateQuery[] collectFallbackState() { + StateQuery fallbackActive = StateQuery.create(SpecializationActive.class, getFallbackSpecializations()); + StateQuery fallbackGuardsActive = StateQuery.create(GuardActive.class, getFallbackGuards()); + StateQuery fallbackImplicitCasts = StateQuery.create(ImplicitCastState.class, getFallbackImplicitCastGuards()); + return new StateQuery[]{fallbackActive, fallbackGuardsActive, fallbackImplicitCasts}; + } + private CodeTree createFastPathExecute(CodeTreeBuilder parent, final ExecutableTypeData forType, SpecializationData specialization, FrameState frameState) { CodeTreeBuilder builder = parent.create(); int ifCount = 0; if (specialization.isFallback()) { - StateQuery fallbackActive = StateQuery.create(SpecializationActive.class, getFallbackSpecializations()); - StateQuery fallbackGuardsActive = StateQuery.create(GuardActive.class, getFallbackGuards()); if (fallbackNeedsState) { - builder.tree(multiState.createLoad(frameState, fallbackActive, fallbackGuardsActive)); + builder.tree(multiState.createLoad(frameState, collectFallbackState())); } builder.startIf().startCall(createFallbackName()); if (fallbackNeedsState) { - multiState.addReferencesTo(frameState, builder, fallbackActive, fallbackGuardsActive); + multiState.addReferencesTo(frameState, builder, collectFallbackState()); } if (fallbackNeedsFrame) { if (frameState.get(FRAME_VALUE) != null) { @@ -4582,14 +4594,12 @@ private CodeTree visitSpecializationGroup(CodeTreeBuilder parent, Specialization NodeExecutionMode mode = frameState.getMode(); boolean hasFallthrough = false; - boolean hasImplicitCast = false; List cachedTriples = new ArrayList<>(); for (TypeGuard guard : group.getTypeGuards()) { IfTriple triple = createTypeCheckOrCast(frameState, group, guard, mode, false, true); if (triple != null) { cachedTriples.add(triple); } - hasImplicitCast = hasImplicitCast || node.getTypeSystem().hasImplicitSourceTypes(guard.getType()); if (!mode.isGuardFallback()) { triple = createTypeCheckOrCast(frameState, group, guard, mode, true, true); if (triple != null) { @@ -4668,7 +4678,7 @@ private CodeTree visitSpecializationGroup(CodeTreeBuilder parent, Specialization cachedTriples = IfTriple.optimize(cachedTriples); - if (specialization != null && !hasImplicitCast) { + if (specialization != null) { IfTriple singleCondition = null; if (cachedTriples.size() == 1) { singleCondition = cachedTriples.get(0); @@ -4676,9 +4686,26 @@ private CodeTree visitSpecializationGroup(CodeTreeBuilder parent, Specialization if (singleCondition != null) { int index = cachedTriples.indexOf(singleCondition); CodeTreeBuilder b = new CodeTreeBuilder(parent); + b.string("!("); b.tree(createSpecializationActiveCheck(frameState, Arrays.asList(specialization))); + + /* + * We can only optimize the fallback type check away if all implicit cast bits + * were exercised. Otherwise we need to keep in all implicit cast checks. + */ + List guards = specialization.getImplicitTypeGuards(); + if (!guards.isEmpty()) { + StateQuery query = StateQuery.create(ImplicitCastState.class, guards); + CodeTree stateCheck = multiState.createContainsAll(frameState, query); + if (!stateCheck.isEmpty()) { + b.newLine(); + b.string(" && ").tree(stateCheck); + } + } + b.string(")"); + cachedTriples.set(index, new IfTriple(singleCondition.prepare, combineTrees(" && ", b.build(), singleCondition.condition), singleCondition.statements)); fallbackNeedsState = true; } @@ -7704,10 +7731,6 @@ int getAllCapacity() { return length; } - CodeTree createContainsAll(FrameState frameState, StateQuery elements) { - return createContainsImpl(all, frameState, elements); - } - List createCachedFields() { List variables = new ArrayList<>(); for (BitSet bitSet : all) { diff --git a/truffle/src/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/generator/MultiBitSet.java b/truffle/src/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/generator/MultiBitSet.java index 5bb608a14bf7..8d70cbde92a0 100644 --- a/truffle/src/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/generator/MultiBitSet.java +++ b/truffle/src/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/generator/MultiBitSet.java @@ -68,6 +68,21 @@ public int getCapacity() { return length; } + public CodeTree createContainsAll(FrameState frameState, StateQuery elements) { + CodeTreeBuilder builder = CodeTreeBuilder.createBuilder(); + String sep = ""; + for (BitSet set : sets) { + StateQuery selected = set.filter(elements); + if (!selected.isEmpty()) { + CodeTree containsAll = set.createIs(frameState, selected, selected); + builder.string(sep); + builder.tree(containsAll); + sep = " && "; + } + } + return builder.build(); + } + public CodeTree createContains(FrameState frameState, StateQuery elements) { return createContainsImpl(sets, frameState, elements); } diff --git a/truffle/src/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/model/SpecializationData.java b/truffle/src/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/model/SpecializationData.java index e16879a8d828..cd6c82441c06 100644 --- a/truffle/src/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/model/SpecializationData.java +++ b/truffle/src/com.oracle.truffle.dsl.processor/src/com/oracle/truffle/dsl/processor/model/SpecializationData.java @@ -64,6 +64,7 @@ import com.oracle.truffle.dsl.processor.generator.FlatNodeGenFactory; import com.oracle.truffle.dsl.processor.java.ElementUtils; import com.oracle.truffle.dsl.processor.parser.NodeParser; +import com.oracle.truffle.dsl.processor.parser.SpecializationGroup.TypeGuard; public final class SpecializationData extends TemplateMethod { @@ -157,6 +158,22 @@ public SpecializationData copy() { return copy; } + public List getImplicitTypeGuards() { + TypeSystemData typeSystem = getNode().getTypeSystem(); + if (typeSystem.getImplicitCasts().isEmpty()) { + return List.of(); + } + int signatureIndex = 0; + List implicitTypeChecks = new ArrayList<>(); + for (Parameter p : getDynamicParameters()) { + if (typeSystem.hasImplicitSourceTypes(p.getType())) { + implicitTypeChecks.add(new TypeGuard(typeSystem, p.getType(), signatureIndex)); + } + signatureIndex++; + } + return implicitTypeChecks; + } + public boolean isNodeReceiverVariable(VariableElement var) { if (getNode().isGenerateInline()) { Parameter p = findByVariable(var);