diff --git a/build.sbt b/build.sbt index 1a6a9913..2d4db46b 100644 --- a/build.sbt +++ b/build.sbt @@ -8,6 +8,7 @@ lazy val main = (project in file("sjsonnet")) .settings( Compile / scalacOptions ++= Seq("-opt:l:inline", "-opt-inline-from:sjsonnet.*,sjsonnet.**"), Test / fork := true, + Test / javaOptions += "-Xss100m", Test / baseDirectory := (ThisBuild / baseDirectory).value, libraryDependencies ++= Seq( "com.lihaoyi" %% "fastparse" % "2.3.3", diff --git a/build.sc b/build.sc index 041b1876..88457977 100644 --- a/build.sc +++ b/build.sc @@ -111,10 +111,10 @@ object sjsonnet extends Module { ivy"org.yaml:snakeyaml::1.33", ivy"com.google.re2j:re2j:1.7", ) - def scalacOptions = Seq("-opt:l:inline", "-opt-inline-from:sjsonnet.**") + def scalacOptions = Seq("-opt:l:inline", "-opt-inline-from:sjsonnet.*,sjsonnet.**") object test extends ScalaTests with CrossTests { - def forkOptions = Seq("-Xss100m") + def forkArgs = Seq("-Xss100m") def sources = T.sources( this.millSourcePath / "src", this.millSourcePath / "src-jvm", diff --git a/sjsonnet/src/sjsonnet/Evaluator.scala b/sjsonnet/src/sjsonnet/Evaluator.scala index 765705e8..5c3a7f90 100644 --- a/sjsonnet/src/sjsonnet/Evaluator.scala +++ b/sjsonnet/src/sjsonnet/Evaluator.scala @@ -4,6 +4,7 @@ import Expr.{Error => _, _} import sjsonnet.Expr.Member.Visibility import ujson.Value +import scala.annotation.tailrec import scala.collection.mutable /** @@ -27,8 +28,9 @@ class Evaluator(resolver: CachedResolver, def materialize(v: Val): Value = Materializer.apply(v) val cachedImports = collection.mutable.HashMap.empty[Path, Val] + var tailstrict: Boolean = false - def visitExpr(e: Expr)(implicit scope: ValScope): Val = try { + override def visitExpr(e: Expr)(implicit scope: ValScope): Val = try { e match { case e: ValidId => visitValidId(e) case e: BinaryOp => visitBinaryOp(e) @@ -184,40 +186,84 @@ class Evaluator(resolver: CachedResolver, private def visitApply(e: Apply)(implicit scope: ValScope) = { val lhs = visitExpr(e.value) - val args = e.args - val argsL = new Array[Lazy](args.length) - var idx = 0 - while (idx < args.length) { - argsL(idx) = visitAsLazy(args(idx)) - idx += 1 + + if (tailstrict) { + lhs.cast[Val.Func].apply(e.args.map(visitExpr(_)), e.namedNames, e.pos) + } else if (e.tailstrict) { + tailstrict = true + val res = lhs.cast[Val.Func].apply(e.args.map(visitExpr(_)), e.namedNames, e.pos) + tailstrict = false + res + } else { + val args = e.args + val argsL = new Array[Lazy](args.length) + var idx = 0 + while (idx < args.length) { + argsL(idx) = visitAsLazy(args(idx)) + idx += 1 + } + lhs.cast[Val.Func].apply(argsL, e.namedNames, e.pos) } - lhs.cast[Val.Func].apply(argsL, e.namedNames, e.pos) } private def visitApply0(e: Apply0)(implicit scope: ValScope): Val = { val lhs = visitExpr(e.value) - lhs.cast[Val.Func].apply0(e.pos) + if (e.tailstrict) { + tailstrict = true + val res = lhs.cast[Val.Func].apply0(e.pos) + tailstrict = false + res + } else { + lhs.cast[Val.Func].apply0(e.pos) + } } private def visitApply1(e: Apply1)(implicit scope: ValScope): Val = { val lhs = visitExpr(e.value) - val l1 = visitAsLazy(e.a1) - lhs.cast[Val.Func].apply1(l1, e.pos) + if (tailstrict) { + lhs.cast[Val.Func].apply1(visitExpr(e.a1), e.pos) + } else if (e.tailstrict) { + tailstrict = true + val res = lhs.cast[Val.Func].apply1(visitExpr(e.a1), e.pos) + tailstrict = false + res + } else { + val l1 = visitAsLazy(e.a1) + lhs.cast[Val.Func].apply1(l1, e.pos) + } } private def visitApply2(e: Apply2)(implicit scope: ValScope): Val = { val lhs = visitExpr(e.value) - val l1 = visitAsLazy(e.a1) - val l2 = visitAsLazy(e.a2) - lhs.cast[Val.Func].apply2(l1, l2, e.pos) + if (tailstrict) { + lhs.cast[Val.Func].apply2(visitExpr(e.a1), visitExpr(e.a2), e.pos) + } else if (e.tailstrict) { + tailstrict = true + val res = lhs.cast[Val.Func].apply2(visitExpr(e.a1), visitExpr(e.a2), e.pos) + tailstrict = false + res + } else { + val l1 = visitAsLazy(e.a1) + val l2 = visitAsLazy(e.a2) + lhs.cast[Val.Func].apply2(l1, l2, e.pos) + } } private def visitApply3(e: Apply3)(implicit scope: ValScope): Val = { val lhs = visitExpr(e.value) - val l1 = visitAsLazy(e.a1) - val l2 = visitAsLazy(e.a2) - val l3 = visitAsLazy(e.a3) - lhs.cast[Val.Func].apply3(l1, l2, l3, e.pos) + if (tailstrict) { + lhs.cast[Val.Func].apply3(visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3), e.pos) + } else if (e.tailstrict) { + tailstrict = true + val res = lhs.cast[Val.Func].apply3(visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3), e.pos) + tailstrict = false + res + } else { + val l1 = visitAsLazy(e.a1) + val l2 = visitAsLazy(e.a2) + val l3 = visitAsLazy(e.a3) + lhs.cast[Val.Func].apply3(l1, l2, l3, e.pos) + } } private def visitApplyBuiltin1(e: ApplyBuiltin1)(implicit scope: ValScope) = @@ -653,7 +699,8 @@ class Evaluator(resolver: CachedResolver, newSelf } - def visitComp(f: List[CompSpec], scopes: Array[ValScope]): Array[ValScope] = f match{ + @tailrec + private final def visitComp(f: List[CompSpec], scopes: Array[ValScope]): Array[ValScope] = f match{ case (spec @ ForSpec(_, name, expr)) :: rest => visitComp( rest, diff --git a/sjsonnet/src/sjsonnet/Expr.scala b/sjsonnet/src/sjsonnet/Expr.scala index 14643887..e014492e 100644 --- a/sjsonnet/src/sjsonnet/Expr.scala +++ b/sjsonnet/src/sjsonnet/Expr.scala @@ -128,11 +128,11 @@ object Expr{ case class ImportStr(pos: Position, value: String) extends Expr case class ImportBin(pos: Position, value: String) extends Expr case class Error(pos: Position, value: Expr) extends Expr - case class Apply(pos: Position, value: Expr, args: Array[Expr], namedNames: Array[String]) extends Expr - case class Apply0(pos: Position, value: Expr) extends Expr - case class Apply1(pos: Position, value: Expr, a1: Expr) extends Expr - case class Apply2(pos: Position, value: Expr, a1: Expr, a2: Expr) extends Expr - case class Apply3(pos: Position, value: Expr, a1: Expr, a2: Expr, a3: Expr) extends Expr + case class Apply(pos: Position, value: Expr, args: Array[Expr], namedNames: Array[String], tailstrict: Boolean) extends Expr + case class Apply0(pos: Position, value: Expr, tailstrict: Boolean) extends Expr + case class Apply1(pos: Position, value: Expr, a1: Expr, tailstrict: Boolean) extends Expr + case class Apply2(pos: Position, value: Expr, a1: Expr, a2: Expr, tailstrict: Boolean) extends Expr + case class Apply3(pos: Position, value: Expr, a1: Expr, a2: Expr, a3: Expr, tailstrict: Boolean) extends Expr case class ApplyBuiltin(pos: Position, func: Val.Builtin, argExprs: Array[Expr]) extends Expr { override def exprErrorString: String = s"std.${func.functionName}" } diff --git a/sjsonnet/src/sjsonnet/ExprTransform.scala b/sjsonnet/src/sjsonnet/ExprTransform.scala index 964833ee..d7d0657d 100644 --- a/sjsonnet/src/sjsonnet/ExprTransform.scala +++ b/sjsonnet/src/sjsonnet/ExprTransform.scala @@ -14,37 +14,37 @@ abstract class ExprTransform { if(x2 eq x) expr else Select(pos, x2, name) - case Apply(pos, x, y, namedNames) => + case Apply(pos, x, y, namedNames, tailstrict) => val x2 = transform(x) val y2 = transformArr(y) if((x2 eq x) && (y2 eq y)) expr - else Apply(pos, x2, y2, namedNames) + else Apply(pos, x2, y2, namedNames, tailstrict) - case Apply0(pos, x) => + case Apply0(pos, x, tailstrict) => val x2 = transform(x) if((x2 eq x)) expr - else Apply0(pos, x2) + else Apply0(pos, x2, tailstrict) - case Apply1(pos, x, y) => + case Apply1(pos, x, y, tailstrict) => val x2 = transform(x) val y2 = transform(y) if((x2 eq x) && (y2 eq y)) expr - else Apply1(pos, x2, y2) + else Apply1(pos, x2, y2, tailstrict) - case Apply2(pos, x, y, z) => + case Apply2(pos, x, y, z, tailstrict) => val x2 = transform(x) val y2 = transform(y) val z2 = transform(z) if((x2 eq x) && (y2 eq y) && (z2 eq z)) expr - else Apply2(pos, x2, y2, z2) + else Apply2(pos, x2, y2, z2, tailstrict) - case Apply3(pos, x, y, z, a) => + case Apply3(pos, x, y, z, a, tailstrict) => val x2 = transform(x) val y2 = transform(y) val z2 = transform(z) val a2 = transform(a) if((x2 eq x) && (y2 eq y) && (z2 eq z) && (a2 eq a)) expr - else Apply3(pos, x2, y2, z2, a2) + else Apply3(pos, x2, y2, z2, a2, tailstrict) case ApplyBuiltin(pos, func, x) => val x2 = transformArr(x) diff --git a/sjsonnet/src/sjsonnet/Parser.scala b/sjsonnet/src/sjsonnet/Parser.scala index 1c06df6a..ffc01dfb 100644 --- a/sjsonnet/src/sjsonnet/Parser.scala +++ b/sjsonnet/src/sjsonnet/Parser.scala @@ -227,8 +227,8 @@ class Parser(val currentFile: Path, case (Some(tree), Seq()) => Expr.Lookup(i, _: Expr, tree) case (start, ins) => Expr.Slice(i, _: Expr, start, ins.lift(0).flatten, ins.lift(1).flatten) } - case '(' => Pass ~ (args ~ ")").map { case (args, namedNames) => - Expr.Apply(i, _: Expr, args, if(namedNames.length == 0) null else namedNames) + case '(' => Pass ~ (args ~ ")" ~ "tailstrict".!.?).map { + case (args, namedNames, tailstrict) => Expr.Apply(i, _: Expr, args, if(namedNames.length == 0) null else namedNames, tailstrict.nonEmpty) } case '{' => Pass ~ (objinside ~ "}").map(x => Expr.ObjExtend(i, _: Expr, x)) case _ => Fail diff --git a/sjsonnet/src/sjsonnet/StaticOptimizer.scala b/sjsonnet/src/sjsonnet/StaticOptimizer.scala index 6e1e2854..694ef7fb 100644 --- a/sjsonnet/src/sjsonnet/StaticOptimizer.scala +++ b/sjsonnet/src/sjsonnet/StaticOptimizer.scala @@ -111,7 +111,7 @@ class StaticOptimizer( } private def transformApply(a: Apply): Expr = { - val rebound = rebindApply(a.pos, a.value, a.args, a.namedNames) match { + val rebound = rebindApply(a.pos, a.value, a.args, a.namedNames, a.tailstrict) match { case null => a case a => a } @@ -121,7 +121,7 @@ class StaticOptimizer( } } - private def tryStaticApply(pos: Position, f: Val.Builtin, args: Array[Expr]): Expr = { + private def tryStaticApply(pos: Position, f: Val.Builtin, args: Array[Expr], tailstrict: Boolean): Expr = { if(f.staticSafe && args.forall(_.isInstanceOf[Val])) { val vargs = args.map(_.asInstanceOf[Val]) try f.apply(vargs, null, pos)(ev).asInstanceOf[Expr] catch { case _: Exception => return null } @@ -131,20 +131,20 @@ class StaticOptimizer( private def specializeApplyArity(a: Apply): Expr = { if(a.namedNames != null) a else a.args.length match { - case 0 => Apply0(a.pos, a.value) - case 1 => Apply1(a.pos, a.value, a.args(0)) - case 2 => Apply2(a.pos, a.value, a.args(0), a.args(1)) - case 3 => Apply3(a.pos, a.value, a.args(0), a.args(1), a.args(2)) + case 0 => Apply0(a.pos, a.value, a.tailstrict) + case 1 => Apply1(a.pos, a.value, a.args(0), a.tailstrict) + case 2 => Apply2(a.pos, a.value, a.args(0), a.args(1), a.tailstrict) + case 3 => Apply3(a.pos, a.value, a.args(0), a.args(1), a.args(2), a.tailstrict) case _ => a } } - private def rebindApply(pos: Position, lhs: Expr, args: Array[Expr], names: Array[String]): Expr = lhs match { + private def rebindApply(pos: Position, lhs: Expr, args: Array[Expr], names: Array[String], tailstrict: Boolean): Expr = lhs match { case f: Val.Builtin => rebind(args, names, f.params) match { case null => null case newArgs => - tryStaticApply(pos, f, newArgs) match { + tryStaticApply(pos, f, newArgs, tailstrict) match { case null => val (f2, rargs) = f.specialize(newArgs) match { case null => (f, newArgs) @@ -166,12 +166,12 @@ class StaticOptimizer( case ScopedVal(Function(_, params, _), _, _) => rebind(args, names, params) match { case null => null - case newArgs => Apply(pos, lhs, newArgs, null) + case newArgs => Apply(pos, lhs, newArgs, null, tailstrict) } case ScopedVal(Bind(_, _, params, _), _, _) => rebind(args, names, params) match { case null => null - case newArgs => Apply(pos, lhs, newArgs, null) + case newArgs => Apply(pos, lhs, newArgs, null, tailstrict) } case _ => null } diff --git a/sjsonnet/src/sjsonnet/Std.scala b/sjsonnet/src/sjsonnet/Std.scala index 1bfe6846..caa3fb84 100644 --- a/sjsonnet/src/sjsonnet/Std.scala +++ b/sjsonnet/src/sjsonnet/Std.scala @@ -650,6 +650,7 @@ class Std(private val additionalNativeFunctions: Map[String, Val.Builtin] = Map. q.removeFirst().force match { case v: Val.Arr => v.asLazyArray.reverseIterator.foreach(q.push) case s: Val.Str => out.write(s.value) + case _ => Error.fail("Cannot call deepJoin on " + value.prettyName) } } Val.Str(pos, out.toString) diff --git a/sjsonnet/src/sjsonnet/Val.scala b/sjsonnet/src/sjsonnet/Val.scala index 33400201..8fe2225c 100644 --- a/sjsonnet/src/sjsonnet/Val.scala +++ b/sjsonnet/src/sjsonnet/Val.scala @@ -465,7 +465,10 @@ object Val{ val simple = namedNames == null && params.names.length == argsL.length val funDefFileScope: FileScope = pos match { case null => outerPos.fileScope case p => p.fileScope } //println(s"apply: argsL: ${argsL.length}, namedNames: $namedNames, paramNames: ${params.names.mkString(",")}") - if(simple) { + if (simple || ev.tailstrict) { + if (ev.tailstrict) { + argsL.foreach(_.force) + } val newScope = defSiteValScope.extendSimple(argsL) evalRhs(newScope, ev, funDefFileScope, outerPos) } else { @@ -530,6 +533,9 @@ object Val{ if(params.names.length != 1) apply(Array(argVal), null, outerPos) else { val funDefFileScope: FileScope = pos match { case null => outerPos.fileScope case p => p.fileScope } + if (ev.tailstrict) { + argVal.force + } val newScope: ValScope = defSiteValScope.extendSimple(argVal) evalRhs(newScope, ev, funDefFileScope, outerPos) } @@ -539,6 +545,10 @@ object Val{ if(params.names.length != 2) apply(Array(argVal1, argVal2), null, outerPos) else { val funDefFileScope: FileScope = pos match { case null => outerPos.fileScope case p => p.fileScope } + if (ev.tailstrict) { + argVal1.force + argVal2.force + } val newScope: ValScope = defSiteValScope.extendSimple(argVal1, argVal2) evalRhs(newScope, ev, funDefFileScope, outerPos) } @@ -548,6 +558,11 @@ object Val{ if(params.names.length != 3) apply(Array(argVal1, argVal2, argVal3), null, outerPos) else { val funDefFileScope: FileScope = pos match { case null => outerPos.fileScope case p => p.fileScope } + if (ev.tailstrict) { + argVal1.force + argVal2.force + argVal3.force + } val newScope: ValScope = defSiteValScope.extendSimple(argVal1, argVal2, argVal3) evalRhs(newScope, ev, funDefFileScope, outerPos) } @@ -575,6 +590,10 @@ object Val{ if(params.names.length != 2) apply(Array(argVal1, argVal2), null, outerPos) else evalRhs(Array(argVal1, argVal2), ev, outerPos) + override def apply3(argVal1: Lazy, argVal2: Lazy, argVal3: Lazy, outerPos: Position)(implicit ev: EvalScope): Val = + if(params.names.length != 3) apply(Array(argVal1, argVal2, argVal3), null, outerPos) + else evalRhs(Array(argVal1, argVal2, argVal3), ev, outerPos) + /** Specialize a call to this function in the optimizer. Must return either `null` to leave the * call-site as it is or a pair of a (possibly different) `Builtin` and the arguments to pass * to it (usually a subset of the supplied arguments). @@ -647,6 +666,8 @@ object Val{ * throughout the Jsonnet evaluation. */ abstract class EvalScope extends EvalErrorScope with Ordering[Val] { + def tailstrict: Boolean + def visitExpr(expr: Expr) (implicit scope: ValScope): Val diff --git a/sjsonnet/test/resources/test_suite/recursive_function_native.jsonnet b/sjsonnet/test/resources/test_suite/recursive_function_native.jsonnet new file mode 100644 index 00000000..beaaead0 --- /dev/null +++ b/sjsonnet/test/resources/test_suite/recursive_function_native.jsonnet @@ -0,0 +1,38 @@ +/* +Copyright 2015 Google Inc. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +local fibonacci(n) = + if n <= 1 then + 1 + else + fibonacci(n - 1) + fibonacci(n - 2); + +std.assertEqual(fibonacci(15), 987) && + +// Tail recursive call +local sum(x, v) = + if x <= 0 then + v + else + sum(x - 1, x + v) tailstrict; + +// Scala Native is struggling with large stacks. +local sz = 1000; +std.assertEqual(sum(sz, 0), sz * (sz + 1) / 2) && + +std.assertEqual(local x() = 3; x() tailstrict, 3) && + +true diff --git a/sjsonnet/test/src-jvm/sjsonnet/FileTests.scala b/sjsonnet/test/src-jvm/sjsonnet/FileTests.scala index 395ef37e..688133c2 100644 --- a/sjsonnet/test/src-jvm/sjsonnet/FileTests.scala +++ b/sjsonnet/test/src-jvm/sjsonnet/FileTests.scala @@ -53,7 +53,7 @@ object FileTests extends TestSuite{ test("oop_extra") - check() test("parsing_edge_cases") - check() test("precedence") - check() -// test("recursive_function") - check() + test("recursive_function") - check() test("recursive_import_ok") - check() test("recursive_object") - check() test("regex") - check() diff --git a/sjsonnet/test/src-native/sjsonnet/FileTests.scala b/sjsonnet/test/src-native/sjsonnet/FileTests.scala index 01330941..bd4ff5de 100644 --- a/sjsonnet/test/src-native/sjsonnet/FileTests.scala +++ b/sjsonnet/test/src-native/sjsonnet/FileTests.scala @@ -53,7 +53,7 @@ object FileTests extends TestSuite{ test("oop_extra") - check() test("parsing_edge_cases") - check() test("precedence") - check() -// test("recursive_function") - check() + test("recursive_function_native") - check() test("recursive_import_ok") - check() test("recursive_object") - check() test("sanity") - checkGolden() diff --git a/sjsonnet/test/src/sjsonnet/FormatTests.scala b/sjsonnet/test/src/sjsonnet/FormatTests.scala index 936d8bfe..1ba6357d 100644 --- a/sjsonnet/test/src/sjsonnet/FormatTests.scala +++ b/sjsonnet/test/src/sjsonnet/FormatTests.scala @@ -10,6 +10,7 @@ object FormatTests extends TestSuite{ val json = ujson.read(jsonStr) val formatted = Format.format(fmt, Materializer.reverse(null, json), dummyPos)( new EvalScope{ + def tailstrict: Boolean = false def extVars = _ => None def wd: Path = DummyPath() def visitExpr(expr: Expr)(implicit scope: ValScope): Val = ???