Skip to content

Commit

Permalink
Add support for tailstrict (#257)
Browse files Browse the repository at this point in the history
With this PR, I'm implementing tailstrict in sjsonnet.
I'm following the guidance from
google/jsonnet#343 (comment):

Quote:
1. If you call a function with tailstrict annotation on the apply AST,
e.g. foo(42) tailstrict then the evaluation of the function body happens
in "tail strict mode". The annotated AST need not actually be a tail
call. Also, the arguments of this function are forced.
2. When another function call (can be a completely different function to
the one that was originally called) is made when we're evaluating in
"tail strict mode" and it is a tail call, then the current stack frame
is re-used for the next call, rather than being pushed on top of.
3. Note that in order to preserve the tail strict mode into the new
function, the new call AST has to be tailstrict as well.

Resolves #189.
  • Loading branch information
stephenamar-db authored Jan 18, 2025
1 parent b025dbd commit f41fd4b
Show file tree
Hide file tree
Showing 13 changed files with 160 additions and 51 deletions.
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ lazy val main = (project in file("sjsonnet"))
.settings(
Compile / scalacOptions ++= Seq("-opt:l:inline", "-opt-inline-from:sjsonnet.*,sjsonnet.**"),
Test / fork := true,
Test / javaOptions += "-Xss100m",
Test / baseDirectory := (ThisBuild / baseDirectory).value,
libraryDependencies ++= Seq(
"com.lihaoyi" %% "fastparse" % "2.3.3",
Expand Down
4 changes: 2 additions & 2 deletions build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ object sjsonnet extends Module {
ivy"org.yaml:snakeyaml::1.33",
ivy"com.google.re2j:re2j:1.7",
)
def scalacOptions = Seq("-opt:l:inline", "-opt-inline-from:sjsonnet.**")
def scalacOptions = Seq("-opt:l:inline", "-opt-inline-from:sjsonnet.*,sjsonnet.**")

object test extends ScalaTests with CrossTests {
def forkOptions = Seq("-Xss100m")
def forkArgs = Seq("-Xss100m")
def sources = T.sources(
this.millSourcePath / "src",
this.millSourcePath / "src-jvm",
Expand Down
85 changes: 66 additions & 19 deletions sjsonnet/src/sjsonnet/Evaluator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import Expr.{Error => _, _}
import sjsonnet.Expr.Member.Visibility
import ujson.Value

import scala.annotation.tailrec
import scala.collection.mutable

/**
Expand All @@ -27,8 +28,9 @@ class Evaluator(resolver: CachedResolver,

def materialize(v: Val): Value = Materializer.apply(v)
val cachedImports = collection.mutable.HashMap.empty[Path, Val]
var tailstrict: Boolean = false

def visitExpr(e: Expr)(implicit scope: ValScope): Val = try {
override def visitExpr(e: Expr)(implicit scope: ValScope): Val = try {
e match {
case e: ValidId => visitValidId(e)
case e: BinaryOp => visitBinaryOp(e)
Expand Down Expand Up @@ -184,40 +186,84 @@ class Evaluator(resolver: CachedResolver,

private def visitApply(e: Apply)(implicit scope: ValScope) = {
val lhs = visitExpr(e.value)
val args = e.args
val argsL = new Array[Lazy](args.length)
var idx = 0
while (idx < args.length) {
argsL(idx) = visitAsLazy(args(idx))
idx += 1

if (tailstrict) {
lhs.cast[Val.Func].apply(e.args.map(visitExpr(_)), e.namedNames, e.pos)
} else if (e.tailstrict) {
tailstrict = true
val res = lhs.cast[Val.Func].apply(e.args.map(visitExpr(_)), e.namedNames, e.pos)
tailstrict = false
res
} else {
val args = e.args
val argsL = new Array[Lazy](args.length)
var idx = 0
while (idx < args.length) {
argsL(idx) = visitAsLazy(args(idx))
idx += 1
}
lhs.cast[Val.Func].apply(argsL, e.namedNames, e.pos)
}
lhs.cast[Val.Func].apply(argsL, e.namedNames, e.pos)
}

private def visitApply0(e: Apply0)(implicit scope: ValScope): Val = {
val lhs = visitExpr(e.value)
lhs.cast[Val.Func].apply0(e.pos)
if (e.tailstrict) {
tailstrict = true
val res = lhs.cast[Val.Func].apply0(e.pos)
tailstrict = false
res
} else {
lhs.cast[Val.Func].apply0(e.pos)
}
}

private def visitApply1(e: Apply1)(implicit scope: ValScope): Val = {
val lhs = visitExpr(e.value)
val l1 = visitAsLazy(e.a1)
lhs.cast[Val.Func].apply1(l1, e.pos)
if (tailstrict) {
lhs.cast[Val.Func].apply1(visitExpr(e.a1), e.pos)
} else if (e.tailstrict) {
tailstrict = true
val res = lhs.cast[Val.Func].apply1(visitExpr(e.a1), e.pos)
tailstrict = false
res
} else {
val l1 = visitAsLazy(e.a1)
lhs.cast[Val.Func].apply1(l1, e.pos)
}
}

private def visitApply2(e: Apply2)(implicit scope: ValScope): Val = {
val lhs = visitExpr(e.value)
val l1 = visitAsLazy(e.a1)
val l2 = visitAsLazy(e.a2)
lhs.cast[Val.Func].apply2(l1, l2, e.pos)
if (tailstrict) {
lhs.cast[Val.Func].apply2(visitExpr(e.a1), visitExpr(e.a2), e.pos)
} else if (e.tailstrict) {
tailstrict = true
val res = lhs.cast[Val.Func].apply2(visitExpr(e.a1), visitExpr(e.a2), e.pos)
tailstrict = false
res
} else {
val l1 = visitAsLazy(e.a1)
val l2 = visitAsLazy(e.a2)
lhs.cast[Val.Func].apply2(l1, l2, e.pos)
}
}

private def visitApply3(e: Apply3)(implicit scope: ValScope): Val = {
val lhs = visitExpr(e.value)
val l1 = visitAsLazy(e.a1)
val l2 = visitAsLazy(e.a2)
val l3 = visitAsLazy(e.a3)
lhs.cast[Val.Func].apply3(l1, l2, l3, e.pos)
if (tailstrict) {
lhs.cast[Val.Func].apply3(visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3), e.pos)
} else if (e.tailstrict) {
tailstrict = true
val res = lhs.cast[Val.Func].apply3(visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3), e.pos)
tailstrict = false
res
} else {
val l1 = visitAsLazy(e.a1)
val l2 = visitAsLazy(e.a2)
val l3 = visitAsLazy(e.a3)
lhs.cast[Val.Func].apply3(l1, l2, l3, e.pos)
}
}

private def visitApplyBuiltin1(e: ApplyBuiltin1)(implicit scope: ValScope) =
Expand Down Expand Up @@ -653,7 +699,8 @@ class Evaluator(resolver: CachedResolver,
newSelf
}

def visitComp(f: List[CompSpec], scopes: Array[ValScope]): Array[ValScope] = f match{
@tailrec
private final def visitComp(f: List[CompSpec], scopes: Array[ValScope]): Array[ValScope] = f match{
case (spec @ ForSpec(_, name, expr)) :: rest =>
visitComp(
rest,
Expand Down
10 changes: 5 additions & 5 deletions sjsonnet/src/sjsonnet/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,11 @@ object Expr{
case class ImportStr(pos: Position, value: String) extends Expr
case class ImportBin(pos: Position, value: String) extends Expr
case class Error(pos: Position, value: Expr) extends Expr
case class Apply(pos: Position, value: Expr, args: Array[Expr], namedNames: Array[String]) extends Expr
case class Apply0(pos: Position, value: Expr) extends Expr
case class Apply1(pos: Position, value: Expr, a1: Expr) extends Expr
case class Apply2(pos: Position, value: Expr, a1: Expr, a2: Expr) extends Expr
case class Apply3(pos: Position, value: Expr, a1: Expr, a2: Expr, a3: Expr) extends Expr
case class Apply(pos: Position, value: Expr, args: Array[Expr], namedNames: Array[String], tailstrict: Boolean) extends Expr
case class Apply0(pos: Position, value: Expr, tailstrict: Boolean) extends Expr
case class Apply1(pos: Position, value: Expr, a1: Expr, tailstrict: Boolean) extends Expr
case class Apply2(pos: Position, value: Expr, a1: Expr, a2: Expr, tailstrict: Boolean) extends Expr
case class Apply3(pos: Position, value: Expr, a1: Expr, a2: Expr, a3: Expr, tailstrict: Boolean) extends Expr
case class ApplyBuiltin(pos: Position, func: Val.Builtin, argExprs: Array[Expr]) extends Expr {
override def exprErrorString: String = s"std.${func.functionName}"
}
Expand Down
20 changes: 10 additions & 10 deletions sjsonnet/src/sjsonnet/ExprTransform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,37 @@ abstract class ExprTransform {
if(x2 eq x) expr
else Select(pos, x2, name)

case Apply(pos, x, y, namedNames) =>
case Apply(pos, x, y, namedNames, tailstrict) =>
val x2 = transform(x)
val y2 = transformArr(y)
if((x2 eq x) && (y2 eq y)) expr
else Apply(pos, x2, y2, namedNames)
else Apply(pos, x2, y2, namedNames, tailstrict)

case Apply0(pos, x) =>
case Apply0(pos, x, tailstrict) =>
val x2 = transform(x)
if((x2 eq x)) expr
else Apply0(pos, x2)
else Apply0(pos, x2, tailstrict)

case Apply1(pos, x, y) =>
case Apply1(pos, x, y, tailstrict) =>
val x2 = transform(x)
val y2 = transform(y)
if((x2 eq x) && (y2 eq y)) expr
else Apply1(pos, x2, y2)
else Apply1(pos, x2, y2, tailstrict)

case Apply2(pos, x, y, z) =>
case Apply2(pos, x, y, z, tailstrict) =>
val x2 = transform(x)
val y2 = transform(y)
val z2 = transform(z)
if((x2 eq x) && (y2 eq y) && (z2 eq z)) expr
else Apply2(pos, x2, y2, z2)
else Apply2(pos, x2, y2, z2, tailstrict)

case Apply3(pos, x, y, z, a) =>
case Apply3(pos, x, y, z, a, tailstrict) =>
val x2 = transform(x)
val y2 = transform(y)
val z2 = transform(z)
val a2 = transform(a)
if((x2 eq x) && (y2 eq y) && (z2 eq z) && (a2 eq a)) expr
else Apply3(pos, x2, y2, z2, a2)
else Apply3(pos, x2, y2, z2, a2, tailstrict)

case ApplyBuiltin(pos, func, x) =>
val x2 = transformArr(x)
Expand Down
4 changes: 2 additions & 2 deletions sjsonnet/src/sjsonnet/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ class Parser(val currentFile: Path,
case (Some(tree), Seq()) => Expr.Lookup(i, _: Expr, tree)
case (start, ins) => Expr.Slice(i, _: Expr, start, ins.lift(0).flatten, ins.lift(1).flatten)
}
case '(' => Pass ~ (args ~ ")").map { case (args, namedNames) =>
Expr.Apply(i, _: Expr, args, if(namedNames.length == 0) null else namedNames)
case '(' => Pass ~ (args ~ ")" ~ "tailstrict".!.?).map {
case (args, namedNames, tailstrict) => Expr.Apply(i, _: Expr, args, if(namedNames.length == 0) null else namedNames, tailstrict.nonEmpty)
}
case '{' => Pass ~ (objinside ~ "}").map(x => Expr.ObjExtend(i, _: Expr, x))
case _ => Fail
Expand Down
20 changes: 10 additions & 10 deletions sjsonnet/src/sjsonnet/StaticOptimizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class StaticOptimizer(
}

private def transformApply(a: Apply): Expr = {
val rebound = rebindApply(a.pos, a.value, a.args, a.namedNames) match {
val rebound = rebindApply(a.pos, a.value, a.args, a.namedNames, a.tailstrict) match {
case null => a
case a => a
}
Expand All @@ -121,7 +121,7 @@ class StaticOptimizer(
}
}

private def tryStaticApply(pos: Position, f: Val.Builtin, args: Array[Expr]): Expr = {
private def tryStaticApply(pos: Position, f: Val.Builtin, args: Array[Expr], tailstrict: Boolean): Expr = {
if(f.staticSafe && args.forall(_.isInstanceOf[Val])) {
val vargs = args.map(_.asInstanceOf[Val])
try f.apply(vargs, null, pos)(ev).asInstanceOf[Expr] catch { case _: Exception => return null }
Expand All @@ -131,20 +131,20 @@ class StaticOptimizer(
private def specializeApplyArity(a: Apply): Expr = {
if(a.namedNames != null) a
else a.args.length match {
case 0 => Apply0(a.pos, a.value)
case 1 => Apply1(a.pos, a.value, a.args(0))
case 2 => Apply2(a.pos, a.value, a.args(0), a.args(1))
case 3 => Apply3(a.pos, a.value, a.args(0), a.args(1), a.args(2))
case 0 => Apply0(a.pos, a.value, a.tailstrict)
case 1 => Apply1(a.pos, a.value, a.args(0), a.tailstrict)
case 2 => Apply2(a.pos, a.value, a.args(0), a.args(1), a.tailstrict)
case 3 => Apply3(a.pos, a.value, a.args(0), a.args(1), a.args(2), a.tailstrict)
case _ => a
}
}

private def rebindApply(pos: Position, lhs: Expr, args: Array[Expr], names: Array[String]): Expr = lhs match {
private def rebindApply(pos: Position, lhs: Expr, args: Array[Expr], names: Array[String], tailstrict: Boolean): Expr = lhs match {
case f: Val.Builtin =>
rebind(args, names, f.params) match {
case null => null
case newArgs =>
tryStaticApply(pos, f, newArgs) match {
tryStaticApply(pos, f, newArgs, tailstrict) match {
case null =>
val (f2, rargs) = f.specialize(newArgs) match {
case null => (f, newArgs)
Expand All @@ -166,12 +166,12 @@ class StaticOptimizer(
case ScopedVal(Function(_, params, _), _, _) =>
rebind(args, names, params) match {
case null => null
case newArgs => Apply(pos, lhs, newArgs, null)
case newArgs => Apply(pos, lhs, newArgs, null, tailstrict)
}
case ScopedVal(Bind(_, _, params, _), _, _) =>
rebind(args, names, params) match {
case null => null
case newArgs => Apply(pos, lhs, newArgs, null)
case newArgs => Apply(pos, lhs, newArgs, null, tailstrict)
}
case _ => null
}
Expand Down
1 change: 1 addition & 0 deletions sjsonnet/src/sjsonnet/Std.scala
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ class Std(private val additionalNativeFunctions: Map[String, Val.Builtin] = Map.
q.removeFirst().force match {
case v: Val.Arr => v.asLazyArray.reverseIterator.foreach(q.push)
case s: Val.Str => out.write(s.value)
case _ => Error.fail("Cannot call deepJoin on " + value.prettyName)
}
}
Val.Str(pos, out.toString)
Expand Down
23 changes: 22 additions & 1 deletion sjsonnet/src/sjsonnet/Val.scala
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,10 @@ object Val{
val simple = namedNames == null && params.names.length == argsL.length
val funDefFileScope: FileScope = pos match { case null => outerPos.fileScope case p => p.fileScope }
//println(s"apply: argsL: ${argsL.length}, namedNames: $namedNames, paramNames: ${params.names.mkString(",")}")
if(simple) {
if (simple || ev.tailstrict) {
if (ev.tailstrict) {
argsL.foreach(_.force)
}
val newScope = defSiteValScope.extendSimple(argsL)
evalRhs(newScope, ev, funDefFileScope, outerPos)
} else {
Expand Down Expand Up @@ -530,6 +533,9 @@ object Val{
if(params.names.length != 1) apply(Array(argVal), null, outerPos)
else {
val funDefFileScope: FileScope = pos match { case null => outerPos.fileScope case p => p.fileScope }
if (ev.tailstrict) {
argVal.force
}
val newScope: ValScope = defSiteValScope.extendSimple(argVal)
evalRhs(newScope, ev, funDefFileScope, outerPos)
}
Expand All @@ -539,6 +545,10 @@ object Val{
if(params.names.length != 2) apply(Array(argVal1, argVal2), null, outerPos)
else {
val funDefFileScope: FileScope = pos match { case null => outerPos.fileScope case p => p.fileScope }
if (ev.tailstrict) {
argVal1.force
argVal2.force
}
val newScope: ValScope = defSiteValScope.extendSimple(argVal1, argVal2)
evalRhs(newScope, ev, funDefFileScope, outerPos)
}
Expand All @@ -548,6 +558,11 @@ object Val{
if(params.names.length != 3) apply(Array(argVal1, argVal2, argVal3), null, outerPos)
else {
val funDefFileScope: FileScope = pos match { case null => outerPos.fileScope case p => p.fileScope }
if (ev.tailstrict) {
argVal1.force
argVal2.force
argVal3.force
}
val newScope: ValScope = defSiteValScope.extendSimple(argVal1, argVal2, argVal3)
evalRhs(newScope, ev, funDefFileScope, outerPos)
}
Expand Down Expand Up @@ -575,6 +590,10 @@ object Val{
if(params.names.length != 2) apply(Array(argVal1, argVal2), null, outerPos)
else evalRhs(Array(argVal1, argVal2), ev, outerPos)

override def apply3(argVal1: Lazy, argVal2: Lazy, argVal3: Lazy, outerPos: Position)(implicit ev: EvalScope): Val =
if(params.names.length != 3) apply(Array(argVal1, argVal2, argVal3), null, outerPos)
else evalRhs(Array(argVal1, argVal2, argVal3), ev, outerPos)

/** Specialize a call to this function in the optimizer. Must return either `null` to leave the
* call-site as it is or a pair of a (possibly different) `Builtin` and the arguments to pass
* to it (usually a subset of the supplied arguments).
Expand Down Expand Up @@ -647,6 +666,8 @@ object Val{
* throughout the Jsonnet evaluation.
*/
abstract class EvalScope extends EvalErrorScope with Ordering[Val] {
def tailstrict: Boolean

def visitExpr(expr: Expr)
(implicit scope: ValScope): Val

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
Copyright 2015 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

local fibonacci(n) =
if n <= 1 then
1
else
fibonacci(n - 1) + fibonacci(n - 2);

std.assertEqual(fibonacci(15), 987) &&

// Tail recursive call
local sum(x, v) =
if x <= 0 then
v
else
sum(x - 1, x + v) tailstrict;

// Scala Native is struggling with large stacks.
local sz = 1000;
std.assertEqual(sum(sz, 0), sz * (sz + 1) / 2) &&

std.assertEqual(local x() = 3; x() tailstrict, 3) &&

true
Loading

0 comments on commit f41fd4b

Please sign in to comment.