Skip to content

Commit

Permalink
Uncached threshold: explicitly support Integer.MIN_VALUE as a way to …
Browse files Browse the repository at this point in the history
…force uncached.
  • Loading branch information
DSouzaM committed Sep 6, 2024
1 parent 186271f commit b2da720
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,6 @@ public class SimpleBytecodeBenchmark extends TruffleBenchmark {
private static final Source SOURCE_MANUAL_NODED_NO_BE = Source.create("bm", NAME_MANUAL_NODED_NO_BE);
private static final Source SOURCE_AST = Source.create("bm", NAME_AST);

// Keep the uncached interpreter around so we can manually reset its invocation threshold.
private static BytecodeBenchmarkRootNode bytecodeUncachedRootNode;

/**
* The benchmark programs implement:
*
Expand Down Expand Up @@ -317,7 +314,7 @@ private static BytecodeParser<BytecodeBenchmarkRootNodeBuilder> createBytecodeDS

BytecodeBenchmarkRootNode root = b.endRoot();
if (forceUncached) {
root.getBytecodeNode().setUncachedThreshold(Integer.MAX_VALUE);
root.getBytecodeNode().setUncachedThreshold(Integer.MIN_VALUE);
}
};
}
Expand Down Expand Up @@ -406,18 +403,6 @@ public void enterContext() {
context.enter();
}

@Setup(Level.Invocation)
public void resetThreshold() {
/**
* Ensure the invocation threshold does not get hit. The number of loop back-edges is
* several orders of magnitude less than this threshold, so it should never transition to
* the cached interpreter.
*/
if (bytecodeUncachedRootNode != null) {
bytecodeUncachedRootNode.getBytecodeNode().setUncachedThreshold(Integer.MAX_VALUE);
}
}

