From 880f3dad8d35fbbf781b9c521e3fe1185e1a5090 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20Gilbert?= Date: Thu, 11 Jan 2024 17:37:04 +0100 Subject: [PATCH] tacred: also cache result of reference_opt_value --- pretyping/tacred.ml | 96 +++++++++++++++++++++++++-------------------- 1 file changed, 54 insertions(+), 42 deletions(-) diff --git a/pretyping/tacred.ml b/pretyping/tacred.ml index 06f9cb1d0bb2e..be284253eaa5f 100644 --- a/pretyping/tacred.ml +++ b/pretyping/tacred.ml @@ -128,11 +128,31 @@ let destEvalRefU sigma c = match EConstr.kind sigma c with | Evar ev -> (EvalEvar ev, EInstance.empty) | _ -> anomaly (Pp.str "Not an unfoldable reference.") -let reference_opt_value env sigma eval u = +module CacheTable = Hashtbl.Make(struct + type t = Constant.t * UVars.Instance.t + + (* WARNING if we use CanOrd and we have [M.x := N.x] unfolding M.x + first will put [M.x -> N.x] in the cache, then trying to unfold + N.x will return N.x ie loop. *) + let equal (c,u) (c',u') = + Constant.UserOrd.equal c c' && UVars.Instance.equal u u' + + let hash (c,u) = + Hashset.Combine.combine (Constant.UserOrd.hash c) (UVars.Instance.hash u) + end) + +let reference_opt_value cache env sigma eval u = match eval with | EvalConst cst -> let u = EInstance.kind sigma u in - Option.map EConstr.of_constr (constant_opt_value_in env (cst,u)) + let cu = (cst, u) in + begin match CacheTable.find_opt cache cu with + | Some v -> v + | None -> + let v = Option.map EConstr.of_constr (constant_opt_value_in env cu) in + CacheTable.add cache cu v; + v + end | EvalVar id -> env |> lookup_named id |> NamedDecl.get_value | EvalRel n -> @@ -143,8 +163,8 @@ let reference_opt_value env sigma eval u = | c -> Some (EConstr.of_kind c) exception NotEvaluable -let reference_value env sigma c u = - match reference_opt_value env sigma c u with +let reference_value cache env sigma c u = + match reference_opt_value cache env sigma c u with | None -> raise NotEvaluable | Some d -> d @@ -252,8 +272,8 @@ let check_fix_reversibility env sigma ref u labs args minarg refs ((lv,i),_ as f refolding_data; } -let compute_fix_wrapper allowed_reds env sigma ref u = - try match reference_opt_value env sigma ref u with +let compute_fix_wrapper ((cache,_),allowed_reds) env sigma ref u = + try match reference_opt_value cache env sigma ref u with | None -> None | Some c -> let labs, ccl = whd_decompose_lambda env sigma c in @@ -304,7 +324,7 @@ let deactivate_delta allowed_reds = for unary fixpoints and to the last constant encapsulating the Fix for mutual fixpoints *) -let compute_consteval allowed_reds env sigma ref u = +let compute_consteval ((cache,_),allowed_reds as cache_reds) env sigma ref u = let allowed_reds_no_delta = deactivate_delta allowed_reds in let rec srec env all_abs lastref lastu onlyproj c = let c', args = whd_stack_gen allowed_reds_no_delta env sigma c in @@ -325,7 +345,7 @@ let compute_consteval allowed_reds env sigma ref u = with Elimconst -> NotAnElimination else (* Try to refold to [lastref] *) - let last_labs, last_args, names = invert_names allowed_reds env sigma lastref lastu names i in + let last_labs, last_args, names = invert_names cache_reds env sigma lastref lastu names i in try check_fix_reversibility env sigma lastref lastu last_labs last_args n_all_abs names fix with Elimconst -> NotAnElimination) | Case (_,_,_,_,_,d,_) when isRel sigma d && not onlyproj -> EliminationCases (List.length all_abs) @@ -334,38 +354,29 @@ let compute_consteval allowed_reds env sigma ref u = | _ when isTransparentEvalRef env sigma (RedFlags.red_transparent allowed_reds) c' -> (* Continue stepwise unfolding from [c' args] *) let ref, u = destEvalRefU sigma c' in - (match reference_opt_value env sigma ref u with + (match reference_opt_value cache env sigma ref u with | None -> NotAnElimination (* e.g. if a rel *) | Some c -> srec env all_abs ref u onlyproj (applist (c, args))) | _ -> NotAnElimination in - match reference_opt_value env sigma ref u with + match reference_opt_value cache env sigma ref u with | None -> NotAnElimination | Some c -> srec env [] ref u false c -module CacheTable = Hashtbl.Make(struct - type t = Constant.t * UVars.Instance.t - let equal (c,u) (c',u') = - Constant.CanOrd.equal c c' && UVars.Instance.equal u u' - - let hash (c,u) = - Hashset.Combine.combine (Constant.CanOrd.hash c) (UVars.Instance.hash u) - end) - let make_simpl_cache () = - CacheTable.create 12 + CacheTable.create 12, CacheTable.create 12 -let reference_eval (cache,allowed_reds) env sigma ref u = +let reference_eval ((_,cache),_ as cache_reds) env sigma ref u = match ref with | EvalConst cst as ref -> let cu = cst, EInstance.kind sigma u in (match CacheTable.find_opt cache cu with | Some v -> v | None -> - let v = compute_consteval allowed_reds env sigma ref u in + let v = compute_consteval cache_reds env sigma ref u in CacheTable.add cache cu v; v) - | ref -> compute_consteval allowed_reds env sigma ref u + | ref -> compute_consteval cache_reds env sigma ref u (* If f is bound to EliminationFix (n',refs,infos), then n' is the minimal number of args for triggering the reduction and infos is @@ -628,7 +639,7 @@ let make_simpl_reds env = constants by keeping the name of the constants in the recursive calls; it fails if no redex is around *) -let rec red_elim_const allowed_reds env sigma ref u largs = +let rec red_elim_const ((cache,_),_ as cache_reds) env sigma ref u largs = let open ReductionBehaviour in let nargs = List.length largs in let* largs, unfold_anyway, unfold_nonelim, nocase = @@ -640,47 +651,47 @@ let rec red_elim_const allowed_reds env sigma ref u largs = | Some (UnfoldWhen { recargs = x::l } | UnfoldWhenNoMatch { recargs = x::l }) when nargs <= List.fold_left max x l -> NotReducible | Some (UnfoldWhen { recargs; nargs = None }) -> - let* params = reduce_params allowed_reds env sigma largs recargs in + let* params = reduce_params cache_reds env sigma largs recargs in Reduced (params, false, false, false) | Some (UnfoldWhenNoMatch { recargs; nargs = None }) -> - let* params = reduce_params allowed_reds env sigma largs recargs in + let* params = reduce_params cache_reds env sigma largs recargs in Reduced (params, false, false, true) | Some (UnfoldWhen { recargs; nargs = Some n }) -> let is_empty = List.is_empty recargs in - let* params = reduce_params allowed_reds env sigma largs recargs in + let* params = reduce_params cache_reds env sigma largs recargs in Reduced (params, is_empty && nargs >= n, not is_empty && nargs >= n, false) | Some (UnfoldWhenNoMatch { recargs; nargs = Some n }) -> let is_empty = List.is_empty recargs in - let* params = reduce_params allowed_reds env sigma largs recargs in + let* params = reduce_params cache_reds env sigma largs recargs in Reduced (params, is_empty && nargs >= n, not is_empty && nargs >= n, true) in - let ans = match reference_eval allowed_reds env sigma ref u with + let ans = match reference_eval cache_reds env sigma ref u with | EliminationCases n when nargs >= n -> - let c = reference_value env sigma ref u in + let c = reference_value cache env sigma ref u in let c', lrest = whd_nothing_for_iota env sigma (c, largs) in - let* ans = special_red_case allowed_reds env sigma (EConstr.destCase sigma c') in + let* ans = special_red_case cache_reds env sigma (EConstr.destCase sigma c') in Reduced ((ans, lrest), nocase) | EliminationProj n when nargs >= n -> - let c = reference_value env sigma ref u in + let c = reference_value cache env sigma ref u in let c', lrest = whd_nothing_for_iota env sigma (c, largs) in - let* ans = reduce_proj allowed_reds env sigma c' in + let* ans = reduce_proj cache_reds env sigma c' in Reduced ((ans, lrest), nocase) | EliminationFix {trigger_min_args; refolding_target; refolding_data} when nargs >= trigger_min_args -> let rec descend (ref,u) args = - let c = reference_value env sigma ref u in + let c = reference_value cache env sigma ref u in if evaluable_reference_eq sigma ref refolding_target then (c,args) else @@ -689,16 +700,16 @@ let rec red_elim_const allowed_reds env sigma ref u largs = let (_, midargs as s) = descend (ref,u) largs in let d, lrest = whd_nothing_for_iota env sigma s in let f = refolding_data, midargs in - let* (c, rest) = reduce_fix allowed_reds env sigma (Some f) (destFix sigma d) lrest in + let* (c, rest) = reduce_fix cache_reds env sigma (Some f) (destFix sigma d) lrest in Reduced ((c, rest), nocase) | NotAnElimination when unfold_nonelim -> - let c = reference_value env sigma ref u in + let c = reference_value cache env sigma ref u in Reduced ((whd_betaiotazeta env sigma (applist (c, largs)), []), nocase) | _ -> NotReducible in match ans with | NotReducible when unfold_anyway -> - let c = reference_value env sigma ref u in + let c = reference_value cache env sigma ref u in Reduced ((whd_betaiotazeta env sigma (applist (c, largs)), []), nocase) | _ -> ans @@ -845,11 +856,11 @@ and whd_construct_stack allowed_reds env sigma s = (* reduce until finding an applied constructor (or primitive value) or fail *) -and whd_construct allowed_reds env sigma s = +and whd_construct ((cache,_),_ as allowed_reds) env sigma s = let (constr, cargs) = whd_simpl_stack allowed_reds env sigma s in match match_eval_ref env sigma constr cargs with | Some (ref, u) -> - (match reference_opt_value env sigma ref u with + (match reference_opt_value cache env sigma ref u with | None -> NotReducible | Some gvalue -> if reducible_construct sigma gvalue then Reduced (Some (ref, u), gvalue, cargs) @@ -868,6 +879,7 @@ and whd_construct allowed_reds env sigma s = let try_red_product env sigma c = let simpfun c = clos_norm_flags RedFlags.betaiotazeta env sigma c in + let cache = CacheTable.create 12 in let rec redrec env x = let x = whd_betaiota env sigma x in match EConstr.kind sigma x with @@ -907,7 +919,7 @@ let try_red_product env sigma c = | Some (ref, u) -> (* TO DO: re-fold fixpoints after expansion *) (* to get true one-step reductions *) - (match reference_opt_value env sigma ref u with + (match reference_opt_value cache env sigma ref u with | None -> NotReducible | Some c -> Reduced c) | _ -> NotReducible) @@ -1299,7 +1311,7 @@ let find_hnf_rectype env sigma t = exception NotStepReducible let one_step_reduce env sigma c = - let cache_reds = make_simpl_cache(), RedFlags.betadeltazeta in + let (cache,_), _ as cache_reds = make_simpl_cache(), RedFlags.betadeltazeta in let rec redrec (x, stack) = match EConstr.kind sigma x with | Lambda (n,t,c) -> @@ -1324,7 +1336,7 @@ let one_step_reduce env sigma c = begin match red_elim_const cache_reds env sigma ref u stack with | Reduced (c, _) -> c | NotReducible -> - match reference_opt_value env sigma ref u with + match reference_opt_value cache env sigma ref u with | Some d -> (d, stack) | None -> raise NotStepReducible end