Skip to content

Commit

Permalink
Merge pull request #6030 from roc-lang/pattern-match-rest-as
Browse files Browse the repository at this point in the history
Pattern match rest as
  • Loading branch information
bhansconnect authored Nov 21, 2023
2 parents 1f0f25f + af5b209 commit 08ee6ed
Show file tree
Hide file tree
Showing 17 changed files with 892 additions and 838 deletions.
5 changes: 1 addition & 4 deletions crates/compiler/builtins/roc/List.roc
Original file line number Diff line number Diff line change
Expand Up @@ -1123,10 +1123,7 @@ findLastIndex = \list, matches ->
## Some languages have a function called **`slice`** which works similarly to this.
sublist : List elem, { start : Nat, len : Nat } -> List elem
sublist = \list, config ->
if config.len == 0 then
[]
else
sublistLowlevel list config.start config.len
sublistLowlevel list config.start config.len

## low-level slicing operation that does no bounds checking
sublistLowlevel : List elem, Nat, Nat -> List elem
Expand Down
42 changes: 40 additions & 2 deletions crates/compiler/exhaustive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ pub fn is_useful(mut old_matrix: PatternMatrix, mut vector: Row) -> bool {
vector.extend(args);
} else {
// TODO turn this into an iteration over the outer loop rather than bouncing
vector.extend(args);
for list_ctor in spec_list_ctors {
let mut old_matrix = old_matrix.clone();
let mut spec_matrix = Vec::with_capacity(old_matrix.len());
Expand All @@ -400,10 +399,19 @@ pub fn is_useful(mut old_matrix: PatternMatrix, mut vector: Row) -> bool {
&mut spec_matrix,
);

if is_useful(spec_matrix, vector.clone()) {
let mut vector = vector.clone();
specialize_row_with_polymorphic_list(
&mut vector,
&args,
arity,
list_ctor,
);

if is_useful(spec_matrix, vector) {
return true;
}
}

return false;
}
}
Expand Down Expand Up @@ -504,6 +512,36 @@ fn specialize_matrix_by_list(
}
}

fn specialize_row_with_polymorphic_list(
row: &mut Vec<Pattern>,
list_element_patterns: &[Pattern],
polymorphic_list_ctor: ListArity,
specialized_list_ctor: ListArity,
) {
let min_len = specialized_list_ctor.min_len();
if list_element_patterns.len() > min_len {
row.extend(list_element_patterns.iter().cloned());
}

let (patterns_before, patterns_after) = match polymorphic_list_ctor {
ListArity::Slice(before, after) => (
&list_element_patterns[..before],
&list_element_patterns[list_element_patterns.len() - after..],
),
ListArity::Exact(_) => (list_element_patterns, &[] as &[Pattern]),
};

let middle_any_patterns_needed =
specialized_list_ctor.min_len() - polymorphic_list_ctor.min_len();
let middle_patterns = std::iter::repeat(Anything).take(middle_any_patterns_needed);

row.extend(
(patterns_before.iter().cloned())
.chain(middle_patterns)
.chain(patterns_after.iter().cloned()),
);
}

// Specialize a row that matches a list's constructor(s).
//
// See the docs on [build_list_ctors_covering_patterns] for more information on how list
Expand Down
54 changes: 54 additions & 0 deletions crates/compiler/load/tests/test_reporting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12596,6 +12596,60 @@ In roc, functions are always written as a lambda, like{}
"###
);

test_no_problem!(
list_match_spread_required_front_back,
indoc!(
r#"
l : List [A, B]
when l is
[A, ..] -> ""
[.., A] -> ""
[..] -> ""
"#
)
);

test_report!(
list_match_spread_redundant_front_back,
indoc!(
r#"
l : List [A]
when l is
[A, ..] -> ""
[.., A] -> ""
[..] -> ""
"#
),
@r###"
── REDUNDANT PATTERN ───────────────────────────────────── /code/proj/Main.roc ─
The 2nd pattern is redundant:
6│ when l is
7│ [A, ..] -> ""
8│> [.., A] -> ""
9│ [..] -> ""
Any value of this shape will be handled by a previous pattern, so this
one should be removed.
"###
);

test_no_problem!(
list_match_spread_as,
indoc!(
r#"
l : List [A, B]
when l is
[A, .. as rest] | [.. as rest, A] -> rest
[.. as rest] -> rest
"#
)
);