@TearDown(Level.Iteration)
public void leaveContext() {
context.leave();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
import com.oracle.truffle.api.bytecode.BytecodeNode;
import com.oracle.truffle.api.bytecode.BytecodeRootNode;
import com.oracle.truffle.api.bytecode.BytecodeRootNodes;
import com.oracle.truffle.api.bytecode.BytecodeTier;
import com.oracle.truffle.api.bytecode.ContinuationResult;
import com.oracle.truffle.api.bytecode.ExceptionHandler;
import com.oracle.truffle.api.bytecode.Instruction;
Expand Down Expand Up @@ -2087,6 +2088,151 @@ public void testTooManyStackValues() {

}

@Test
public void testTransitionToCached() {
assumeTrue(run.hasUncachedInterpreter());
BasicInterpreter node = parseNode("transitionToCached", b -> {
b.beginRoot();
b.beginReturn();
b.emitLoadConstant(42L);
b.endReturn();
b.endRoot();
});

node.getBytecodeNode().setUncachedThreshold(50);
for (int i = 0; i < 50; i++) {
assertEquals(BytecodeTier.UNCACHED, node.getBytecodeNode().getTier());
assertEquals(42L, node.getCallTarget().call());
}
assertEquals(BytecodeTier.CACHED, node.getBytecodeNode().getTier());
assertEquals(42L, node.getCallTarget().call());
}

@Test
public void testTransitionToCachedImmediately() {
assumeTrue(run.hasUncachedInterpreter());
BasicInterpreter node = parseNode("transitionToCachedImmediately", b -> {
b.beginRoot();
b.beginReturn();
b.emitLoadConstant(42L);
b.endReturn();
b.endRoot();
});

node.getBytecodeNode().setUncachedThreshold(0);
// The bytecode node will transition to cached on the first call.
assertEquals(BytecodeTier.UNCACHED, node.getBytecodeNode().getTier());
assertEquals(42L, node.getCallTarget().call());
assertEquals(BytecodeTier.CACHED, node.getBytecodeNode().getTier());
}

@Test
public void testTransitionToCachedBadThreshold() {
assumeTrue(run.hasUncachedInterpreter());
BasicInterpreter node = parseNode("transitionToCachedBadThreshold", b -> {
b.beginRoot();
b.beginReturn();
b.emitLoadConstant(42L);
b.endReturn();
b.endRoot();
});

assertThrows(IllegalArgumentException.class, () -> node.getBytecodeNode().setUncachedThreshold(-1));
}

@Test
public void testTransitionToCachedLoop() {
assumeTrue(run.hasUncachedInterpreter());
BasicInterpreter node = parseNode("transitionToCachedLoop", b -> {
b.beginRoot();
BytecodeLocal i = b.createLocal();
b.beginStoreLocal(i);
b.emitLoadConstant(0L);
b.endStoreLocal();

b.beginWhile();
b.beginLess();
b.emitLoadLocal(i);
b.emitLoadArgument(0);
b.endLess();

b.beginStoreLocal(i);
b.beginAddConstantOperation(1L);
b.emitLoadLocal(i);
b.endAddConstantOperation();
b.endStoreLocal();
b.endWhile();

b.beginReturn();
b.emitLoadLocal(i);
b.endReturn();

b.endRoot();
});

node.getBytecodeNode().setUncachedThreshold(50);
assertEquals(BytecodeTier.UNCACHED, node.getBytecodeNode().getTier());
assertEquals(24L, node.getCallTarget().call(24L)); // 24 back edges + 1 return
assertEquals(BytecodeTier.UNCACHED, node.getBytecodeNode().getTier());
assertEquals(24L, node.getCallTarget().call(24L)); // 24 back edges + 1 return
assertEquals(BytecodeTier.CACHED, node.getBytecodeNode().getTier());
assertEquals(24L, node.getCallTarget().call(24L));
}

@Test
public void testDisableTransitionToCached() {
assumeTrue(run.hasUncachedInterpreter());
BasicInterpreter node = parseNode("disableTransitionToCached", b -> {
b.beginRoot();
b.beginReturn();
b.emitLoadConstant(42L);
b.endReturn();
b.endRoot();
});

node.getBytecodeNode().setUncachedThreshold(Integer.MIN_VALUE);
for (int i = 0; i < 50; i++) {
assertEquals(BytecodeTier.UNCACHED, node.getBytecodeNode().getTier());
assertEquals(42L, node.getCallTarget().call());
}
}

@Test
public void testDisableTransitionToCachedLoop() {
assumeTrue(run.hasUncachedInterpreter());
BasicInterpreter node = parseNode("disableTransitionToCachedLoop", b -> {
b.beginRoot();
BytecodeLocal i = b.createLocal();
b.beginStoreLocal(i);
b.emitLoadConstant(0L);
b.endStoreLocal();

b.beginWhile();
b.beginLess();
b.emitLoadLocal(i);
b.emitLoadArgument(0);
b.endLess();

b.beginStoreLocal(i);
b.beginAddConstantOperation(1L);
b.emitLoadLocal(i);
b.endAddConstantOperation();
b.endStoreLocal();
b.endWhile();

b.beginReturn();
b.emitLoadLocal(i);
b.endReturn();

b.endRoot();
});

node.getBytecodeNode().setUncachedThreshold(Integer.MIN_VALUE);
assertEquals(BytecodeTier.UNCACHED, node.getBytecodeNode().getTier());
assertEquals(50L, node.getCallTarget().call(50L));
assertEquals(BytecodeTier.UNCACHED, node.getBytecodeNode().getTier());
}

@Test
public void testIntrospectionDataInstructions() {
BasicInterpreter node = parseNode("introspectionDataInstructions", b -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,13 @@
* {@link BytecodeNode} is created and automatically replaced.
* <p>
* The {@link #getTier() tier} of a bytecode node initially always starts out as
* {@link BytecodeTier#UNCACHED}. This means that no cached nodes were created yet. It takes number
* {@link #setUncachedThreshold(int) uncached threshold} calls and back-edges for the node to
* transition to the cached tier. BY default the uncached threshold is 16 if the
* {@link GenerateBytecode#enableUncachedInterpreter() uncached generation} is enabled, and 0 if
* not. The intention of the uncached bytecode tier is to reduce footprint overhead for root nodes
* that are only executed infrequently.
* {@link BytecodeTier#UNCACHED}. This means that no cached nodes were created yet. The
* {@link #setUncachedThreshold(int) uncached threshold} determines how many calls, back-edges, and
* yields are necessary for the node to transition to the cached tier. By default the uncached
* threshold is 16 if the {@link GenerateBytecode#enableUncachedInterpreter() uncached interpreter}
* is enabled, and 0 if not (i.e., it will transition to cached on the first execution). The
* intention of the uncached bytecode tier is to reduce the footprint of root nodes that are only
* executed infrequently.
* <p>
* Since the the number of bytecodes may change between bytecode nodes of a root node, a
* bytecodeIndex returned by the the DSL is only valid for a single bytecode node, it is therefore
Expand Down Expand Up @@ -942,13 +943,17 @@ public void setLocalValueBoolean(int bytecodeIndex, Frame frame, int localOffset
public abstract List<LocalVariable> getLocals();

/**
* Sets a threshold that must be reached before the uncached interpreter switches to a cached
* interpreter. The interpreter can switch to cached when the number of times it returns,
* yields, and branches backwards exceeds the threshold.
* Sets the number of times an uncached interpreter must return, branch backwards, or yield
* before transitioning to cached. The default threshold is {@code 16}. The {@code threshold}
* should be a positive value, {@code 0}, or {@code Integer.MIN_VALUE}. A threshold of {@code 0}
* forces the uncached interpreter to transition to cached on the next invocation. A threshold
* of {@code Integer.MIN_VALUE} forces the uncached interpreter to stay uncached (i.e., it will
* not transition to cached).
* <p>
* This method has no effect if an uncached interpreter is not
* {@link GenerateBytecode#enableUncachedInterpreter enabled} or the root node has already
* switched to a specializing interpreter.
* This method should be called before executing the root node. It will not have any effect on
* an uncached interpreter that is currently executing, an interpreter that is already cached,
* or an interpreter that does not {@link GenerateBytecode#enableUncachedInterpreter enable
* uncached}.
*
* @since 24.2
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11520,6 +11520,7 @@ private CodeExecutableElement createGetTagTree() {
final class BytecodeNodeElement extends CodeTypeElement {

private static final String METADATA_FIELD_NAME = "osrMetadata_";
private static final String FORCE_UNCACHED_THRESHOLD = "Integer.MIN_VALUE";
private final InterpreterTier tier;
private final Map<InstructionModel, CodeExecutableElement> doInstructionMethods = new LinkedHashMap<>();
private final CodeTypeElement interpreterStateElement;
Expand Down Expand Up @@ -12294,13 +12295,17 @@ public int compareTo(CachedInitializationKey o) {
}

private CodeExecutableElement createSetUncachedThreshold() {
CodeExecutableElement ex = GeneratorUtils.override(types.BytecodeNode, "setUncachedThreshold", new String[]{"invocationCount"}, new TypeMirror[]{type(int.class)});
CodeExecutableElement ex = GeneratorUtils.override(types.BytecodeNode, "setUncachedThreshold", new String[]{"threshold"}, new TypeMirror[]{type(int.class)});
ElementUtils.setVisibility(ex.getModifiers(), PUBLIC);
ex.getModifiers().remove(ABSTRACT);

CodeTreeBuilder b = ex.createBuilder();
if (tier.isUncached()) {
b.startAssign("uncachedExecuteCount_").string("invocationCount").end();
b.tree(createNeverPartOfCompilation());
b.startIf().string("threshold < 0 && threshold != ", FORCE_UNCACHED_THRESHOLD).end().startBlock();
emitThrow(b, IllegalArgumentException.class, "\"threshold cannot be a negative value other than " + FORCE_UNCACHED_THRESHOLD + "\"");
b.end();
b.startAssign("uncachedExecuteCount_").string("threshold").end();
} else {
// do nothing for cached
}
Expand Down Expand Up @@ -12338,7 +12343,7 @@ private List<CodeExecutableElement> createContinueAt() {
b.startTryBlock();

b.statement("int uncachedExecuteCount = this.uncachedExecuteCount_");
b.startIf().string("uncachedExecuteCount <= 0").end().startBlock();
b.startIf().string("uncachedExecuteCount <= 0 && uncachedExecuteCount != ", FORCE_UNCACHED_THRESHOLD).end().startBlock();
b.statement("$root.transitionToCached(frame, 0)");
b.startReturn().string("startState").end();
b.end();
Expand Down Expand Up @@ -12615,11 +12620,15 @@ private void buildInstructionCaseBlock(CodeTreeBuilder b, InstructionModel instr
case BRANCH_BACKWARD:
if (tier.isUncached()) {
b.statement("bci = " + readImmediate("bc", "bci", instr.getImmediate(ImmediateKind.BYTECODE_INDEX)));
b.startIf().string("--uncachedExecuteCount <= 0").end().startBlock();
b.startIf().string("uncachedExecuteCount <= 1").end().startBlock();
b.startIf().string("uncachedExecuteCount != ", FORCE_UNCACHED_THRESHOLD).end().startBlock();
b.tree(GeneratorUtils.createTransferToInterpreterAndInvalidate());
b.statement("$root.transitionToCached(frame, bci)");
b.statement("return ", encodeState("bci", "sp"));
b.end();
b.end().startElseBlock();
b.statement("uncachedExecuteCount--");
b.end();
} else {
emitReportLoopCount(b, CodeTreeBuilder.createBuilder().string("++loopCounter.value >= ").staticReference(loopCounter.asType(), "REPORT_LOOP_STRIDE").build(), true);

Expand Down Expand Up @@ -15183,11 +15192,17 @@ private static void emitReturnTopOfStack(CodeTreeBuilder b) {

private void emitBeforeReturnProfiling(CodeTreeBuilder b) {
if (tier.isUncached()) {
b.statement("uncachedExecuteCount--");
b.startIf().string("uncachedExecuteCount <= 0").end().startBlock();
b.startIf().string("uncachedExecuteCount <= 1").end().startBlock();
/*
* The force uncached check is put in here so that we don't need to check it in the
* common case (the else branch where we just decrement).
*/
b.startIf().string("uncachedExecuteCount != ", FORCE_UNCACHED_THRESHOLD).end().startBlock();
b.tree(GeneratorUtils.createTransferToInterpreterAndInvalidate());
b.statement("this.getRoot().transitionToCached(frame, bci)");
b.statement("$root.transitionToCached(frame, bci)");
b.end();
b.end().startElseBlock();
b.statement("uncachedExecuteCount--");
b.statement("this.uncachedExecuteCount_ = uncachedExecuteCount");
b.end();
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ public static void parseSL(SLLanguage language, Source source, Map<TruffleString
node.getBytecodeNode().setUncachedThreshold(0);
break;
case UNCACHED:
node.getBytecodeNode().setUncachedThreshold(Integer.MAX_VALUE);
node.getBytecodeNode().setUncachedThreshold(Integer.MIN_VALUE);
break;
default:
throw CompilerDirectives.shouldNotReachHere("Unexpected tier: " + tier);
Expand Down

0 comments on commit b2da720

Please sign in to comment.