Skip to content

Commit

Permalink
Merge pull request #7204 from smores56/constrain-early-return-functions
Browse files Browse the repository at this point in the history
Constrain early returns in functions in addition to closures
  • Loading branch information
smores56 authored Nov 22, 2024
2 parents 6a3db1e + c5b2e16 commit 22423ca
Show file tree
Hide file tree
Showing 6 changed files with 436 additions and 64 deletions.
206 changes: 143 additions & 63 deletions crates/compiler/constrain/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,37 +168,38 @@ fn constrain_untyped_closure(
vars.push(closure_var);
vars.push(fn_var);

let body_type = constraints.push_expected_type(ForReason(
let return_type_index = constraints.push_expected_type(ForReason(
Reason::FunctionOutput,
return_type_index,
loc_body_expr.region,
));

let ret_constraint = env.with_fx_expectation(fx_var, None, |env| {
constrain_expr(
let returns_constraint = env.with_fx_expectation(fx_var, None, |env| {
let return_con = constrain_expr(
types,
constraints,
env,
loc_body_expr.region,
&loc_body_expr.value,
body_type,
)
});

let mut early_return_constraints = Vec::with_capacity(early_returns.len());
for (early_return_variable, early_return_region) in early_returns {
let early_return_var = constraints.push_variable(*early_return_variable);
let early_return_con = constraints.equal_types(
early_return_var,
body_type,
Category::Return,
*early_return_region,
return_type_index,
);

early_return_constraints.push(early_return_con);
}
let mut return_constraints = Vec::with_capacity(early_returns.len() + 1);
return_constraints.push(return_con);

let early_returns_constraint = constraints.and_constraint(early_return_constraints);
for (early_return_variable, early_return_region) in early_returns {
let early_return_con = constraints.equal_types_var(
*early_return_variable,
return_type_index,
Category::Return,
*early_return_region,
);

return_constraints.push(early_return_con);
}

constraints.and_constraint(return_constraints)
});

// make sure the captured symbols are sorted!
debug_assert_eq!(captured_symbols.to_vec(), {
Expand Down Expand Up @@ -231,7 +232,7 @@ fn constrain_untyped_closure(
pattern_state.vars,
pattern_state.headers,
pattern_state_constraints,
ret_constraint,
returns_constraint,
Generalizable(true),
),
constraints.and_constraint(pattern_state.delayed_fx_suffix_constraints),
Expand All @@ -242,7 +243,6 @@ fn constrain_untyped_closure(
region,
fn_var,
),
early_returns_constraint,
closure_constraint,
constraints.flex_to_pure(fx_var),
];
Expand Down Expand Up @@ -1423,23 +1423,20 @@ pub fn constrain_expr(
return_var,
} => {
let return_type_index = constraints.push_variable(*return_var);

let expected_return_value = constraints.push_expected_type(ForReason(
Reason::FunctionOutput,
return_type_index,
return_value.region,
));

let return_con = constrain_expr(
constrain_expr(
types,
constraints,
env,
return_value.region,
&return_value.value,
expected_return_value,
);

constraints.exists([*return_var], return_con)
)
}
Tag {
tag_union_var: variant_var,
Expand Down Expand Up @@ -2075,16 +2072,42 @@ fn constrain_function_def(
constraints.push_type(types, fn_type)
};

let ret_constraint = {
let con = constrain_expr(
let returns_constraint = {
let return_con = constrain_expr(
types,
constraints,
env,
loc_body_expr.region,
&loc_body_expr.value,
return_type_annotation_expected,
);
attach_resolution_constraints(constraints, env, con)

let mut return_constraints =
Vec::with_capacity(function_def.early_returns.len() + 1);
return_constraints.push(return_con);

for (early_return_variable, early_return_region) in &function_def.early_returns {
let early_return_type_expected =
constraints.push_expected_type(Expected::ForReason(
Reason::FunctionOutput,
ret_type_index,
*early_return_region,
));

vars.push(*early_return_variable);
let early_return_con = constraints.equal_types_var(
*early_return_variable,
early_return_type_expected,
Category::Return,
*early_return_region,
);

return_constraints.push(early_return_con);
}

let returns_constraint = constraints.and_constraint(return_constraints);

attach_resolution_constraints(constraints, env, returns_constraint)
};

vars.push(expr_var);
Expand All @@ -2104,7 +2127,7 @@ fn constrain_function_def(
argument_pattern_state.vars,
argument_pattern_state.headers,
defs_constraint,
ret_constraint,
returns_constraint,
// This is a syntactic function, it can be generalized
Generalizable(true),
),
Expand Down Expand Up @@ -2860,6 +2883,7 @@ fn constrain_typed_def(
function_type: fn_var,
closure_type: closure_var,
return_type: ret_var,
early_returns,
fx_type: fx_var,
captured_symbols,
arguments,
Expand Down Expand Up @@ -2929,7 +2953,7 @@ fn constrain_typed_def(
constraints.push_type(types, fn_type)
};

let body_type = constraints.push_expected_type(FromAnnotation(
let return_type = constraints.push_expected_type(FromAnnotation(
def.loc_pattern.clone(),
arguments.len(),
AnnotationSource::TypedBody {
Expand All @@ -2938,18 +2962,35 @@ fn constrain_typed_def(
ret_type_index,
));

let ret_constraint = env.with_fx_expectation(fx_var, Some(annotation.region), |env| {
constrain_expr(
types,
constraints,
env,
loc_body_expr.region,
&loc_body_expr.value,
body_type,
)
});
let returns_constraint =
env.with_fx_expectation(fx_var, Some(annotation.region), |env| {
let return_con = constrain_expr(
types,
constraints,
env,
loc_body_expr.region,
&loc_body_expr.value,
return_type,
);

let ret_constraint = attach_resolution_constraints(constraints, env, ret_constraint);
let mut return_constraints = Vec::with_capacity(early_returns.len() + 1);
return_constraints.push(return_con);

for (early_return_variable, early_return_region) in early_returns {
let early_return_con = constraints.equal_types_var(
*early_return_variable,
return_type,
Category::Return,
*early_return_region,
);

return_constraints.push(early_return_con);
}

let returns_constraint = constraints.and_constraint(return_constraints);

attach_resolution_constraints(constraints, env, returns_constraint)
});

vars.push(*fn_var);
let defs_constraint = constraints.and_constraint(argument_pattern_state.constraints);
Expand All @@ -2962,7 +3003,7 @@ fn constrain_typed_def(
argument_pattern_state.vars,
argument_pattern_state.headers,
defs_constraint,
ret_constraint,
returns_constraint,
// This is a syntactic function, it can be generalized
Generalizable(true),
),
Expand Down Expand Up @@ -3969,18 +4010,38 @@ fn constraint_recursive_function(
constraints.push_type(types, typ)
};

let expr_con = env.with_fx_expectation(fx_var, Some(annotation.region), |env| {
let expected = constraints.push_expected_type(NoExpectation(ret_type_index));
constrain_expr(
types,
constraints,
env,
loc_body_expr.region,
&loc_body_expr.value,
expected,
)
});
let expr_con = attach_resolution_constraints(constraints, env, expr_con);
let returns_constraint =
env.with_fx_expectation(fx_var, Some(annotation.region), |env| {
let expected = constraints.push_expected_type(NoExpectation(ret_type_index));
let return_con = constrain_expr(
types,
constraints,
env,
loc_body_expr.region,
&loc_body_expr.value,
expected,
);

let mut return_constraints =
Vec::with_capacity(function_def.early_returns.len() + 1);
return_constraints.push(return_con);

for (early_return_variable, early_return_region) in &function_def.early_returns
{
let early_return_con = constraints.equal_types_var(
*early_return_variable,
expected,
Category::Return,
*early_return_region,
);

return_constraints.push(early_return_con);
}

let returns_constraint = constraints.and_constraint(return_constraints);

attach_resolution_constraints(constraints, env, returns_constraint)
});

vars.push(expr_var);

Expand All @@ -3992,7 +4053,7 @@ fn constraint_recursive_function(
argument_pattern_state.vars,
argument_pattern_state.headers,
state_constraints,
expr_con,
returns_constraint,
// Syntactic function can be generalized
Generalizable(true),
),
Expand Down Expand Up @@ -4454,6 +4515,7 @@ fn rec_defs_help(
function_type: fn_var,
closure_type: closure_var,
return_type: ret_var,
early_returns,
fx_type: fx_var,
captured_symbols,
arguments,
Expand Down Expand Up @@ -4521,22 +4583,40 @@ fn rec_defs_help(
let typ = types.function(pattern_types, lambda_set, ret_type, fx_type);
constraints.push_type(types, typ)
};
let expr_con =
let returns_constraint =
env.with_fx_expectation(fx_var, Some(annotation.region), |env| {
let body_type =
let return_type_expected =
constraints.push_expected_type(NoExpectation(ret_type_index));

constrain_expr(
let return_con = constrain_expr(
types,
constraints,
env,
loc_body_expr.region,
&loc_body_expr.value,
body_type,
)
});
return_type_expected,
);

let expr_con = attach_resolution_constraints(constraints, env, expr_con);
let mut return_constraints =
Vec::with_capacity(early_returns.len() + 1);
return_constraints.push(return_con);

for (early_return_variable, early_return_region) in early_returns {
let early_return_con = constraints.equal_types_var(
*early_return_variable,
return_type_expected,
Category::Return,
*early_return_region,
);

return_constraints.push(early_return_con);
}

let returns_constraint =
constraints.and_constraint(return_constraints);

attach_resolution_constraints(constraints, env, returns_constraint)
});

vars.push(*fn_var);

Expand All @@ -4551,7 +4631,7 @@ fn rec_defs_help(
argument_pattern_state.vars,
argument_pattern_state.headers,
state_constraints,
expr_con,
returns_constraint,
generalizable,
),
// Check argument suffixes against usage
Expand Down
4 changes: 3 additions & 1 deletion crates/compiler/gen_dev/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,9 @@ trait Backend<'a> {
cond_layout,
branches,
default_branch,
ret_layout,
// always use the proc's ret_layout, as early returns can make
// this ret_layout inaccurate
ret_layout: _,
} => {
self.load_literal_symbols(&[*cond_symbol]);
self.build_switch(
Expand Down
Loading

0 comments on commit 22423ca

Please sign in to comment.