Skip to content

Commit

Permalink
Merge PR coq#18492: Fix caching in Tacred.reference_eval and cache un…
Browse files Browse the repository at this point in the history
…iv poly

Reviewed-by: ppedrot
Co-authored-by: ppedrot <[email protected]>
  • Loading branch information
coqbot-app[bot] and ppedrot authored Jan 15, 2024
2 parents 4b5ee43 + 142cff0 commit 0adc190
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 50 deletions.
122 changes: 72 additions & 50 deletions pretyping/tacred.ml
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,31 @@ let destEvalRefU sigma c = match EConstr.kind sigma c with
| Evar ev -> (EvalEvar ev, EInstance.empty)
| _ -> anomaly (Pp.str "Not an unfoldable reference.")

let reference_opt_value env sigma eval u =
module CacheTable = Hashtbl.Make(struct
type t = Constant.t * UVars.Instance.t

(* WARNING if we use CanOrd and we have [M.x := N.x] unfolding M.x
first will put [M.x -> N.x] in the cache, then trying to unfold
N.x will return N.x ie loop. *)
let equal (c,u) (c',u') =
Constant.UserOrd.equal c c' && UVars.Instance.equal u u'

let hash (c,u) =
Hashset.Combine.combine (Constant.UserOrd.hash c) (UVars.Instance.hash u)
end)

let reference_opt_value cache env sigma eval u =
match eval with
| EvalConst cst ->
let u = EInstance.kind sigma u in
Option.map EConstr.of_constr (constant_opt_value_in env (cst,u))
let cu = (cst, u) in
begin match CacheTable.find_opt cache cu with
| Some v -> v
| None ->
let v = Option.map EConstr.of_constr (constant_opt_value_in env cu) in
CacheTable.add cache cu v;
v
end
| EvalVar id ->
env |> lookup_named id |> NamedDecl.get_value
| EvalRel n ->
Expand All @@ -143,8 +163,8 @@ let reference_opt_value env sigma eval u =
| c -> Some (EConstr.of_kind c)

exception NotEvaluable
let reference_value env sigma c u =
match reference_opt_value env sigma c u with
let reference_value cache env sigma c u =
match reference_opt_value cache env sigma c u with
| None -> raise NotEvaluable
| Some d -> d

Expand All @@ -170,12 +190,6 @@ type constant_evaluation =
| EliminationProj of int
| NotAnElimination

(* We use a cache registered as a global table *)

type frozen = constant_evaluation Cmap.t

let eval_table = Summary.ref (Cmap.empty : frozen) ~name:"evaluation"

(* [compute_consteval] determines whether f is an "elimination constant"
either [yn:Tn]..[y1:T1](match yi with f1..fk end g1 ..gp)
Expand Down Expand Up @@ -258,8 +272,8 @@ let check_fix_reversibility env sigma ref u labs args minarg refs ((lv,i),_ as f
refolding_data;
}

let compute_fix_wrapper allowed_reds env sigma ref u =
try match reference_opt_value env sigma ref u with
let compute_fix_wrapper ((cache,_),allowed_reds) env sigma ref u =
try match reference_opt_value cache env sigma ref u with
| None -> None
| Some c ->
let labs, ccl = whd_decompose_lambda env sigma c in
Expand Down Expand Up @@ -310,7 +324,7 @@ let deactivate_delta allowed_reds =
for unary fixpoints and to the last constant encapsulating the Fix
for mutual fixpoints *)

