From f41fd4b78c32b4d5a56da234e6115dbb884c5d96 Mon Sep 17 00:00:00 2001 From: Stephen Amar Date: Sat, 18 Jan 2025 15:31:23 -0800 Subject: [PATCH] Add support for tailstrict (#257) With this PR, I'm implementing tailstrict in sjsonnet. I'm following the guidance from https://github.com/google/jsonnet/issues/343#issuecomment-325987638: Quote: 1. If you call a function with tailstrict annotation on the apply AST, e.g. foo(42) tailstrict then the evaluation of the function body happens in "tail strict mode". The annotated AST need not actually be a tail call. Also, the arguments of this function are forced. 2. When another function call (can be a completely different function to the one that was originally called) is made when we're evaluating in "tail strict mode" and it is a tail call, then the current stack frame is re-used for the next call, rather than being pushed on top of. 3. Note that in order to preserve the tail strict mode into the new function, the new call AST has to be tailstrict as well. Resolves #189. --- build.sbt | 1 + build.sc | 4 +- sjsonnet/src/sjsonnet/Evaluator.scala | 85 ++++++++++++++----- sjsonnet/src/sjsonnet/Expr.scala | 10 +-- sjsonnet/src/sjsonnet/ExprTransform.scala | 20 ++--- sjsonnet/src/sjsonnet/Parser.scala | 4 +- sjsonnet/src/sjsonnet/StaticOptimizer.scala | 20 ++--- sjsonnet/src/sjsonnet/Std.scala | 1 + sjsonnet/src/sjsonnet/Val.scala | 23 ++++- .../recursive_function_native.jsonnet | 38 +++++++++ .../test/src-jvm/sjsonnet/FileTests.scala | 2 +- .../test/src-native/sjsonnet/FileTests.scala | 2 +- sjsonnet/test/src/sjsonnet/FormatTests.scala | 1 + 13 files changed, 160 insertions(+), 51 deletions(-) create mode 100644 sjsonnet/test/resources/test_suite/recursive_function_native.jsonnet diff --git a/build.sbt b/build.sbt index 1a6a9913..2d4db46b 100644 --- a/build.sbt +++ b/build.sbt @@ -8,6 +8,7 @@ lazy val main = (project in file("sjsonnet")) .settings( Compile / scalacOptions ++= Seq("-opt:l:inline", "-opt-inline-from:sjsonnet.*,sjsonnet.**"), Test / fork := true, + Test / javaOptions += "-Xss100m", Test / baseDirectory := (ThisBuild / baseDirectory).value, libraryDependencies ++= Seq( "com.lihaoyi" %% "fastparse" % "2.3.3", diff --git a/build.sc b/build.sc index 041b1876..88457977 100644 --- a/build.sc +++ b/build.sc @@ -111,10 +111,10 @@ object sjsonnet extends Module { ivy"org.yaml:snakeyaml::1.33", ivy"com.google.re2j:re2j:1.7", ) - def scalacOptions = Seq("-opt:l:inline", "-opt-inline-from:sjsonnet.**") + def scalacOptions = Seq("-opt:l:inline", "-opt-inline-from:sjsonnet.*,sjsonnet.**") object test extends ScalaTests with CrossTests { - def forkOptions = Seq("-Xss100m") + def forkArgs = Seq("-Xss100m") def sources = T.sources( this.millSourcePath / "src", this.millSourcePath / "src-jvm", diff --git a/sjsonnet/src/sjsonnet/Evaluator.scala b/sjsonnet/src/sjsonnet/Evaluator.scala index 765705e8..5c3a7f90 100644 --- a/sjsonnet/src/sjsonnet/Evaluator.scala +++ b/sjsonnet/src/sjsonnet/Evaluator.scala @@ -4,6 +4,7 @@ import Expr.{Error => _, _} import sjsonnet.Expr.Member.Visibility import ujson.Value +import scala.annotation.tailrec import scala.collection.mutable /** @@ -27,8 +28,9 @@ class Evaluator(resolver: CachedResolver, def materialize(v: Val): Value = Materializer.apply(v) val cachedImports = collection.mutable.HashMap.empty[Path, Val] + var tailstrict: Boolean = false - def visitExpr(e: Expr)(implicit scope: ValScope): Val = try { + override def visitExpr(e: Expr)(implicit scope: ValScope): Val = try { e match { case e: ValidId => visitValidId(e) case e: BinaryOp => visitBinaryOp(e) @@ -184,40 +186,84 @@ class Evaluator(resolver: CachedResolver, private def visitApply(e: Apply)(implicit scope: ValScope) = { val lhs = visitExpr(e.value) - val args = e.args - val argsL = new Array[Lazy](args.length) - var idx = 0 - while (idx < args.length) { - argsL(idx) = visitAsLazy(args(idx)) - idx += 1 + + if (tailstrict) { + lhs.cast[Val.Func].apply(e.args.map(visitExpr(_)), e.namedNames, e.pos) + } else if (e.tailstrict) { + tailstrict = true + val res = lhs.cast[Val.Func].apply(e.args.map(visitExpr(_)), e.namedNames, e.pos) + tailstrict = false + res + } else { + val args = e.args + val argsL = new Array[Lazy](args.length) + var idx = 0 + while (idx < args.length) { + argsL(idx) = visitAsLazy(args(idx)) + idx += 1 + } + lhs.cast[Val.Func].apply(argsL, e.namedNames, e.pos) } - lhs.cast[Val.Func].apply(argsL, e.namedNames, e.pos) } private def visitApply0(e: Apply0)(implicit scope: ValScope): Val = { val lhs = visitExpr(e.value) - lhs.cast[Val.Func].apply0(e.pos) + if (e.tailstrict) { + tailstrict = true + val res = lhs.cast[Val.Func].apply0(e.pos) + tailstrict = false + res + } else { + lhs.cast[Val.Func].apply0(e.pos) + } } private def visitApply1(e: Apply1)(implicit scope: ValScope): Val = { val lhs = visitExpr(e.value) - val l1 = visitAsLazy(e.a1) - lhs.cast[Val.Func].apply1(l1, e.pos) + if (tailstrict) { + lhs.cast[Val.Func].apply1(visitExpr(e.a1), e.pos) + } else if (e.tailstrict) { + tailstrict = true + val res = lhs.cast[Val.Func].apply1(visitExpr(e.a1), e.pos) + tailstrict = false + res + } else { + val l1 = visitAsLazy(e.a1) + lhs.cast[Val.Func].apply1(l1, e.pos) + } } private def visitApply2(e: Apply2)(implicit scope: ValScope): Val = { val lhs = visitExpr(e.value) - val l1 = visitAsLazy(e.a1) - val l2 = visitAsLazy(e.a2) - lhs.cast[Val.Func].apply2(l1, l2, e.pos) + if (tailstrict) { + lhs.cast[Val.Func].apply2(visitExpr(e.a1), visitExpr(e.a2), e.pos) + } else if (e.tailstrict) { + tailstrict = true + val res = lhs.cast[Val.Func].apply2(visitExpr(e.a1), visitExpr(e.a2), e.pos) + tailstrict = false + res + } else { + val l1 = visitAsLazy(e.a1) + val l2 = visitAsLazy(e.a2) + lhs.cast[Val.Func].apply2(l1, l2, e.pos) + } } private def visitApply3(e: Apply3)(implicit scope: ValScope): Val = { val lhs = visitExpr(e.value) - val l1 = visitAsLazy(e.a1) - val l2 = visitAsLazy(e.a2) - val l3 = visitAsLazy(e.a3) - lhs.cast[Val.Func].apply3(l1, l2, l3, e.pos) + if (tailstrict) { + lhs.cast[Val.Func].apply3(visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3), e.pos) + } else if (e.tailstrict) { + tailstrict = true + val res = lhs.cast[Val.Func].apply3(visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3), e.pos) + tailstrict = false + res + } else { + val l1 = visitAsLazy(e.a1) + val l2 = visitAsLazy(e.a2) + val l3 = visitAsLazy(e.a3) + lhs.cast[Val.Func].apply3(l1, l2, l3, e.pos) + } } private def visitApplyBuiltin1(e: ApplyBuiltin1)(implicit scope: ValScope) = @@ -653,7 +699,8 @@ class Evaluator(resolver: CachedResolver, newSelf } - def visitComp(f: List[CompSpec], scopes: Array[ValScope]): Array[ValScope] = f match{ + @tailrec + private final def visitComp(f: List[CompSpec], scopes: Array[ValScope]): Array[ValScope] = f match{ case (spec @ ForSpec(_, name, expr)) :: rest => visitComp( rest, diff --git a/sjsonnet/src/sjsonnet/Expr.scala b/sjsonnet/src/sjsonnet/Expr.scala index 14643887..e014492e 100644 --- a/sjsonnet/src/sjsonnet/Expr.scala +++ b/sjsonnet/src/sjsonnet/Expr.scala @@ -128,11 +128,11 @@ object Expr{ case class ImportStr(pos: Position, value: String) extends Expr case class ImportBin(pos: Position, value: String) extends Expr case class Error(pos: Position, value: Expr) extends Expr - case class Apply(pos: Position, value: Expr, args: Array[Expr], namedNames: Array[String]) extends Expr - case class Apply0(pos: Position, value: Expr) extends Expr - case class Apply1(pos: Position, value: Expr, a1: Expr) extends Expr - case class Apply2(pos: Position, value: Expr, a1: Expr, a2: Expr) extends Expr - case class Apply3(pos: Position, value: Expr, a1: Expr, a2: Expr, a3: Expr) extends Expr + case class Apply(pos: Position, value: Expr, args: Array[Expr], namedNames: Array[String], tailstrict: Boolean) extends Expr + case class Apply0(pos: Position, value: Expr, tailstrict: Boolean) extends Expr + case class Apply1(pos: Position, value: Expr, a1: Expr, tailstrict: Boolean) extends Expr + case class Apply2(pos: Position, value: Expr, a1: Expr, a2: Expr, tailstrict: Boolean) extends Expr + case class Apply3(pos: Position, value: Expr, a1: Expr, a2: Expr, a3: Expr, tailstrict: Boolean) extends Expr case class ApplyBuiltin(pos: Position, func: Val.Builtin, argExprs: Array[Expr]) extends Expr { override def exprErrorString: String = s"std.${func.functionName}" } diff --git a/sjsonnet/src/sjsonnet/ExprTransform.scala b/sjsonnet/src/sjsonnet/ExprTransform.scala index 964833ee..d7d0657d 100644 --- a/sjsonnet/src/sjsonnet/ExprTransform.scala +++ b/sjsonnet/src/sjsonnet/ExprTransform.scala @@ -14,37 +14,37 @@ abstract class ExprTransform { if(x2 eq x) expr else Select(pos, x2, name) - case Apply(pos, x, y, namedNames) => + case Apply(pos, x, y, namedNames, tailstrict) => val x2 = transform(x) val y2 = transformArr(y) if((x2 eq x) && (y2 eq y)) expr - else Apply(pos, x2, y2, namedNames) + else Apply(pos, x2, y2, namedNames, tailstrict) - case Apply0(pos, x) => + case Apply0(pos, x, tailstrict) => val x2 = transform(x) if((x2 eq x)) expr - else Apply0(pos, x2) + else Apply0(pos, x2, tailstrict) - case Apply1(pos, x, y) => + case Apply1(pos, x, y, tailstrict) => val x2 = transform(x) val y2 = transform(y) if((x2 eq x) && (y2 eq y)) expr - else Apply1(pos, x2, y2) + else Apply1(pos, x2, y2, tailstrict) - case Apply2(pos, x, y, z) => + case Apply2(pos, x, y, z, tailstrict) => val x2 = transform(x) val y2 = transform(y) val z2 = transform(z) if((x2 eq x) && (y2 eq y) && (z2 eq z)) expr - else Apply2(pos, x2, y2, z2) + else Apply2(pos, x2, y2, z2, tailstrict) - case Apply3(pos, x, y, z, a) => + case Apply3(pos, x, y, z, a, tailstrict) => val x2 = transform(x) val y2 = transform(y) val z2 = transform(z) val a2 = transform(a) if((x2 eq x) && (y2 eq y) && (z2 eq z) && (a2 eq a)) expr - else Apply3(pos, x2, y2, z2, a2) + else Apply3(pos, x2, y2, z2, a2, tailstrict) case ApplyBuiltin(pos, func, x) => val x2 = transformArr(x) diff --git a/sjsonnet/src/sjsonnet/Parser.scala b/sjsonnet/src/sjsonnet/Parser.scala index 1c06df6a..ffc01dfb 100644 --- a/sjsonnet/src/sjsonnet/Parser.scala +++ b/sjsonnet/src/sjsonnet/Parser.scala @@ -227,8 +227,8 @@ class Parser(val currentFile: Path, case (Some(tree), Seq()) => Expr.Lookup(i, _: Expr, tree) case (start, ins) => Expr.Slice(i, _: Expr, start, ins.lift(0).flatten, ins.lift(1).flatten) } - case '(' => Pass ~ (args ~ ")").map { case (args, namedNames) => - Expr.Apply(i, _: Expr, args, if(namedNames.length == 0) null else namedNames) + case '(' => Pass ~ (args ~ ")" ~ "tailstrict".!.?).map { + case (args, namedNames, tailstrict) => Expr.Apply(i, _: Expr, args, if(namedNames.length == 0) null else namedNames, tailstrict.nonEmpty) } case '{' => Pass ~ (objinside ~ "}").map(x => Expr.ObjExtend(i, _: Expr, x)) case _ => Fail diff --git a/sjsonnet/src/sjsonnet/StaticOptimizer.scala b/sjsonnet/src/sjsonnet/StaticOptimizer.scala index 6e1e2854..694ef7fb 100644 --- a/sjsonnet/src/sjsonnet/StaticOptimizer.scala +++ b/sjsonnet/src/sjsonnet/StaticOptimizer.scala @@ -111,7 +111,7 @@ class StaticOptimizer( } private def transformApply(a: Apply): Expr = { - val rebound = rebindApply(a.pos, a.value, a.args, a.namedNames) match { + val rebound = rebindApply(a.pos, a.value, a.args, a.namedNames, a.tailstrict) match { case null => a case a => a } @@ -121,7 +121,7 @@ class StaticOptimizer( } } - private def tryStaticApply(pos: Position, f: Val.Builtin, args: Array[Expr]): Expr = { + private def tryStaticApply(pos: Position, f: Val.Builtin, args: Array[Expr], tailstrict: Boolean): Expr = { if(f.staticSafe && args.forall(_.isInstanceOf[Val])) { val vargs = args.map(_.asInstanceOf[Val]) try f.apply(vargs, null, pos)(ev).asInstanceOf[Expr] catch { case _: Exception => return null } @@ -131,20 +131,20 @@ class StaticOptimizer( private def specializeApplyArity(a: Apply): Expr = { if(a.namedNames != null) a else a.args.length match { - case 0 => Apply0(a.pos, a.value) - case 1 => Apply1(a.pos, a.value, a.args(0)) - case 2 => Apply2(a.pos, a.value, a.args(0), a.args(1)) - case 3 => Apply3(a.pos, a.value, a.args(0), a.args(1), a.args(2)) + case 0 => Apply0(a.pos, a.value, a.tailstrict) + case 1 => Apply1(a.pos, a.value, a.args(0), a.tailstrict) + case 2 => Apply2(a.pos, a.value, a.args(0), a.args(1), a.tailstrict) + case 3 => Apply3(a.pos, a.value, a.args(0), a.args(1), a.args(2), a.tailstrict) case _ => a } } - private def rebindApply(pos: Position, lhs: Expr, args: Array[Expr], names: Array[String]): Expr = lhs match { + private def rebindApply(pos: Position, lhs: Expr, args: Array[Expr], names: Array[String], tailstrict: Boolean): Expr = lhs match { case f: Val.Builtin => rebind(args, names, f.params) match { case null => null case newArgs => - tryStaticApply(pos, f, newArgs) match { + tryStaticApply(pos, f, newArgs, tailstrict) match { case null => val (f2, rargs) = f.specialize(newArgs) match { case null => (f, newArgs) @@ -166,12 +166,12 @@ class StaticOptimizer( case ScopedVal(Function(_, params, _), _, _) => rebind(args, names, params) match { case null => null - case newArgs => Apply(pos, lhs, newArgs, null) + case newArgs => Apply(pos, lhs, newArgs, null, tailstrict) } case ScopedVal(Bind(_, _, params, _), _, _) => rebind(args, names, params) match { case null => null - case newArgs => Apply(pos, lhs, newArgs, null) + case newArgs => Apply(pos, lhs, newArgs, null, tailstrict) } case _ => null } diff --git a/sjsonnet/src/sjsonnet/Std.scala b/sjsonnet/src/sjsonnet/Std.scala index 1bfe6846..caa3fb84 100644 --- a/sjsonnet/src/sjsonnet/Std.scala +++ b/sjsonnet/src/sjsonnet/Std.scala @@ -650,6 +650,7 @@ class Std(private val additionalNativeFunctions: Map[String, Val.Builtin] = Map. q.removeFirst().force match { case v: Val.Arr => v.asLazyArray.reverseIterator.foreach(q.push) case s: Val.Str => out.write(s.value) + case _ => Error.fail("Cannot call deepJoin on " + value.prettyName) } } Val.Str(pos, out.toString) diff --git a/sjsonnet/src/sjsonnet/Val.scala b/sjsonnet/src/sjsonnet/Val.scala index 33400201..8fe2225c 100644 --- a/sjsonnet/src/sjsonnet/Val.scala +++ b/sjsonnet/src/sjsonnet/Val.scala @@ -465,7 +465,10 @@ object Val{ val simple = namedNames == null && params.names.length == argsL.length val funDefFileScope: FileScope = pos match { case null => outerPos.fileScope case p => p.fileScope } //println(s"apply: argsL: ${argsL.length}, namedNames: $namedNames, paramNames: ${params.names.mkString(",")}") - if(simple) { + if (simple || ev.tailstrict) { + if (ev.tailstrict) { + argsL.foreach(_.force) + } val newScope = defSiteValScope.extendSimple(argsL) evalRhs(newScope, ev, funDefFileScope, outerPos) } else { @@ -530,6 +533,9 @@ object Val{ if(params.names.length != 1) apply(Array(argVal), null, outerPos) else { val funDefFileScope: FileScope = pos match { case null => outerPos.fileScope case p => p.fileScope } + if (ev.tailstrict) { + argVal.force + } val newScope: ValScope = defSiteValScope.extendSimple(argVal) evalRhs(newScope, ev, funDefFileScope, outerPos) } @@ -539,6 +545,10 @@ object Val{ if(params.names.length != 2) apply(Array(argVal1, argVal2), null, outerPos) else { val funDefFileScope: FileScope = pos match { case null => outerPos.fileScope case p => p.fileScope } + if (ev.tailstrict) { + argVal1.force + argVal2.force + } val newScope: ValScope = defSiteValScope.extendSimple(argVal1, argVal2) evalRhs(newScope, ev, funDefFileScope, outerPos) } @@ -548,6 +558,11 @@ object Val{ if(params.names.length != 3) apply(Array(argVal1, argVal2, argVal3), null, outerPos) else { val funDefFileScope: FileScope = pos match { case null => outerPos.fileScope case p => p.fileScope } + if (ev.tailstrict) { + argVal1.force + argVal2.force + argVal3.force + } val newScope: ValScope = defSiteValScope.extendSimple(argVal1, argVal2, argVal3) evalRhs(newScope, ev, funDefFileScope, outerPos) } @@ -575,6 +590,10 @@ object Val{ if(params.names.length != 2) apply(Array(argVal1, argVal2), null, outerPos) else evalRhs(Array(argVal1, argVal2), ev, outerPos) + override def apply3(argVal1: Lazy, argVal2: Lazy, argVal3: Lazy, outerPos: Position)(implicit ev: EvalScope): Val = + if(params.names.length != 3) apply(Array(argVal1, argVal2, argVal3), null, outerPos) + else evalRhs(Array(argVal1, argVal2, argVal3), ev, outerPos) + /** Specialize a call to this function in the optimizer. Must return either `null` to leave the * call-site as it is or a pair of a (possibly different) `Builtin` and the arguments to pass * to it (usually a subset of the supplied arguments). @@ -647,6 +666,8 @@ object Val{ * throughout the Jsonnet evaluation. */ abstract class EvalScope extends EvalErrorScope with Ordering[Val] { + def tailstrict: Boolean + def visitExpr(expr: Expr) (implicit scope: ValScope): Val diff --git a/sjsonnet/test/resources/test_suite/recursive_function_native.jsonnet b/sjsonnet/test/resources/test_suite/recursive_function_native.jsonnet new file mode 100644 index 00000000..beaaead0 --- /dev/null +++ b/sjsonnet/test/resources/test_suite/recursive_function_native.jsonnet @@ -0,0 +1,38 @@ +/* +Copyright 2015 Google Inc. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +local fibonacci(n) = + if n <= 1 then + 1 + else + fibonacci(n - 1) + fibonacci(n - 2); + +std.assertEqual(fibonacci(15), 987) && + +// Tail recursive call +local sum(x, v) = + if x <= 0 then + v + else + sum(x - 1, x + v) tailstrict; + +// Scala Native is struggling with large stacks. +local sz = 1000; +std.assertEqual(sum(sz, 0), sz * (sz + 1) / 2) && + +std.assertEqual(local x() = 3; x() tailstrict, 3) && + +true diff --git a/sjsonnet/test/src-jvm/sjsonnet/FileTests.scala b/sjsonnet/test/src-jvm/sjsonnet/FileTests.scala index 395ef37e..688133c2 100644 --- a/sjsonnet/test/src-jvm/sjsonnet/FileTests.scala +++ b/sjsonnet/test/src-jvm/sjsonnet/FileTests.scala @@ -53,7 +53,7 @@ object FileTests extends TestSuite{ test("oop_extra") - check() test("parsing_edge_cases") - check() test("precedence") - check() -// test("recursive_function") - check() + test("recursive_function") - check() test("recursive_import_ok") - check() test("recursive_object") - check() test("regex") - check() diff --git a/sjsonnet/test/src-native/sjsonnet/FileTests.scala b/sjsonnet/test/src-native/sjsonnet/FileTests.scala index 01330941..bd4ff5de 100644 --- a/sjsonnet/test/src-native/sjsonnet/FileTests.scala +++ b/sjsonnet/test/src-native/sjsonnet/FileTests.scala @@ -53,7 +53,7 @@ object FileTests extends TestSuite{ test("oop_extra") - check() test("parsing_edge_cases") - check() test("precedence") - check() -// test("recursive_function") - check() + test("recursive_function_native") - check() test("recursive_import_ok") - check() test("recursive_object") - check() test("sanity") - checkGolden() diff --git a/sjsonnet/test/src/sjsonnet/FormatTests.scala b/sjsonnet/test/src/sjsonnet/FormatTests.scala index 936d8bfe..1ba6357d 100644 --- a/sjsonnet/test/src/sjsonnet/FormatTests.scala +++ b/sjsonnet/test/src/sjsonnet/FormatTests.scala @@ -10,6 +10,7 @@ object FormatTests extends TestSuite{ val json = ujson.read(jsonStr) val formatted = Format.format(fmt, Materializer.reverse(null, json), dummyPos)( new EvalScope{ + def tailstrict: Boolean = false def extVars = _ => None def wd: Path = DummyPath() def visitExpr(expr: Expr)(implicit scope: ValScope): Val = ???