Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenamar-db committed Jan 14, 2025
1 parent 01a33dd commit fa3f21c
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 57 deletions.
42 changes: 23 additions & 19 deletions sjsonnet/src/sjsonnet/Evaluator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import Expr.{Error => _, _}
import sjsonnet.Expr.Member.Visibility
import ujson.Value

import scala.annotation.tailrec
import scala.collection.mutable

/**
Expand All @@ -27,11 +28,7 @@ class Evaluator(resolver: CachedResolver,

def materialize(v: Val): Value = Materializer.apply(v)
val cachedImports = collection.mutable.HashMap.empty[Path, Val]

var isInTailstrictMode = false

override def tailstrict: Boolean = isInTailstrictMode

var tailstrict: Boolean = false

override def visitExpr(e: Expr)(implicit scope: ValScope): Val = try {
e match {
Expand Down Expand Up @@ -190,12 +187,12 @@ class Evaluator(resolver: CachedResolver,
private def visitApply(e: Apply)(implicit scope: ValScope) = {
val lhs = visitExpr(e.value)

if (isInTailstrictMode) {
if (tailstrict) {
lhs.cast[Val.Func].apply(e.args.map(visitExpr(_)), e.namedNames, e.pos)
} else if (e.tailstrict) {
isInTailstrictMode = true
tailstrict = true
val res = lhs.cast[Val.Func].apply(e.args.map(visitExpr(_)), e.namedNames, e.pos)
isInTailstrictMode = false
tailstrict = false
res
} else {
val args = e.args
Expand All @@ -211,17 +208,24 @@ class Evaluator(resolver: CachedResolver,

private def visitApply0(e: Apply0)(implicit scope: ValScope): Val = {
val lhs = visitExpr(e.value)
lhs.cast[Val.Func].apply0(e.pos)
if (e.tailstrict) {
tailstrict = true
val res = lhs.cast[Val.Func].apply0(e.pos)
tailstrict = false
res
} else {
lhs.cast[Val.Func].apply0(e.pos)
}
}

private def visitApply1(e: Apply1)(implicit scope: ValScope): Val = {
val lhs = visitExpr(e.value)
if (isInTailstrictMode) {
if (tailstrict) {
lhs.cast[Val.Func].apply1(visitExpr(e.a1), e.pos)
} else if (e.tailstrict) {
isInTailstrictMode = true
tailstrict = true
val res = lhs.cast[Val.Func].apply1(visitExpr(e.a1), e.pos)
isInTailstrictMode = false
tailstrict = false
res
} else {
val l1 = visitAsLazy(e.a1)
Expand All @@ -231,12 +235,12 @@ class Evaluator(resolver: CachedResolver,

private def visitApply2(e: Apply2)(implicit scope: ValScope): Val = {
val lhs = visitExpr(e.value)
if (isInTailstrictMode) {
if (tailstrict) {
lhs.cast[Val.Func].apply2(visitExpr(e.a1), visitExpr(e.a2), e.pos)
} else if (e.tailstrict) {
isInTailstrictMode = true
tailstrict = true
val res = lhs.cast[Val.Func].apply2(visitExpr(e.a1), visitExpr(e.a2), e.pos)
isInTailstrictMode = false
tailstrict = false
res
} else {
val l1 = visitAsLazy(e.a1)
Expand All @@ -247,12 +251,12 @@ class Evaluator(resolver: CachedResolver,

private def visitApply3(e: Apply3)(implicit scope: ValScope): Val = {
val lhs = visitExpr(e.value)
if (isInTailstrictMode) {
if (tailstrict) {
lhs.cast[Val.Func].apply3(visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3), e.pos)
} else if (e.tailstrict) {
isInTailstrictMode = true
tailstrict = true
val res = lhs.cast[Val.Func].apply3(visitExpr(e.a1), visitExpr(e.a2), visitExpr(e.a3), e.pos)
isInTailstrictMode = false
tailstrict = false
res
} else {
val l1 = visitAsLazy(e.a1)
Expand Down Expand Up @@ -695,7 +699,7 @@ class Evaluator(resolver: CachedResolver,
newSelf
}

@inline
@tailrec
private final def visitComp(f: List[CompSpec], scopes: Array[ValScope]): Array[ValScope] = f match{
case (spec @ ForSpec(_, name, expr)) :: rest =>
visitComp(
Expand Down
2 changes: 1 addition & 1 deletion sjsonnet/src/sjsonnet/Interpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class Interpreter(extVars: Map[String, String],
override def evalDefault(expr: Expr, vs: ValScope, es: EvalScope) = {
evaluator.visitExpr(expr)(if (tlaExpressions.exists(_ eq expr)) ValScope.empty else vs)
}
}.apply0(f.pos)(evaluator, f.defSiteValScope)
}.apply0(f.pos)(evaluator)
case x => x
}
} yield res
Expand Down
68 changes: 31 additions & 37 deletions sjsonnet/src/sjsonnet/Val.scala
Original file line number Diff line number Diff line change
Expand Up @@ -461,14 +461,14 @@ object Val{

override def asFunc: Func = this

def apply(argsL: Array[_ <: Lazy], namedNames: Array[String], outerPos: Position)(implicit ev: EvalScope, vs: ValScope = defSiteValScope): Val = {
def apply(argsL: Array[_ <: Lazy], namedNames: Array[String], outerPos: Position)(implicit ev: EvalScope): Val = {
val simple = namedNames == null && params.names.length == argsL.length
val funDefFileScope: FileScope = pos match { case null => outerPos.fileScope case p => p.fileScope }
//println(s"apply: argsL: ${argsL.length}, namedNames: $namedNames, paramNames: ${params.names.mkString(",")}")
if (ev.tailstrict) {
System.arraycopy(argsL, 0, vs.bindings, defSiteValScope.length, argsL.length)
evalRhs(vs, ev, funDefFileScope, outerPos)
} else if(simple) {
if (simple || ev.tailstrict) {
if (ev.tailstrict) {
argsL.foreach(_.force)
}
val newScope = defSiteValScope.extendSimple(argsL)
evalRhs(newScope, ev, funDefFileScope, outerPos)
} else {
Expand Down Expand Up @@ -521,56 +521,50 @@ object Val{
}
}

def apply0(outerPos: Position)(implicit ev: EvalScope, vs: ValScope = defSiteValScope): Val = {
def apply0(outerPos: Position)(implicit ev: EvalScope): Val = {
if(params.names.length != 0) apply(Evaluator.emptyLazyArray, null, outerPos)
else {
val funDefFileScope: FileScope = pos match { case null => outerPos.fileScope case p => p.fileScope }
evalRhs(defSiteValScope, ev, funDefFileScope, outerPos)
}
}

def apply1(argVal: Lazy, outerPos: Position)(implicit ev: EvalScope, vs: ValScope = defSiteValScope): Val = {
def apply1(argVal: Lazy, outerPos: Position)(implicit ev: EvalScope): Val = {
if(params.names.length != 1) apply(Array(argVal), null, outerPos)
else {
val funDefFileScope: FileScope = pos match { case null => outerPos.fileScope case p => p.fileScope }
if (ev.tailstrict) {
vs.bindings(defSiteValScope.length) = argVal
evalRhs(vs, ev, funDefFileScope, outerPos)
} else {
val newScope: ValScope = defSiteValScope.extendSimple(argVal)
evalRhs(newScope, ev, funDefFileScope, outerPos)
argVal.force
}
val newScope: ValScope = defSiteValScope.extendSimple(argVal)
evalRhs(newScope, ev, funDefFileScope, outerPos)
}
}

def apply2(argVal1: Lazy, argVal2: Lazy, outerPos: Position)(implicit ev: EvalScope, vs: ValScope = defSiteValScope): Val = {
def apply2(argVal1: Lazy, argVal2: Lazy, outerPos: Position)(implicit ev: EvalScope): Val = {
if(params.names.length != 2) apply(Array(argVal1, argVal2), null, outerPos)
else {
val funDefFileScope: FileScope = pos match { case null => outerPos.fileScope case p => p.fileScope }
if (ev.tailstrict) {
vs.bindings(defSiteValScope.length) = argVal1
vs.bindings(defSiteValScope.length+1) = argVal2
evalRhs(vs, ev, funDefFileScope, outerPos)
} else {
val newScope: ValScope = defSiteValScope.extendSimple(argVal1, argVal2)
evalRhs(newScope, ev, funDefFileScope, outerPos)
argVal1.force
argVal2.force
}
val newScope: ValScope = defSiteValScope.extendSimple(argVal1, argVal2)
evalRhs(newScope, ev, funDefFileScope, outerPos)
}
}

def apply3(argVal1: Lazy, argVal2: Lazy, argVal3: Lazy, outerPos: Position)(implicit ev: EvalScope, vs: ValScope = defSiteValScope): Val = {
def apply3(argVal1: Lazy, argVal2: Lazy, argVal3: Lazy, outerPos: Position)(implicit ev: EvalScope): Val = {
if(params.names.length != 3) apply(Array(argVal1, argVal2, argVal3), null, outerPos)
else {
val funDefFileScope: FileScope = pos match { case null => outerPos.fileScope case p => p.fileScope }
if (ev.tailstrict) {
vs.bindings(defSiteValScope.length) = argVal1
vs.bindings(defSiteValScope.length+1) = argVal2
vs.bindings(defSiteValScope.length+2) = argVal3
evalRhs(vs, ev, funDefFileScope, outerPos)
} else {
val newScope: ValScope = defSiteValScope.extendSimple(argVal1, argVal2, argVal3)
evalRhs(newScope, ev, funDefFileScope, outerPos)
argVal1.force
argVal2.force
argVal3.force
}
val newScope: ValScope = defSiteValScope.extendSimple(argVal1, argVal2, argVal3)
evalRhs(newScope, ev, funDefFileScope, outerPos)
}
}
}
Expand All @@ -588,16 +582,16 @@ object Val{

def evalRhs(args: Array[_ <: Lazy], ev: EvalScope, pos: Position): Val

override def apply(argsL: Array[_ <: Lazy], namedNames: Array[String], outerPos: Position)(implicit ev: EvalScope, vs: ValScope): Val =
override def apply(argsL: Array[_ <: Lazy], namedNames: Array[String], outerPos: Position)(implicit ev: EvalScope): Val =
evalRhs(argsL, ev, outerPos)

override def apply1(argVal: Lazy, outerPos: Position)(implicit ev: EvalScope, vs: ValScope): Val =
override def apply1(argVal: Lazy, outerPos: Position)(implicit ev: EvalScope): Val =
evalRhs(Array(argVal), ev, outerPos)

override def apply2(argVal1: Lazy, argVal2: Lazy, outerPos: Position)(implicit ev: EvalScope, vs: ValScope): Val =
override def apply2(argVal1: Lazy, argVal2: Lazy, outerPos: Position)(implicit ev: EvalScope): Val =
evalRhs(Array(argVal1, argVal2), ev, outerPos)

override def apply3(argVal1: Lazy, argVal2: Lazy, argVal3: Lazy, outerPos: Position)(implicit ev: EvalScope, vs: ValScope): Val =
override def apply3(argVal1: Lazy, argVal2: Lazy, argVal3: Lazy, outerPos: Position)(implicit ev: EvalScope): Val =
evalRhs(Array(argVal1, argVal2, argVal3), ev, outerPos)

/** Specialize a call to this function in the optimizer. Must return either `null` to leave the
Expand All @@ -617,11 +611,11 @@ object Val{

def evalRhs(arg1: Val, ev: EvalScope, pos: Position): Val

override def apply(argVals: Array[_ <: Lazy], namedNames: Array[String], outerPos: Position)(implicit ev: EvalScope, vs: ValScope): Val =
override def apply(argVals: Array[_ <: Lazy], namedNames: Array[String], outerPos: Position)(implicit ev: EvalScope): Val =
if(namedNames == null && argVals.length == 1) evalRhs(argVals(0).force, ev, outerPos)
else super.apply(argVals, namedNames, outerPos)

override def apply1(argVal: Lazy, outerPos: Position)(implicit ev: EvalScope, vs: ValScope): Val =
override def apply1(argVal: Lazy, outerPos: Position)(implicit ev: EvalScope): Val =
if(params.names.length == 1) evalRhs(argVal.force, ev, outerPos)
else super.apply(Array(argVal), null, outerPos)
}
Expand All @@ -632,12 +626,12 @@ object Val{

def evalRhs(arg1: Val, arg2: Val, ev: EvalScope, pos: Position): Val

override def apply(argVals: Array[_ <: Lazy], namedNames: Array[String], outerPos: Position)(implicit ev: EvalScope, vs: ValScope): Val =
override def apply(argVals: Array[_ <: Lazy], namedNames: Array[String], outerPos: Position)(implicit ev: EvalScope): Val =
if(namedNames == null && argVals.length == 2)
evalRhs(argVals(0).force, argVals(1).force, ev, outerPos)
else super.apply(argVals, namedNames, outerPos)

override def apply2(argVal1: Lazy, argVal2: Lazy, outerPos: Position)(implicit ev: EvalScope, vs: ValScope): Val =
override def apply2(argVal1: Lazy, argVal2: Lazy, outerPos: Position)(implicit ev: EvalScope): Val =
if(params.names.length == 2) evalRhs(argVal1.force, argVal2.force, ev, outerPos)
else super.apply(Array(argVal1, argVal2), null, outerPos)
}
Expand All @@ -648,7 +642,7 @@ object Val{

def evalRhs(arg1: Val, arg2: Val, arg3: Val, ev: EvalScope, pos: Position): Val

override def apply(argVals: Array[_ <: Lazy], namedNames: Array[String], outerPos: Position)(implicit ev: EvalScope, vs: ValScope): Val =
override def apply(argVals: Array[_ <: Lazy], namedNames: Array[String], outerPos: Position)(implicit ev: EvalScope): Val =
if(namedNames == null && argVals.length == 3)
evalRhs(argVals(0).force, argVals(1).force, argVals(2).force, ev, outerPos)
else super.apply(argVals, namedNames, outerPos)
Expand All @@ -660,7 +654,7 @@ object Val{

def evalRhs(arg1: Val, arg2: Val, arg3: Val, arg4: Val, ev: EvalScope, pos: Position): Val

override def apply(argVals: Array[_ <: Lazy], namedNames: Array[String], outerPos: Position)(implicit ev: EvalScope, vs: ValScope): Val =
override def apply(argVals: Array[_ <: Lazy], namedNames: Array[String], outerPos: Position)(implicit ev: EvalScope): Val =
if(namedNames == null && argVals.length == 4)
evalRhs(argVals(0).force, argVals(1).force, argVals(2).force, argVals(3).force, ev, outerPos)
else super.apply(argVals, namedNames, outerPos)
Expand Down

0 comments on commit fa3f21c

Please sign in to comment.