Skip to content

Commit

Permalink
No flatten for most backends (#5528)
Browse files Browse the repository at this point in the history
### Description
Do not flatten matches for most backends. Such flattening can cause the
generated code size to be the square of the input code size. Removing
the flattening leads to a huge decrease in the size of the generated C#
code inside the Dafny codebase. It goes from ~14K lines to ~7K.

The compilation of nested matches is done as follows:
```dafny
datatype TX = X(x1arg: TY, x2arg: TY)
datatype TY = Y(yarg: int)
datatype TZ = Z(zarg: TW)
datatype TW = W(warg: int)

method M(a: TX) {
  match a
    case X(Y(b),Z(W(c)) => <body1>
    case r => <body2>
}
```

Is roughly compiled into 

```dafny
// Same datatypes

method M(a: TX) {
  var unmatched := true;
  if (unmatched && a is X) {
    var x1arg1 = ((X)a).1;
    if (x1 is Y) {
      var b = ((Y)x1arg1).1;
      var x2arg2 := ((X)a).2; 
      if (x2 is Z) {
        var zarg1 := ((Z)x2arg2).1;
        if (x4 is W) {
          var c := ((W)zarg1).1;
          unmatched := false;
          <body1>
        }
      } 
    }
  }
  if (unmatched) {
    var r := a;
    <body2>
  }
}
```

#### Caveats

##### Maintainability 
To reduce the required work, Java and Dafny back-ends still compile
using flattened matches

Ideally the transformation would be a Dafny-to-Dafny source
transformation, instead of a customization of SinglePassCompiler.
However, Dafny does not allow using statements in expression contexts,
and this is needed for the transformation. I think it would be good to
have an intermediate Dafny that does allow this, similar to what
@cpitclaudel 's Dafny-in-Dafny compiler allowed, and then to define the
rewrite that this PR introduces using a Dafny source translation.

##### Improvement
For C# we could generate much nicer code, since C# allows declaring new
variables inside expressions using `x is T xAsType` expressions. We
could get rid of the nested `if` statements and the `unmatched`
variable. However, I'll leave this for future work.

### How has this been tested?
- Performance change. No additional tests added.

<small>By submitting this pull request, I confirm that my contribution
is made under the terms of the [MIT
license](https://github.com/dafny-lang/dafny/blob/master/LICENSE.txt).</small>
  • Loading branch information
keyboardDrummer authored Jun 5, 2024
1 parent fe3b56c commit 5b048ff
Show file tree
Hide file tree
Showing 35 changed files with 7,176 additions and 14,203 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/msbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:

integration-tests:
needs: check-deep-tests
if: always() && (( github.event_name == 'pull_request' && (needs.check-deep-tests.result == 'success' || contains(github.event.pull_request.labels.*.name, 'run-deep-tests'))) || ( github.event_name == 'push' && ( github.ref_name == 'master' || vars.TEST_ON_FORK == 'true' )))
if: always() && (( github.event_name == 'pull_request' && (needs.check-deep-tests.result == 'success' || contains(github.event.pull_request.labels.*.name, 'run-deep-tests') || contains(github.event.pull_request.labels.*.name, 'run-integration-tests'))) || ( github.event_name == 'push' && ( github.ref_name == 'master' || vars.TEST_ON_FORK == 'true' )))
uses: ./.github/workflows/integration-tests-reusable.yml
with:
ref: ${{ github.ref }}
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@ namespace Microsoft.Dafny;
interface INestedMatch : INode {
Expression Source { get; }
string MatchTypeName { get; }
IReadOnlyList<NestedMatchCase> Cases { get; }
}

public class NestedMatchExpr : Expression, ICloneable<NestedMatchExpr>, ICanFormat, INestedMatch {
public Expression Source { get; }
public string MatchTypeName => "expression";
public readonly List<NestedMatchCaseExpr> Cases;
public List<NestedMatchCaseExpr> Cases { get; }

IReadOnlyList<NestedMatchCase> INestedMatch.Cases => Cases;

public readonly bool UsesOptionalBraces;
public Attributes Attributes;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ namespace Microsoft.Dafny;
public class NestedMatchStmt : Statement, ICloneable<NestedMatchStmt>, ICanFormat, INestedMatch, ICanResolve {
public Expression Source { get; }
public string MatchTypeName => "statement";
public readonly List<NestedMatchCaseStmt> Cases;
public List<NestedMatchCaseStmt> Cases { get; }

IReadOnlyList<NestedMatchCase> INestedMatch.Cases => Cases;

public readonly bool UsesOptionalBraces;

[FilledInDuringResolution] public Statement Flattened { get; set; }
Expand Down
4 changes: 3 additions & 1 deletion Source/DafnyCore/AST/Substituter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ ExtendedPattern SubstituteForPattern(ExtendedPattern pattern) {
case IdPattern idPattern:
if (idPattern.BoundVar == null) {
return new IdPattern(idPattern.Tok, idPattern.Id, idPattern.Type,
idPattern.Arguments?.Select(SubstituteForPattern).ToList(), idPattern.IsGhost);
idPattern.Arguments?.Select(SubstituteForPattern).ToList(), idPattern.IsGhost) {
Ctor = idPattern.Ctor
};
}

discoveredBvs.Add((BoundVar)idPattern.BoundVar);
Expand Down
23 changes: 12 additions & 11 deletions Source/DafnyCore/Backends/CSharp/CsharpCodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2958,7 +2958,8 @@ protected override ConcreteSyntaxTree EmitBetaRedex(List<string> boundVars, List
return result;
}

protected override void EmitDestructor(Action<ConcreteSyntaxTree> source, Formal dtor, int formalNonGhostIndex, DatatypeCtor ctor, List<Type> typeArgs, Type bvType, ConcreteSyntaxTree wr) {
protected override void EmitDestructor(Action<ConcreteSyntaxTree> source, Formal dtor, int formalNonGhostIndex,
DatatypeCtor ctor, Func<List<Type>> getTypeArgs, Type bvType, ConcreteSyntaxTree wr) {
if (DatatypeWrapperEraser.IsErasableDatatypeWrapper(Options, ctor.EnclosingDatatype, out var coreDtor)) {
Contract.Assert(coreDtor.CorrespondingFormals.Count == 1);
Contract.Assert(dtor == coreDtor.CorrespondingFormals[0]); // any other destructor is a ghost
Expand Down Expand Up @@ -3041,7 +3042,7 @@ static bool IsDirectlyComparable(Type t) {
}

protected override void CompileBinOp(BinaryExpr.ResolvedOpcode op,
Expression e0, Expression e1, IToken tok, Type resultType,
Type e0Type, Type e1Type, IToken tok, Type resultType,
out string opString,
out string preOpString,
out string postOpString,
Expand All @@ -3065,7 +3066,7 @@ protected override void CompileBinOp(BinaryExpr.ResolvedOpcode op,

switch (op) {
case BinaryExpr.ResolvedOpcode.EqCommon: {
var eqType = DatatypeWrapperEraser.SimplifyType(Options, e0.Type);
var eqType = DatatypeWrapperEraser.SimplifyType(Options, e0Type);
if (eqType.IsRefType) {
// Dafny's type rules are slightly different C#, so we may need a cast here.
// For example, Dafny allows x==y if x:array<T> and y:array<int> and T is some
Expand All @@ -3079,7 +3080,7 @@ protected override void CompileBinOp(BinaryExpr.ResolvedOpcode op,
break;
}
case BinaryExpr.ResolvedOpcode.NeqCommon: {
var eqType = DatatypeWrapperEraser.SimplifyType(Options, e0.Type);
var eqType = DatatypeWrapperEraser.SimplifyType(Options, e0Type);
if (eqType.IsRefType) {
// Dafny's type rules are slightly different C#, so we may need a cast here.
// For example, Dafny allows x==y if x:array<T> and y:array<int> and T is some
Expand Down Expand Up @@ -3170,14 +3171,14 @@ protected override void CompileBinOp(BinaryExpr.ResolvedOpcode op,

case BinaryExpr.ResolvedOpcode.ProperSubset:
case BinaryExpr.ResolvedOpcode.ProperMultiSubset:
staticCallString = TypeHelperName(e0.Type, errorWr, tok, e1.Type) + ".IsProperSubsetOf"; break;
staticCallString = TypeHelperName(e0Type, errorWr, tok, e1Type) + ".IsProperSubsetOf"; break;
case BinaryExpr.ResolvedOpcode.Subset:
case BinaryExpr.ResolvedOpcode.MultiSubset:
staticCallString = TypeHelperName(e0.Type, errorWr, tok, e1.Type) + ".IsSubsetOf"; break;
staticCallString = TypeHelperName(e0Type, errorWr, tok, e1Type) + ".IsSubsetOf"; break;

case BinaryExpr.ResolvedOpcode.Disjoint:
case BinaryExpr.ResolvedOpcode.MultiSetDisjoint:
staticCallString = TypeHelperName(e0.Type, errorWr, tok, e1.Type) + ".IsDisjointFrom"; break;
staticCallString = TypeHelperName(e0Type, errorWr, tok, e1Type) + ".IsDisjointFrom"; break;
case BinaryExpr.ResolvedOpcode.InSet:
case BinaryExpr.ResolvedOpcode.InMultiSet:
case BinaryExpr.ResolvedOpcode.InMap:
Expand All @@ -3200,14 +3201,14 @@ protected override void CompileBinOp(BinaryExpr.ResolvedOpcode op,
staticCallString = TypeHelperName(resultType, errorWr, tok) + ".Subtract"; break;

case BinaryExpr.ResolvedOpcode.ProperPrefix:
staticCallString = TypeHelperName(e0.Type, errorWr, e0.tok) + ".IsProperPrefixOf"; break;
staticCallString = TypeHelperName(e0Type, errorWr, e0Type.tok) + ".IsProperPrefixOf"; break;
case BinaryExpr.ResolvedOpcode.Prefix:
staticCallString = TypeHelperName(e0.Type, errorWr, e0.tok) + ".IsPrefixOf"; break;
staticCallString = TypeHelperName(e0Type, errorWr, e0Type.tok) + ".IsPrefixOf"; break;
case BinaryExpr.ResolvedOpcode.Concat:
staticCallString = TypeHelperName(e0.Type, errorWr, e0.tok) + ".Concat"; break;
staticCallString = TypeHelperName(e0Type, errorWr, e0Type.tok) + ".Concat"; break;

default:
base.CompileBinOp(op, e0, e1, tok, resultType,
base.CompileBinOp(op, e0Type, e1Type, tok, resultType,
out opString, out preOpString, out postOpString, out callString, out staticCallString, out reverseArguments, out truncateResult, out convertE1_to_int, out coerceE1,
errorWr);
break;
Expand Down
15 changes: 8 additions & 7 deletions Source/DafnyCore/Backends/Cplusplus/CppCodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1956,7 +1956,8 @@ protected override void EmitConstructorCheck(string source, DatatypeCtor ctor, C
wr.Write("is_{1}({0})", source, DatatypeSubStructName(ctor));
}

protected override void EmitDestructor(Action<ConcreteSyntaxTree> source, Formal dtor, int formalNonGhostIndex, DatatypeCtor ctor, List<Type> typeArgs, Type bvType, ConcreteSyntaxTree wr) {
protected override void EmitDestructor(Action<ConcreteSyntaxTree> source, Formal dtor, int formalNonGhostIndex,
DatatypeCtor ctor, Func<List<Type>> getTypeArgs, Type bvType, ConcreteSyntaxTree wr) {
if (ctor.EnclosingDatatype is TupleTypeDecl) {
wr.Write("(");
source(wr);
Expand Down Expand Up @@ -2042,7 +2043,7 @@ bool IsDirectlyComparable(Type t) {
}

protected override void CompileBinOp(BinaryExpr.ResolvedOpcode op,
Expression e0, Expression e1, IToken tok, Type resultType,
Type e0Type, Type e1Type, IToken tok, Type resultType,
out string opString,
out string preOpString,
out string postOpString,
Expand Down Expand Up @@ -2096,9 +2097,9 @@ protected override void CompileBinOp(BinaryExpr.ResolvedOpcode op,
break;

case BinaryExpr.ResolvedOpcode.EqCommon: {
if (IsDirectlyComparable(e0.Type)) {
if (IsDirectlyComparable(e0Type)) {
opString = "==";
} else if (e0.Type.IsRefType) {
} else if (e0Type.IsRefType) {
opString = "==";
} else {
//staticCallString = "==";
Expand All @@ -2107,9 +2108,9 @@ protected override void CompileBinOp(BinaryExpr.ResolvedOpcode op,
break;
}
case BinaryExpr.ResolvedOpcode.NeqCommon: {
if (IsDirectlyComparable(e0.Type)) {
if (IsDirectlyComparable(e0Type)) {
opString = "!=";
} else if (e0.Type.IsRefType) {
} else if (e0Type.IsRefType) {
opString = "!=";
} else {
opString = "!=";
Expand Down Expand Up @@ -2147,7 +2148,7 @@ protected override void CompileBinOp(BinaryExpr.ResolvedOpcode op,
case BinaryExpr.ResolvedOpcode.RightShift:
if (AsNativeType(resultType) != null) {
opString = ">>";
if (AsNativeType(e1.Type) == null) {
if (AsNativeType(e1Type) == null) {
postOpString = ".Uint64()";
}
} else {
Expand Down
28 changes: 21 additions & 7 deletions Source/DafnyCore/Backends/Dafny/DafnyCodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1923,8 +1923,9 @@ protected override ConcreteSyntaxTree EmitBetaRedex(List<string> boundVars, List
}
}

protected override void EmitDestructor(Action<ConcreteSyntaxTree> source, Formal dtor, int formalNonGhostIndex, DatatypeCtor ctor,
List<Type> typeArgs, Type bvType, ConcreteSyntaxTree wr) {
protected override void EmitDestructor(Action<ConcreteSyntaxTree> source, Formal dtor, int formalNonGhostIndex,
DatatypeCtor ctor,
Func<List<Type>> getTypeArgs, Type bvType, ConcreteSyntaxTree wr) {
if (GetExprBuilder(wr, out var builder)) {
if (DatatypeWrapperEraser.IsErasableDatatypeWrapper(Options, ctor.EnclosingDatatype, out var coreDtor)) {
Contract.Assert(coreDtor.CorrespondingFormals.Count == 1);
Expand Down Expand Up @@ -2123,7 +2124,7 @@ private static DAST.Expression BinaryOp(_IBinOp op, _IExpression left, _IExpress
}

protected override void CompileBinOp(BinaryExpr.ResolvedOpcode op,
Expression e0, Expression e1, IToken tok, Type resultType,
Type e0Type, Type e1Type, IToken tok, Type resultType,
out string opString,
out string preOpString,
out string postOpString,
Expand Down Expand Up @@ -2206,8 +2207,8 @@ object C(System.Func<DAST.Expression, DAST.Expression, DAST.Expression> callback

var newBuilder = op switch {
BinaryExpr.ResolvedOpcode.EqCommon => B((BinOp)BinOp.create_Eq(
e0.Type.IsRefType,
!e0.Type.IsNonNullRefType
e0Type.IsRefType,
!e0Type.IsNonNullRefType
)),
BinaryExpr.ResolvedOpcode.SetEq => B((BinOp)BinOp.create_Eq(false, false)),
BinaryExpr.ResolvedOpcode.MapEq => B((BinOp)BinOp.create_Eq(false, false)),
Expand All @@ -2216,8 +2217,8 @@ object C(System.Func<DAST.Expression, DAST.Expression, DAST.Expression> callback
BinaryExpr.ResolvedOpcode.NeqCommon => C((left, right) =>
Not(BinaryOp(
BinOp.create_Eq(
e0.Type.IsRefType,
!e0.Type.IsNonNullRefType
e0Type.IsRefType,
!e0Type.IsNonNullRefType
), left, right))),
BinaryExpr.ResolvedOpcode.SetNeq => C((left, right) =>
Not(BinaryOp(BinOp.create_Eq(false, false), left, right))),
Expand Down Expand Up @@ -2668,5 +2669,18 @@ protected override void EmitHaltRecoveryStmt(Statement body, string haltMessageV
AddUnsupported("<i>EmitHaltRecoveryStmt</i>");
}

protected override void EmitNestedMatchExpr(NestedMatchExpr match, bool inLetExprBody, ConcreteSyntaxTree output,
ConcreteSyntaxTree wStmts) {
EmitExpr(match.Flattened, inLetExprBody, output, wStmts);
}

protected override void TrOptNestedMatchExpr(NestedMatchExpr match, Type resultType, ConcreteSyntaxTree wr, ConcreteSyntaxTree wStmts,
bool inLetExprBody, IVariable accumulatorVar) {
TrExprOpt(match.Flattened, resultType, wr, wStmts, inLetExprBody, accumulatorVar);
}

protected override void EmitNestedMatchStmt(NestedMatchStmt match, ConcreteSyntaxTree writer) {
TrStmt(match.Flattened, writer);
}
}
}
33 changes: 17 additions & 16 deletions Source/DafnyCore/Backends/GoLang/GoCodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3272,7 +3272,8 @@ protected override void EmitConstructorCheck(string source, DatatypeCtor ctor, C
wr.Write("{0}.{1}()", source, FormatDatatypeConstructorCheckName(ctor.GetCompileName(Options)));
}

protected override void EmitDestructor(Action<ConcreteSyntaxTree> source, Formal dtor, int formalNonGhostIndex, DatatypeCtor ctor, List<Type> typeArgs, Type bvType, ConcreteSyntaxTree wr) {
protected override void EmitDestructor(Action<ConcreteSyntaxTree> source, Formal dtor, int formalNonGhostIndex,
DatatypeCtor ctor, Func<List<Type>> getTypeArgs, Type bvType, ConcreteSyntaxTree wr) {
if (DatatypeWrapperEraser.IsErasableDatatypeWrapper(Options, ctor.EnclosingDatatype, out var coreDtor)) {
Contract.Assert(coreDtor.CorrespondingFormals.Count == 1);
Contract.Assert(dtor == coreDtor.CorrespondingFormals[0]); // any other destructor is a ghost
Expand All @@ -3281,7 +3282,7 @@ protected override void EmitDestructor(Action<ConcreteSyntaxTree> source, Formal
Contract.Assert(tupleTypeDecl.NonGhostDims != 1); // such a tuple is an erasable-wrapper type, handled above
wr.Write("(*(");
source(wr);
wr.Write(").IndexInt({0})).({1})", formalNonGhostIndex, TypeName(typeArgs[formalNonGhostIndex], wr, Token.NoToken));
wr.Write(").IndexInt({0})).({1})", formalNonGhostIndex, TypeName(getTypeArgs()[formalNonGhostIndex], wr, Token.NoToken));
} else {
var dtorName = DatatypeFieldName(dtor, formalNonGhostIndex);
wr = EmitCoercionIfNecessary(from: dtor.Type, to: bvType, tok: dtor.tok, wr: wr);
Expand Down Expand Up @@ -3375,7 +3376,7 @@ private bool IsComparedByEquals(Type t) {
}

protected override void CompileBinOp(BinaryExpr.ResolvedOpcode op,
Expression e0, Expression e1, IToken tok, Type resultType,
Type e0Type, Type e1Type, IToken tok, Type resultType,
out string opString,
out string preOpString,
out string postOpString,
Expand Down Expand Up @@ -3421,8 +3422,8 @@ protected override void CompileBinOp(BinaryExpr.ResolvedOpcode op,
break;

case BinaryExpr.ResolvedOpcode.EqCommon: {
var eqType = DatatypeWrapperEraser.SimplifyType(Options, e0.Type);
if (!EqualsUpToParameters(eqType, DatatypeWrapperEraser.SimplifyType(Options, e1.Type))) {
var eqType = DatatypeWrapperEraser.SimplifyType(Options, e0Type);
if (!EqualsUpToParameters(eqType, DatatypeWrapperEraser.SimplifyType(Options, e1Type))) {
staticCallString = $"{HelperModulePrefix}AreEqual";
} else if (IsOrderedByCmp(eqType)) {
callString = "Cmp";
Expand All @@ -3437,8 +3438,8 @@ protected override void CompileBinOp(BinaryExpr.ResolvedOpcode op,
break;
}
case BinaryExpr.ResolvedOpcode.NeqCommon: {
var eqType = DatatypeWrapperEraser.SimplifyType(Options, e0.Type);
if (!EqualsUpToParameters(eqType, DatatypeWrapperEraser.SimplifyType(Options, e1.Type))) {
var eqType = DatatypeWrapperEraser.SimplifyType(Options, e0Type);
if (!EqualsUpToParameters(eqType, DatatypeWrapperEraser.SimplifyType(Options, e1Type))) {
preOpString = "!";
staticCallString = $"{HelperModulePrefix}AreEqual";
} else if (IsDirectlyComparable(eqType)) {
Expand All @@ -3458,31 +3459,31 @@ protected override void CompileBinOp(BinaryExpr.ResolvedOpcode op,
}

case BinaryExpr.ResolvedOpcode.Lt:
if (IsOrderedByCmp(e0.Type)) {
if (IsOrderedByCmp(e0Type)) {
callString = "Cmp";
postOpString = " < 0";
} else {
opString = "<";
}
break;
case BinaryExpr.ResolvedOpcode.Le:
if (IsOrderedByCmp(e0.Type)) {
if (IsOrderedByCmp(e0Type)) {
callString = "Cmp";
postOpString = " <= 0";
} else {
opString = "<=";
}
break;
case BinaryExpr.ResolvedOpcode.Ge:
if (IsOrderedByCmp(e0.Type)) {
if (IsOrderedByCmp(e0Type)) {
callString = "Cmp";
postOpString = " >= 0";
} else {
opString = ">=";
}
break;
case BinaryExpr.ResolvedOpcode.Gt:
if (IsOrderedByCmp(e0.Type)) {
if (IsOrderedByCmp(e0Type)) {
callString = "Cmp";
postOpString = " > 0";
} else {
Expand All @@ -3495,11 +3496,11 @@ protected override void CompileBinOp(BinaryExpr.ResolvedOpcode op,
}
if (AsNativeType(resultType) != null) {
opString = "<<";
if (AsNativeType(e1.Type) == null) {
if (AsNativeType(e1Type) == null) {
postOpString = ".Uint64()";
}
} else {
if (AsNativeType(e1.Type) != null) {
if (AsNativeType(e1Type) != null) {
callString = "Lsh(_dafny.IntOfUint64(uint64";
postOpString = "))";
} else {
Expand All @@ -3510,11 +3511,11 @@ protected override void CompileBinOp(BinaryExpr.ResolvedOpcode op,
case BinaryExpr.ResolvedOpcode.RightShift:
if (AsNativeType(resultType) != null) {
opString = ">>";
if (AsNativeType(e1.Type) == null) {
if (AsNativeType(e1Type) == null) {
postOpString = ".Uint64()";
}
} else {
if (AsNativeType(e1.Type) != null) {
if (AsNativeType(e1Type) != null) {
callString = "Rsh(_dafny.IntOfUint64(uint64";
postOpString = "))";
} else {
Expand Down Expand Up @@ -3634,7 +3635,7 @@ protected override void CompileBinOp(BinaryExpr.ResolvedOpcode op,
staticCallString = $"{DafnySequenceCompanion}.Contains"; reverseArguments = true; break;

default:
base.CompileBinOp(op, e0, e1, tok, resultType,
base.CompileBinOp(op, e0Type, e1Type, tok, resultType,
out opString, out preOpString, out postOpString, out callString, out staticCallString, out reverseArguments, out truncateResult, out convertE1_to_int, out coerceE1,
errorWr);
break;
Expand Down
Loading

0 comments on commit 5b048ff

Please sign in to comment.