let compute_consteval allowed_reds env sigma ref u =
let compute_consteval ((cache,_),allowed_reds as cache_reds) env sigma ref u =
let allowed_reds_no_delta = deactivate_delta allowed_reds in
let rec srec env all_abs lastref lastu onlyproj c =
let c', args = whd_stack_gen allowed_reds_no_delta env sigma c in
Expand All @@ -331,7 +345,7 @@ let compute_consteval allowed_reds env sigma ref u =
with Elimconst -> NotAnElimination
else
(* Try to refold to [lastref] *)
let last_labs, last_args, names = invert_names allowed_reds env sigma lastref lastu names i in
let last_labs, last_args, names = invert_names cache_reds env sigma lastref lastu names i in
try check_fix_reversibility env sigma lastref lastu last_labs last_args n_all_abs names fix
with Elimconst -> NotAnElimination)
| Case (_,_,_,_,_,d,_) when isRel sigma d && not onlyproj -> EliminationCases (List.length all_abs)
Expand All @@ -340,26 +354,29 @@ let compute_consteval allowed_reds env sigma ref u =
| _ when isTransparentEvalRef env sigma (RedFlags.red_transparent allowed_reds) c' ->
(* Continue stepwise unfolding from [c' args] *)
let ref, u = destEvalRefU sigma c' in
(match reference_opt_value env sigma ref u with
(match reference_opt_value cache env sigma ref u with
| None -> NotAnElimination (* e.g. if a rel *)
| Some c -> srec env all_abs ref u onlyproj (applist (c, args)))
| _ -> NotAnElimination
in
match reference_opt_value env sigma ref u with
match reference_opt_value cache env sigma ref u with
| None -> NotAnElimination
| Some c -> srec env [] ref u false c

let reference_eval allowed_reds env sigma ref u =
let make_simpl_cache () =
CacheTable.create 12, CacheTable.create 12

let reference_eval ((_,cache),_ as cache_reds) env sigma ref u =
match ref with
| EvalConst cst as ref when EInstance.is_empty u ->
(try
Cmap.find cst !eval_table
with Not_found -> begin
let v = compute_consteval allowed_reds env sigma ref u in
eval_table := Cmap.add cst v !eval_table;
v
end)
| ref -> compute_consteval allowed_reds env sigma ref u
| EvalConst cst as ref ->
let cu = cst, EInstance.kind sigma u in
(match CacheTable.find_opt cache cu with
| Some v -> v
| None ->
let v = compute_consteval cache_reds env sigma ref u in
CacheTable.add cache cu v;
v)
| ref -> compute_consteval cache_reds env sigma ref u

(* If f is bound to EliminationFix (n',refs,infos), then n' is the minimal
number of args for triggering the reduction and infos is
Expand Down Expand Up @@ -622,7 +639,7 @@ let make_simpl_reds env =
constants by keeping the name of the constants in the recursive calls;
it fails if no redex is around *)

let rec red_elim_const allowed_reds env sigma ref u largs =
let rec red_elim_const ((cache,_),_ as cache_reds) env sigma ref u largs =
let open ReductionBehaviour in
let nargs = List.length largs in
let* largs, unfold_anyway, unfold_nonelim, nocase =
Expand All @@ -634,47 +651,47 @@ let rec red_elim_const allowed_reds env sigma ref u largs =
| Some (UnfoldWhen { recargs = x::l } | UnfoldWhenNoMatch { recargs = x::l })
when nargs <= List.fold_left max x l -> NotReducible
| Some (UnfoldWhen { recargs; nargs = None }) ->
let* params = reduce_params allowed_reds env sigma largs recargs in
let* params = reduce_params cache_reds env sigma largs recargs in
Reduced (params,
false,
false,
false)
| Some (UnfoldWhenNoMatch { recargs; nargs = None }) ->
let* params = reduce_params allowed_reds env sigma largs recargs in
let* params = reduce_params cache_reds env sigma largs recargs in
Reduced (params,
false,
false,
true)
| Some (UnfoldWhen { recargs; nargs = Some n }) ->
let is_empty = List.is_empty recargs in
let* params = reduce_params allowed_reds env sigma largs recargs in
let* params = reduce_params cache_reds env sigma largs recargs in
Reduced (params,
is_empty && nargs >= n,
not is_empty && nargs >= n,
false)
| Some (UnfoldWhenNoMatch { recargs; nargs = Some n }) ->
let is_empty = List.is_empty recargs in
let* params = reduce_params allowed_reds env sigma largs recargs in
let* params = reduce_params cache_reds env sigma largs recargs in
Reduced (params,
is_empty && nargs >= n,
not is_empty && nargs >= n,
true)
in
let ans = match reference_eval allowed_reds env sigma ref u with
let ans = match reference_eval cache_reds env sigma ref u with
| EliminationCases n when nargs >= n ->
let c = reference_value env sigma ref u in
let c = reference_value cache env sigma ref u in
let c', lrest = whd_nothing_for_iota env sigma (c, largs) in
let* ans = special_red_case allowed_reds env sigma (EConstr.destCase sigma c') in
let* ans = special_red_case cache_reds env sigma (EConstr.destCase sigma c') in
Reduced ((ans, lrest), nocase)
| EliminationProj n when nargs >= n ->
let c = reference_value env sigma ref u in
let c = reference_value cache env sigma ref u in
let c', lrest = whd_nothing_for_iota env sigma (c, largs) in
let* ans = reduce_proj allowed_reds env sigma c' in
let* ans = reduce_proj cache_reds env sigma c' in
Reduced ((ans, lrest), nocase)
| EliminationFix {trigger_min_args; refolding_target; refolding_data}
when nargs >= trigger_min_args ->
let rec descend (ref,u) args =
let c = reference_value env sigma ref u in
let c = reference_value cache env sigma ref u in
if evaluable_reference_eq sigma ref refolding_target then
(c,args)
else
Expand All @@ -683,16 +700,16 @@ let rec red_elim_const allowed_reds env sigma ref u largs =
let (_, midargs as s) = descend (ref,u) largs in
let d, lrest = whd_nothing_for_iota env sigma s in
let f = refolding_data, midargs in
let* (c, rest) = reduce_fix allowed_reds env sigma (Some f) (destFix sigma d) lrest in
let* (c, rest) = reduce_fix cache_reds env sigma (Some f) (destFix sigma d) lrest in
Reduced ((c, rest), nocase)
| NotAnElimination when unfold_nonelim ->
let c = reference_value env sigma ref u in
let c = reference_value cache env sigma ref u in
Reduced ((whd_betaiotazeta env sigma (applist (c, largs)), []), nocase)
| _ -> NotReducible
in
match ans with
| NotReducible when unfold_anyway ->
let c = reference_value env sigma ref u in
let c = reference_value cache env sigma ref u in
Reduced ((whd_betaiotazeta env sigma (applist (c, largs)), []), nocase)
| _ -> ans

Expand Down Expand Up @@ -839,11 +856,11 @@ and whd_construct_stack allowed_reds env sigma s =

(* reduce until finding an applied constructor (or primitive value) or fail *)

and whd_construct allowed_reds env sigma s =
and whd_construct ((cache,_),_ as allowed_reds) env sigma s =
let (constr, cargs) = whd_simpl_stack allowed_reds env sigma s in
match match_eval_ref env sigma constr cargs with
| Some (ref, u) ->
(match reference_opt_value env sigma ref u with
(match reference_opt_value cache env sigma ref u with
| None -> NotReducible
| Some gvalue ->
if reducible_construct sigma gvalue then Reduced (Some (ref, u), gvalue, cargs)
Expand All @@ -862,6 +879,7 @@ and whd_construct allowed_reds env sigma s =

let try_red_product env sigma c =
let simpfun c = clos_norm_flags RedFlags.betaiotazeta env sigma c in
let cache = CacheTable.create 12 in
let rec redrec env x =
let x = whd_betaiota env sigma x in
match EConstr.kind sigma x with
Expand Down Expand Up @@ -901,7 +919,7 @@ let try_red_product env sigma c =
| Some (ref, u) ->
(* TO DO: re-fold fixpoints after expansion *)
(* to get true one-step reductions *)
(match reference_opt_value env sigma ref u with
(match reference_opt_value cache env sigma ref u with
| None -> NotReducible
| Some c -> Reduced c)
| _ -> NotReducible)
Expand Down Expand Up @@ -971,8 +989,9 @@ let whd_simpl_orelse_delta_but_fix_old env sigma c =

let whd_simpl_orelse_delta_but_fix env sigma c =
let reds = make_simpl_reds env in
let cache = make_simpl_cache() in
let rec redrec s =
let (constr, stack as s') = whd_simpl_stack reds env sigma s in
let (constr, stack as s') = whd_simpl_stack (cache,reds) env sigma s in
match match_eval_ref_value env sigma constr stack with
| Some c ->
(match EConstr.kind sigma (snd (decompose_lambda sigma c)) with
Expand Down Expand Up @@ -1003,13 +1022,15 @@ let hnf_constr env sigma c =
let whd_simpl_with_reds allowed_reds env sigma c =
applist (whd_simpl_stack allowed_reds env sigma (c, []))

let whd_simpl env sigma x = whd_simpl_with_reds (make_simpl_reds env) env sigma x
let whd_simpl env sigma x =
whd_simpl_with_reds (make_simpl_cache(), make_simpl_reds env) env sigma x

let simpl env sigma c =
let allowed_reds = make_simpl_reds env in
let cache = make_simpl_cache () in
let rec strongrec env t =
map_constr_with_full_binders env sigma push_rel strongrec env
(whd_simpl_with_reds allowed_reds env sigma t) in
(whd_simpl_with_reds (cache,allowed_reds) env sigma t) in
strongrec env c

(* Reduction at specific subterms *)
Expand Down Expand Up @@ -1290,6 +1311,7 @@ let find_hnf_rectype env sigma t =
exception NotStepReducible

let one_step_reduce env sigma c =
let (cache,_), _ as cache_reds = make_simpl_cache(), RedFlags.betadeltazeta in
let rec redrec (x, stack) =
match EConstr.kind sigma x with
| Lambda (n,t,c) ->
Expand All @@ -1300,21 +1322,21 @@ let one_step_reduce env sigma c =
| LetIn (_,f,_,cl) -> (Vars.subst1 f cl,stack)
| Cast (c,_,_) -> redrec (c,stack)
| Case (ci,u,pms,p,iv,c,lf) ->
begin match special_red_case RedFlags.betadeltazeta env sigma (ci,u,pms,p,iv,c,lf) with
begin match special_red_case cache_reds env sigma (ci,u,pms,p,iv,c,lf) with
| Reduced c -> (c, stack)
| NotReducible -> raise NotStepReducible
end
| Fix fix ->
begin match reduce_fix RedFlags.betadeltazeta env sigma None fix stack with
begin match reduce_fix cache_reds env sigma None fix stack with
| Reduced s' -> s'
| NotReducible -> raise NotStepReducible
end
| _ when isEvalRef env sigma x ->
let ref,u = destEvalRefU sigma x in
begin match red_elim_const RedFlags.betadeltazeta env sigma ref u stack with
begin match red_elim_const cache_reds env sigma ref u stack with
| Reduced (c, _) -> c
| NotReducible ->
match reference_opt_value env sigma ref u with
match reference_opt_value cache env sigma ref u with
| Some d -> (d, stack)
| None -> raise NotStepReducible
end
Expand Down
14 changes: 14 additions & 0 deletions test-suite/bugs/bug_18490.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
Definition plus' := plus.

Definition foo := Eval simpl in plus' 1 2.

Arguments plus : simpl never.

Lemma test P : P 3 -> P (plus' 1 2).
Proof.
intros p3.
simpl.
match goal with |- P (plus' 1 2) => idtac end.
(* the order of "Eval simpl" and "Arguments" above resulted in
incorrect cache and the goal was reduced *)
Abort.

0 comments on commit 0adc190

Please sign in to comment.