Skip to content

Commit

Permalink
tailstrict
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenamar-db committed Jan 8, 2025
1 parent 9143c58 commit 4eb8b30
Show file tree
Hide file tree
Showing 12 changed files with 182 additions and 73 deletions.
2 changes: 1 addition & 1 deletion build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ object sjsonnet extends Module {
def scalacOptions = Seq("-opt:l:inline", "-opt-inline-from:sjsonnet.**")

object test extends ScalaTests with CrossTests {
def forkOptions = Seq("-Xss100m")
def forkArgs = Seq("-Xss100m")
def sources = T.sources(
this.millSourcePath / "src",
this.millSourcePath / "src-jvm",
Expand Down
79 changes: 61 additions & 18 deletions sjsonnet/src/sjsonnet/Evaluator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ class Evaluator(resolver: CachedResolver,
def materialize(v: Val): Value = Materializer.apply(v)
val cachedImports = collection.mutable.HashMap.empty[Path, Val]

def visitExpr(e: Expr)(implicit scope: ValScope): Val = try {
var isInTailstrictMode = false

override def tailstrict: Boolean = isInTailstrictMode


override def visitExpr(e: Expr)(implicit scope: ValScope): Val = try {
e match {
case e: ValidId => visitValidId(e)
case e: BinaryOp => visitBinaryOp(e)
Expand Down Expand Up @@ -184,14 +189,24 @@ class Evaluator(resolver: CachedResolver,

private def visitApply(e: Apply)(implicit scope: ValScope) = {
val lhs = visitExpr(e.value)
val args = e.args
val argsL = new Array[Lazy](args.length)
var idx = 0
while (idx < args.length) {
argsL(idx) = visitAsLazy(args(idx))
idx += 1

if (isInTailstrictMode) {
lhs.cast[Val.Func].apply(e.args.map(visitExpr(_)), e.namedNames, e.pos)
} else if (e.tailstrict) {
isInTailstrictMode = true
val res = lhs.cast[Val.Func].apply(e.args.map(visitExpr(_)), e.namedNames, e.pos)
isInTailstrictMode = false
res
} else {
val args = e.args
val argsL = new Array[Lazy](args.length)
var idx = 0
while (idx < args.length) {
argsL(idx) = visitAsLazy(args(idx))
idx += 1
}
lhs.cast[Val.Func].apply(argsL, e.namedNames, e.pos)
}
lhs.cast[Val.Func].apply(argsL, e.namedNames, e.pos)
}

private def visitApply0(e: Apply0)(implicit scope: ValScope): Val = {
Expand All @@ -201,23 +216,50 @@ class Evaluator(resolver: CachedResolver,

private def visitApply1(e: Apply1)(implicit scope: ValScope): Val = {
val lhs = visitExpr(e.value)
val l1 = visitAsLazy(e.a1)
lhs.cast[Val.Func].apply1(l1, e.pos)
if (isInTailstrictMode) {
lhs.cast[Val.Func].apply1(visitExpr(e.a1), e.pos)
} else if (e.tailstrict) {
isInTailstrictMode = true
val res = lhs.cast[Val.Func].apply1(visitExpr(e.a1), e.pos)
isInTailstrictMode = false
res
} else {
val l1 = visitAsLazy(e.a1)
lhs.cast[Val.Func].apply1(l1, e.pos)
}
}

private def visitApply2(e: Apply2)(implicit scope: ValScope): Val = {
val lhs = visitExpr(e.value)
val l1 = visitAsLazy(e.a1)
val l2 = visitAsLazy(e.a2)
lhs.cast[Val.Func].apply2(l1, l2, e.pos)
if (isInTailstrictMode) {
lhs.cast[Val.Func].apply2(visitExpr(e.a1), visitExpr(e.a2), e.pos)
} else if (e.tailstrict) {
isInTailstrictMode = true
val res = lhs.cast[Val.Func].apply2(visitExpr(e.a1), visitExpr(e.a2), e.pos)
isInTailstrictMode = false
res
} else {
val l1 = visitAsLazy(e.a1)
val l2 = visitAsLazy(e.a2)
lhs.cast[Val.Func].apply2(l1, l2, e.pos)
}
}

private def visitApply3(e: Apply3)(implicit scope: ValScope): Val = {
val lhs = visitExpr(e.value)
val l1 = visitAsLazy(e.a1)
val l2 = visitAsLazy(e.a2)
val l3 = visitAsLazy(e.a3)
lhs.cast[Val.Func].apply3(l1, l2, l3, e.pos)
if (isInTailstrictMode) {
lhs.cast[Val.Func].apply3(visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3), e.pos)
} else if (e.tailstrict) {
isInTailstrictMode = true
val res = lhs.cast[Val.Func].apply3(visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3), e.pos)
isInTailstrictMode = false
res
} else {
val l1 = visitAsLazy(e.a1)
val l2 = visitAsLazy(e.a2)
val l3 = visitAsLazy(e.a3)
lhs.cast[Val.Func].apply3(l1, l2, l3, e.pos)
}
}

private def visitApplyBuiltin1(e: ApplyBuiltin1)(implicit scope: ValScope) =
Expand Down Expand Up @@ -642,7 +684,8 @@ class Evaluator(resolver: CachedResolver,
newSelf
}

def visitComp(f: List[CompSpec], scopes: Array[ValScope]): Array[ValScope] = f match{
@inline
private final def visitComp(f: List[CompSpec], scopes: Array[ValScope]): Array[ValScope] = f match{
case (spec @ ForSpec(_, name, expr)) :: rest =>
visitComp(
rest,
Expand Down
10 changes: 5 additions & 5 deletions sjsonnet/src/sjsonnet/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,11 @@ object Expr{
case class ImportStr(pos: Position, value: String) extends Expr
case class ImportBin(pos: Position, value: String) extends Expr
case class Error(pos: Position, value: Expr) extends Expr
case class Apply(pos: Position, value: Expr, args: Array[Expr], namedNames: Array[String]) extends Expr
case class Apply0(pos: Position, value: Expr) extends Expr
case class Apply1(pos: Position, value: Expr, a1: Expr) extends Expr
case class Apply2(pos: Position, value: Expr, a1: Expr, a2: Expr) extends Expr
case class Apply3(pos: Position, value: Expr, a1: Expr, a2: Expr, a3: Expr) extends Expr
case class Apply(pos: Position, value: Expr, args: Array[Expr], namedNames: Array[String], tailstrict: Boolean) extends Expr
case class Apply0(pos: Position, value: Expr, tailstrict: Boolean) extends Expr
case class Apply1(pos: Position, value: Expr, a1: Expr, tailstrict: Boolean) extends Expr
case class Apply2(pos: Position, value: Expr, a1: Expr, a2: Expr, tailstrict: Boolean) extends Expr
case class Apply3(pos: Position, value: Expr, a1: Expr, a2: Expr, a3: Expr, tailstrict: Boolean) extends Expr
case class ApplyBuiltin(pos: Position, func: Val.Builtin, argExprs: Array[Expr]) extends Expr {
override def exprErrorString: String = s"std.${func.functionName}"
}
Expand Down
20 changes: 10 additions & 10 deletions sjsonnet/src/sjsonnet/ExprTransform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,37 @@ abstract class ExprTransform {
if(x2 eq x) expr
else Select(pos, x2, name)

case Apply(pos, x, y, namedNames) =>
case Apply(pos, x, y, namedNames, tailstrict) =>
val x2 = transform(x)
val y2 = transformArr(y)
if((x2 eq x) && (y2 eq y)) expr
else Apply(pos, x2, y2, namedNames)
else Apply(pos, x2, y2, namedNames, tailstrict)

case Apply0(pos, x) =>
case Apply0(pos, x, tailstrict) =>
val x2 = transform(x)
if((x2 eq x)) expr
else Apply0(pos, x2)
else Apply0(pos, x2, tailstrict)

case Apply1(pos, x, y) =>
case Apply1(pos, x, y, tailstrict) =>
val x2 = transform(x)
val y2 = transform(y)
if((x2 eq x) && (y2 eq y)) expr
else Apply1(pos, x2, y2)
else Apply1(pos, x2, y2, tailstrict)

case Apply2(pos, x, y, z) =>
case Apply2(pos, x, y, z, tailstrict) =>
val x2 = transform(x)
val y2 = transform(y)
val z2 = transform(z)
if((x2 eq x) && (y2 eq y) && (z2 eq z)) expr
else Apply2(pos, x2, y2, z2)
else Apply2(pos, x2, y2, z2, tailstrict)

case Apply3(pos, x, y, z, a) =>
case Apply3(pos, x, y, z, a, tailstrict) =>
val x2 = transform(x)
val y2 = transform(y)
val z2 = transform(z)
val a2 = transform(a)
if((x2 eq x) && (y2 eq y) && (z2 eq z) && (a2 eq a)) expr
else Apply3(pos, x2, y2, z2, a2)
else Apply3(pos, x2, y2, z2, a2, tailstrict)

case ApplyBuiltin(pos, func, x) =>
val x2 = transformArr(x)
Expand Down
2 changes: 1 addition & 1 deletion sjsonnet/src/sjsonnet/Interpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class Interpreter(extVars: Map[String, String],
override def evalDefault(expr: Expr, vs: ValScope, es: EvalScope) = {
evaluator.visitExpr(expr)(if (tlaExpressions.exists(_ eq expr)) ValScope.empty else vs)
}
}.apply0(f.pos)(evaluator)
}.apply0(f.pos)(evaluator, f.defSiteValScope)
case x => x
}
} yield res
Expand Down
4 changes: 2 additions & 2 deletions sjsonnet/src/sjsonnet/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ class Parser(val currentFile: Path,
case (Some(tree), Seq()) => Expr.Lookup(i, _: Expr, tree)
case (start, ins) => Expr.Slice(i, _: Expr, start, ins.lift(0).flatten, ins.lift(1).flatten)
}
case '(' => Pass ~ (args ~ ")").map { case (args, namedNames) =>
Expr.Apply(i, _: Expr, args, if(namedNames.length == 0) null else namedNames)
case '(' => Pass ~ (args ~ ")" ~ "tailstrict".?.!).map {
case (args, namedNames, tailstrict) => Expr.Apply(i, _: Expr, args, if(namedNames.length == 0) null else namedNames, tailstrict.nonEmpty)
}
case '{' => Pass ~ (objinside ~ "}").map(x => Expr.ObjExtend(i, _: Expr, x))
case _ => Fail
Expand Down
20 changes: 10 additions & 10 deletions sjsonnet/src/sjsonnet/StaticOptimizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class StaticOptimizer(
}

private def transformApply(a: Apply): Expr = {
val rebound = rebindApply(a.pos, a.value, a.args, a.namedNames) match {
val rebound = rebindApply(a.pos, a.value, a.args, a.namedNames, a.tailstrict) match {
case null => a
case a => a
}
Expand All @@ -121,7 +121,7 @@ class StaticOptimizer(
}
}

private def tryStaticApply(pos: Position, f: Val.Builtin, args: Array[Expr]): Expr = {
private def tryStaticApply(pos: Position, f: Val.Builtin, args: Array[Expr], tailstrict: Boolean): Expr = {
if(f.staticSafe && args.forall(_.isInstanceOf[Val])) {
val vargs = args.map(_.asInstanceOf[Val])
try f.apply(vargs, null, pos)(ev).asInstanceOf[Expr] catch { case _: Exception => return null }
Expand All @@ -131,20 +131,20 @@ class StaticOptimizer(
private def specializeApplyArity(a: Apply): Expr = {
if(a.namedNames != null) a
else a.args.length match {
case 0 => Apply0(a.pos, a.value)
case 1 => Apply1(a.pos, a.value, a.args(0))
case 2 => Apply2(a.pos, a.value, a.args(0), a.args(1))
case 3 => Apply3(a.pos, a.value, a.args(0), a.args(1), a.args(2))
case 0 => Apply0(a.pos, a.value, a.tailstrict)
case 1 => Apply1(a.pos, a.value, a.args(0), a.tailstrict)
case 2 => Apply2(a.pos, a.value, a.args(0), a.args(1), a.tailstrict)
case 3 => Apply3(a.pos, a.value, a.args(0), a.args(1), a.args(2), a.tailstrict)
case _ => a
}
}

private def rebindApply(pos: Position, lhs: Expr, args: Array[Expr], names: Array[String]): Expr = lhs match {
private def rebindApply(pos: Position, lhs: Expr, args: Array[Expr], names: Array[String], tailstrict: Boolean): Expr = lhs match {
case f: Val.Builtin =>
rebind(args, names, f.params) match {
case null => null
case newArgs =>
tryStaticApply(pos, f, newArgs) match {
tryStaticApply(pos, f, newArgs, tailstrict) match {
case null =>
val (f2, rargs) = f.specialize(newArgs) match {
case null => (f, newArgs)
Expand All @@ -166,12 +166,12 @@ class StaticOptimizer(
case ScopedVal(Function(_, params, _), _, _) =>
rebind(args, names, params) match {
case null => null
case newArgs => Apply(pos, lhs, newArgs, null)
case newArgs => Apply(pos, lhs, newArgs, null, tailstrict)
}
case ScopedVal(Bind(_, _, params, _), _, _) =>
rebind(args, names, params) match {
case null => null
case newArgs => Apply(pos, lhs, newArgs, null)
case newArgs => Apply(pos, lhs, newArgs, null, tailstrict)
}
case _ => null
}
Expand Down
Loading

0 comments on commit 4eb8b30

Please sign in to comment.