Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for tailstrict #257

Merged
merged 4 commits into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ lazy val main = (project in file("sjsonnet"))
.settings(
Compile / scalacOptions ++= Seq("-opt:l:inline", "-opt-inline-from:sjsonnet.*,sjsonnet.**"),
Test / fork := true,
Test / javaOptions += "-Xss100m",
Test / baseDirectory := (ThisBuild / baseDirectory).value,
libraryDependencies ++= Seq(
"com.lihaoyi" %% "fastparse" % "2.3.3",
Expand Down
4 changes: 2 additions & 2 deletions build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ object sjsonnet extends Module {
ivy"org.yaml:snakeyaml::1.33",
ivy"com.google.re2j:re2j:1.7",
)
def scalacOptions = Seq("-opt:l:inline", "-opt-inline-from:sjsonnet.**")
def scalacOptions = Seq("-opt:l:inline", "-opt-inline-from:sjsonnet.*,sjsonnet.**")

object test extends ScalaTests with CrossTests {
def forkOptions = Seq("-Xss100m")
def forkArgs = Seq("-Xss100m")
def sources = T.sources(
this.millSourcePath / "src",
this.millSourcePath / "src-jvm",
Expand Down
85 changes: 66 additions & 19 deletions sjsonnet/src/sjsonnet/Evaluator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import Expr.{Error => _, _}
import sjsonnet.Expr.Member.Visibility
import ujson.Value

import scala.annotation.tailrec
import scala.collection.mutable

/**
Expand All @@ -27,8 +28,9 @@ class Evaluator(resolver: CachedResolver,

def materialize(v: Val): Value = Materializer.apply(v)
val cachedImports = collection.mutable.HashMap.empty[Path, Val]
var tailstrict: Boolean = false

def visitExpr(e: Expr)(implicit scope: ValScope): Val = try {
override def visitExpr(e: Expr)(implicit scope: ValScope): Val = try {
e match {
case e: ValidId => visitValidId(e)
case e: BinaryOp => visitBinaryOp(e)
Expand Down Expand Up @@ -184,40 +186,84 @@ class Evaluator(resolver: CachedResolver,

private def visitApply(e: Apply)(implicit scope: ValScope) = {
val lhs = visitExpr(e.value)
val args = e.args
val argsL = new Array[Lazy](args.length)
var idx = 0
while (idx < args.length) {
argsL(idx) = visitAsLazy(args(idx))
idx += 1

if (tailstrict) {
lhs.cast[Val.Func].apply(e.args.map(visitExpr(_)), e.namedNames, e.pos)
} else if (e.tailstrict) {
tailstrict = true
val res = lhs.cast[Val.Func].apply(e.args.map(visitExpr(_)), e.namedNames, e.pos)
tailstrict = false
res
} else {
val args = e.args
val argsL = new Array[Lazy](args.length)
var idx = 0
while (idx < args.length) {
argsL(idx) = visitAsLazy(args(idx))
idx += 1
}
lhs.cast[Val.Func].apply(argsL, e.namedNames, e.pos)
}
lhs.cast[Val.Func].apply(argsL, e.namedNames, e.pos)
}

private def visitApply0(e: Apply0)(implicit scope: ValScope): Val = {
stephenamar-db marked this conversation as resolved.
Show resolved Hide resolved
val lhs = visitExpr(e.value)
lhs.cast[Val.Func].apply0(e.pos)
if (e.tailstrict) {
tailstrict = true
val res = lhs.cast[Val.Func].apply0(e.pos)
tailstrict = false
res
} else {
lhs.cast[Val.Func].apply0(e.pos)
}
}

private def visitApply1(e: Apply1)(implicit scope: ValScope): Val = {
val lhs = visitExpr(e.value)
val l1 = visitAsLazy(e.a1)
lhs.cast[Val.Func].apply1(l1, e.pos)
if (tailstrict) {
lhs.cast[Val.Func].apply1(visitExpr(e.a1), e.pos)
} else if (e.tailstrict) {
tailstrict = true
val res = lhs.cast[Val.Func].apply1(visitExpr(e.a1), e.pos)
tailstrict = false
res
} else {
val l1 = visitAsLazy(e.a1)
lhs.cast[Val.Func].apply1(l1, e.pos)
}
}

private def visitApply2(e: Apply2)(implicit scope: ValScope): Val = {
val lhs = visitExpr(e.value)
val l1 = visitAsLazy(e.a1)
val l2 = visitAsLazy(e.a2)
lhs.cast[Val.Func].apply2(l1, l2, e.pos)
if (tailstrict) {
lhs.cast[Val.Func].apply2(visitExpr(e.a1), visitExpr(e.a2), e.pos)
} else if (e.tailstrict) {
tailstrict = true
val res = lhs.cast[Val.Func].apply2(visitExpr(e.a1), visitExpr(e.a2), e.pos)
tailstrict = false
res
} else {
val l1 = visitAsLazy(e.a1)
val l2 = visitAsLazy(e.a2)
lhs.cast[Val.Func].apply2(l1, l2, e.pos)
}
}

private def visitApply3(e: Apply3)(implicit scope: ValScope): Val = {
val lhs = visitExpr(e.value)
val l1 = visitAsLazy(e.a1)
val l2 = visitAsLazy(e.a2)
val l3 = visitAsLazy(e.a3)
lhs.cast[Val.Func].apply3(l1, l2, l3, e.pos)
if (tailstrict) {
lhs.cast[Val.Func].apply3(visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3), e.pos)
} else if (e.tailstrict) {
tailstrict = true
val res = lhs.cast[Val.Func].apply3(visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3), e.pos)
tailstrict = false
res
} else {
val l1 = visitAsLazy(e.a1)
val l2 = visitAsLazy(e.a2)
val l3 = visitAsLazy(e.a3)
lhs.cast[Val.Func].apply3(l1, l2, l3, e.pos)
}
}

private def visitApplyBuiltin1(e: ApplyBuiltin1)(implicit scope: ValScope) =
Expand Down Expand Up @@ -653,7 +699,8 @@ class Evaluator(resolver: CachedResolver,
newSelf
}

def visitComp(f: List[CompSpec], scopes: Array[ValScope]): Array[ValScope] = f match{
@tailrec
private final def visitComp(f: List[CompSpec], scopes: Array[ValScope]): Array[ValScope] = f match{
case (spec @ ForSpec(_, name, expr)) :: rest =>
visitComp(
rest,
Expand Down
10 changes: 5 additions & 5 deletions sjsonnet/src/sjsonnet/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,11 @@ object Expr{
case class ImportStr(pos: Position, value: String) extends Expr
case class ImportBin(pos: Position, value: String) extends Expr
case class Error(pos: Position, value: Expr) extends Expr
case class Apply(pos: Position, value: Expr, args: Array[Expr], namedNames: Array[String]) extends Expr
case class Apply0(pos: Position, value: Expr) extends Expr
case class Apply1(pos: Position, value: Expr, a1: Expr) extends Expr
case class Apply2(pos: Position, value: Expr, a1: Expr, a2: Expr) extends Expr
case class Apply3(pos: Position, value: Expr, a1: Expr, a2: Expr, a3: Expr) extends Expr
case class Apply(pos: Position, value: Expr, args: Array[Expr], namedNames: Array[String], tailstrict: Boolean) extends Expr
case class Apply0(pos: Position, value: Expr, tailstrict: Boolean) extends Expr
case class Apply1(pos: Position, value: Expr, a1: Expr, tailstrict: Boolean) extends Expr
case class Apply2(pos: Position, value: Expr, a1: Expr, a2: Expr, tailstrict: Boolean) extends Expr
case class Apply3(pos: Position, value: Expr, a1: Expr, a2: Expr, a3: Expr, tailstrict: Boolean) extends Expr
case class ApplyBuiltin(pos: Position, func: Val.Builtin, argExprs: Array[Expr]) extends Expr {
stephenamar-db marked this conversation as resolved.
Show resolved Hide resolved
override def exprErrorString: String = s"std.${func.functionName}"
}
Expand Down
20 changes: 10 additions & 10 deletions sjsonnet/src/sjsonnet/ExprTransform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,37 @@ abstract class ExprTransform {
if(x2 eq x) expr
else Select(pos, x2, name)

case Apply(pos, x, y, namedNames) =>
case Apply(pos, x, y, namedNames, tailstrict) =>
val x2 = transform(x)
val y2 = transformArr(y)
if((x2 eq x) && (y2 eq y)) expr
else Apply(pos, x2, y2, namedNames)
else Apply(pos, x2, y2, namedNames, tailstrict)

case Apply0(pos, x) =>
case Apply0(pos, x, tailstrict) =>
val x2 = transform(x)
if((x2 eq x)) expr
else Apply0(pos, x2)
else Apply0(pos, x2, tailstrict)

case Apply1(pos, x, y) =>
case Apply1(pos, x, y, tailstrict) =>
val x2 = transform(x)
val y2 = transform(y)
if((x2 eq x) && (y2 eq y)) expr
else Apply1(pos, x2, y2)
else Apply1(pos, x2, y2, tailstrict)

case Apply2(pos, x, y, z) =>
case Apply2(pos, x, y, z, tailstrict) =>
val x2 = transform(x)
val y2 = transform(y)
val z2 = transform(z)
if((x2 eq x) && (y2 eq y) && (z2 eq z)) expr
else Apply2(pos, x2, y2, z2)
else Apply2(pos, x2, y2, z2, tailstrict)

case Apply3(pos, x, y, z, a) =>
case Apply3(pos, x, y, z, a, tailstrict) =>
val x2 = transform(x)
val y2 = transform(y)
val z2 = transform(z)
val a2 = transform(a)
if((x2 eq x) && (y2 eq y) && (z2 eq z) && (a2 eq a)) expr
else Apply3(pos, x2, y2, z2, a2)
else Apply3(pos, x2, y2, z2, a2, tailstrict)

case ApplyBuiltin(pos, func, x) =>
val x2 = transformArr(x)
Expand Down
4 changes: 2 additions & 2 deletions sjsonnet/src/sjsonnet/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ class Parser(val currentFile: Path,
case (Some(tree), Seq()) => Expr.Lookup(i, _: Expr, tree)
case (start, ins) => Expr.Slice(i, _: Expr, start, ins.lift(0).flatten, ins.lift(1).flatten)
}
case '(' => Pass ~ (args ~ ")").map { case (args, namedNames) =>
Expr.Apply(i, _: Expr, args, if(namedNames.length == 0) null else namedNames)
case '(' => Pass ~ (args ~ ")" ~ "tailstrict".!.?).map {
case (args, namedNames, tailstrict) => Expr.Apply(i, _: Expr, args, if(namedNames.length == 0) null else namedNames, tailstrict.nonEmpty)
}
case '{' => Pass ~ (objinside ~ "}").map(x => Expr.ObjExtend(i, _: Expr, x))
case _ => Fail
Expand Down
20 changes: 10 additions & 10 deletions sjsonnet/src/sjsonnet/StaticOptimizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class StaticOptimizer(
}

private def transformApply(a: Apply): Expr = {
val rebound = rebindApply(a.pos, a.value, a.args, a.namedNames) match {
val rebound = rebindApply(a.pos, a.value, a.args, a.namedNames, a.tailstrict) match {
case null => a
case a => a
}
Expand All @@ -121,7 +121,7 @@ class StaticOptimizer(
}
}

private def tryStaticApply(pos: Position, f: Val.Builtin, args: Array[Expr]): Expr = {
private def tryStaticApply(pos: Position, f: Val.Builtin, args: Array[Expr], tailstrict: Boolean): Expr = {
if(f.staticSafe && args.forall(_.isInstanceOf[Val])) {
val vargs = args.map(_.asInstanceOf[Val])
try f.apply(vargs, null, pos)(ev).asInstanceOf[Expr] catch { case _: Exception => return null }
Expand All @@ -131,20 +131,20 @@ class StaticOptimizer(
private def specializeApplyArity(a: Apply): Expr = {
if(a.namedNames != null) a
else a.args.length match {
case 0 => Apply0(a.pos, a.value)
case 1 => Apply1(a.pos, a.value, a.args(0))
case 2 => Apply2(a.pos, a.value, a.args(0), a.args(1))
case 3 => Apply3(a.pos, a.value, a.args(0), a.args(1), a.args(2))
case 0 => Apply0(a.pos, a.value, a.tailstrict)
case 1 => Apply1(a.pos, a.value, a.args(0), a.tailstrict)
case 2 => Apply2(a.pos, a.value, a.args(0), a.args(1), a.tailstrict)
case 3 => Apply3(a.pos, a.value, a.args(0), a.args(1), a.args(2), a.tailstrict)
case _ => a
}
}

private def rebindApply(pos: Position, lhs: Expr, args: Array[Expr], names: Array[String]): Expr = lhs match {
private def rebindApply(pos: Position, lhs: Expr, args: Array[Expr], names: Array[String], tailstrict: Boolean): Expr = lhs match {
case f: Val.Builtin =>
rebind(args, names, f.params) match {
case null => null
case newArgs =>
tryStaticApply(pos, f, newArgs) match {
tryStaticApply(pos, f, newArgs, tailstrict) match {
case null =>
val (f2, rargs) = f.specialize(newArgs) match {
case null => (f, newArgs)
Expand All @@ -166,12 +166,12 @@ class StaticOptimizer(
case ScopedVal(Function(_, params, _), _, _) =>
rebind(args, names, params) match {
case null => null
case newArgs => Apply(pos, lhs, newArgs, null)
case newArgs => Apply(pos, lhs, newArgs, null, tailstrict)
}
case ScopedVal(Bind(_, _, params, _), _, _) =>
rebind(args, names, params) match {
case null => null
case newArgs => Apply(pos, lhs, newArgs, null)
case newArgs => Apply(pos, lhs, newArgs, null, tailstrict)
}
case _ => null
}
Expand Down
1 change: 1 addition & 0 deletions sjsonnet/src/sjsonnet/Std.scala
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ class Std(private val additionalNativeFunctions: Map[String, Val.Builtin] = Map.
q.removeFirst().force match {
case v: Val.Arr => v.asLazyArray.reverseIterator.foreach(q.push)
case s: Val.Str => out.write(s.value)
case _ => Error.fail("Cannot call deepJoin on " + value.prettyName)
}
}
Val.Str(pos, out.toString)
Expand Down
23 changes: 22 additions & 1 deletion sjsonnet/src/sjsonnet/Val.scala
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,10 @@ object Val{
val simple = namedNames == null && params.names.length == argsL.length
val funDefFileScope: FileScope = pos match { case null => outerPos.fileScope case p => p.fileScope }
//println(s"apply: argsL: ${argsL.length}, namedNames: $namedNames, paramNames: ${params.names.mkString(",")}")
if(simple) {
if (simple || ev.tailstrict) {
if (ev.tailstrict) {
argsL.foreach(_.force)
}
val newScope = defSiteValScope.extendSimple(argsL)
evalRhs(newScope, ev, funDefFileScope, outerPos)
} else {
Expand Down Expand Up @@ -530,6 +533,9 @@ object Val{
if(params.names.length != 1) apply(Array(argVal), null, outerPos)
else {
val funDefFileScope: FileScope = pos match { case null => outerPos.fileScope case p => p.fileScope }
if (ev.tailstrict) {
argVal.force
}
val newScope: ValScope = defSiteValScope.extendSimple(argVal)
evalRhs(newScope, ev, funDefFileScope, outerPos)
}
Expand All @@ -539,6 +545,10 @@ object Val{
if(params.names.length != 2) apply(Array(argVal1, argVal2), null, outerPos)
else {
val funDefFileScope: FileScope = pos match { case null => outerPos.fileScope case p => p.fileScope }
if (ev.tailstrict) {
argVal1.force
argVal2.force
}
val newScope: ValScope = defSiteValScope.extendSimple(argVal1, argVal2)
evalRhs(newScope, ev, funDefFileScope, outerPos)
}
Expand All @@ -548,6 +558,11 @@ object Val{
if(params.names.length != 3) apply(Array(argVal1, argVal2, argVal3), null, outerPos)
else {
val funDefFileScope: FileScope = pos match { case null => outerPos.fileScope case p => p.fileScope }
if (ev.tailstrict) {
argVal1.force
argVal2.force
argVal3.force
}
val newScope: ValScope = defSiteValScope.extendSimple(argVal1, argVal2, argVal3)
evalRhs(newScope, ev, funDefFileScope, outerPos)
}
Expand Down Expand Up @@ -575,6 +590,10 @@ object Val{
if(params.names.length != 2) apply(Array(argVal1, argVal2), null, outerPos)
else evalRhs(Array(argVal1, argVal2), ev, outerPos)

override def apply3(argVal1: Lazy, argVal2: Lazy, argVal3: Lazy, outerPos: Position)(implicit ev: EvalScope): Val =
if(params.names.length != 3) apply(Array(argVal1, argVal2, argVal3), null, outerPos)
else evalRhs(Array(argVal1, argVal2, argVal3), ev, outerPos)

/** Specialize a call to this function in the optimizer. Must return either `null` to leave the
* call-site as it is or a pair of a (possibly different) `Builtin` and the arguments to pass
* to it (usually a subset of the supplied arguments).
Expand Down Expand Up @@ -647,6 +666,8 @@ object Val{
* throughout the Jsonnet evaluation.
*/
abstract class EvalScope extends EvalErrorScope with Ordering[Val] {
def tailstrict: Boolean

def visitExpr(expr: Expr)
(implicit scope: ValScope): Val

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
Copyright 2015 Google Inc. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

local fibonacci(n) =
if n <= 1 then
1
else
fibonacci(n - 1) + fibonacci(n - 2);

std.assertEqual(fibonacci(15), 987) &&

// Tail recursive call
local sum(x, v) =
if x <= 0 then
v
else
sum(x - 1, x + v) tailstrict;

// Scala Native is struggling with large stacks.
local sz = 1000;
std.assertEqual(sum(sz, 0), sz * (sz + 1) / 2) &&

std.assertEqual(local x() = 3; x() tailstrict, 3) &&

true
Loading
Loading