test_no_problem!(
list_match_exhaustive_empty_and_rest_with_unary_head,
indoc!(
Expand Down
4 changes: 4 additions & 0 deletions crates/compiler/mono/src/ir/decision_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -634,8 +634,10 @@ fn test_for_pattern<'a>(pattern: &Pattern<'a>) -> Option<Test<'a>> {

List {
arity,
list_layout: _,
element_layout: _,
elements: _,
opt_rest: _,
} => IsListLen {
bound: match arity {
ListArity::Exact(_) => ListLenBound::Exact,
Expand Down Expand Up @@ -908,7 +910,9 @@ fn to_relevant_branch_help<'a>(
List {
arity: my_arity,
elements,
list_layout: _,
element_layout: _,
opt_rest: _,
} => match test {
IsListLen {
bound: test_bound,
Expand Down
80 changes: 79 additions & 1 deletion crates/compiler/mono/src/ir/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ pub enum Pattern<'a> {
},
List {
arity: ListArity,
list_layout: InLayout<'a>,
element_layout: InLayout<'a>,
elements: Vec<'a, Pattern<'a>>,
opt_rest: Option<(usize, Option<Symbol>)>,
},
}

Expand Down Expand Up @@ -1050,10 +1052,18 @@ fn from_can_pattern_help<'a>(
}

List {
list_var: _,
list_var,
elem_var,
patterns,
} => {
let list_layout = match layout_cache.from_var(env.arena, *list_var, env.subs) {
Ok(lay) => lay,
Err(LayoutProblem::UnresolvedTypeVar(_)) => {
return Err(RuntimeError::UnresolvedTypeVar)
}
Err(LayoutProblem::Erroneous) => return Err(RuntimeError::ErroneousType),
};

let element_layout = match layout_cache.from_var(env.arena, *elem_var, env.subs) {
Ok(lay) => lay,
Err(LayoutProblem::UnresolvedTypeVar(_)) => {
Expand All @@ -1073,8 +1083,10 @@ fn from_can_pattern_help<'a>(

Ok(Pattern::List {
arity,
list_layout,
element_layout,
elements: mono_patterns,
opt_rest: patterns.opt_rest,
})
}
}
Expand Down Expand Up @@ -1240,17 +1252,21 @@ fn store_pattern_help<'a>(

List {
arity,
list_layout,
element_layout,
elements,
opt_rest,
} => {
return store_list_pattern(
env,
procs,
layout_cache,
outer_symbol,
*arity,
*list_layout,
*element_layout,
elements,
opt_rest,
stmt,
)
}
Expand Down Expand Up @@ -1447,8 +1463,10 @@ fn store_list_pattern<'a>(
layout_cache: &mut LayoutCache<'a>,
list_sym: Symbol,
list_arity: ListArity,
list_layout: InLayout<'a>,
element_layout: InLayout<'a>,
elements: &[Pattern<'a>],
opt_rest: &Option<(usize, Option<Symbol>)>,
mut stmt: Stmt<'a>,
) -> StorePattern<'a> {
use Pattern::*;
Expand Down Expand Up @@ -1526,13 +1544,73 @@ fn store_list_pattern<'a>(
}
}

stmt = store_list_rest(env, list_sym, list_arity, list_layout, opt_rest, stmt);

if is_productive {
StorePattern::Productive(stmt)
} else {
StorePattern::NotProductive(stmt)
}
}

fn store_list_rest<'a>(
env: &mut Env<'a, '_>,
list_sym: Symbol,
list_arity: ListArity,
list_layout: InLayout<'a>,
opt_rest: &Option<(usize, Option<Symbol>)>,
mut stmt: Stmt<'a>,
) -> Stmt<'a> {
if let Some((index, Some(rest_sym))) = opt_rest {
let usize_layout = Layout::usize(env.target_info);

let total_dropped = list_arity.min_len();

let total_dropped_sym = env.unique_symbol();
let total_dropped_expr = Expr::Literal(Literal::Int((total_dropped as u128).to_ne_bytes()));

let list_len_sym = env.unique_symbol();
let list_len_expr = Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::ListLen,
update_mode: env.next_update_mode_id(),
},
arguments: env.arena.alloc([list_sym]),
});

let rest_len_sym = env.unique_symbol();
let rest_len_expr = Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::NumSub,
update_mode: env.next_update_mode_id(),
},
arguments: env.arena.alloc([list_len_sym, total_dropped_sym]),
});

let start_sym = env.unique_symbol();
let start_expr = Expr::Literal(Literal::Int((*index as u128).to_ne_bytes()));

let rest_expr = Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::ListSublist,
update_mode: env.next_update_mode_id(),
},
arguments: env.arena.alloc([list_sym, start_sym, rest_len_sym]),
});
let needed_stores = [
(total_dropped_sym, total_dropped_expr, usize_layout),
(list_len_sym, list_len_expr, usize_layout),
(rest_len_sym, rest_len_expr, usize_layout),
(start_sym, start_expr, usize_layout),
(*rest_sym, rest_expr, list_layout),
];
for (sym, expr, lay) in needed_stores.into_iter().rev() {
stmt = Stmt::Let(sym, expr, lay, env.arena.alloc(stmt));
}
}
stmt
}

#[allow(clippy::too_many_arguments)]
fn store_tag_pattern<'a>(
env: &mut Env<'a, '_>,
Expand Down
23 changes: 23 additions & 0 deletions crates/compiler/test_gen/src/gen_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4039,4 +4039,27 @@ mod pattern_match {
RocList<u8>
)
}

#[test]
fn rest_as() {
assert_evals_to!(
r#"
helper : List U8 -> U8
helper = \l -> when l is
[1, .. as rest, 1] -> helper rest
[1, .. as rest] -> helper rest
[.. as rest, 1] -> helper rest
[first, .., last] | [first as last] -> first + last
[] -> 0
[
helper [1, 1, 1],
helper [2, 1],
helper [1, 1, 2, 4, 1],
helper [1, 1, 8, 7, 3, 1, 1, 1],
]
"#,
RocList::from_slice(&[0, 4, 6, 11]),
RocList<u8>
)
}
}
Loading

0 comments on commit 08ee6ed

Please sign in to comment.