Skip to content

Commit

Permalink
tacred: also cache result of reference_opt_value
Browse files Browse the repository at this point in the history
  • Loading branch information
SkySkimmer authored and eladrion committed Jan 15, 2024
1 parent 5b3692d commit 880f3da
Showing 1 changed file with 54 additions and 42 deletions.
96 changes: 54 additions & 42 deletions pretyping/tacred.ml
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,31 @@ let destEvalRefU sigma c = match EConstr.kind sigma c with
| Evar ev -> (EvalEvar ev, EInstance.empty)
| _ -> anomaly (Pp.str "Not an unfoldable reference.")

let reference_opt_value env sigma eval u =
module CacheTable = Hashtbl.Make(struct
type t = Constant.t * UVars.Instance.t

(* WARNING if we use CanOrd and we have [M.x := N.x] unfolding M.x
first will put [M.x -> N.x] in the cache, then trying to unfold
N.x will return N.x ie loop. *)
let equal (c,u) (c',u') =
Constant.UserOrd.equal c c' && UVars.Instance.equal u u'

let hash (c,u) =
Hashset.Combine.combine (Constant.UserOrd.hash c) (UVars.Instance.hash u)
end)

let reference_opt_value cache env sigma eval u =
match eval with
| EvalConst cst ->
let u = EInstance.kind sigma u in
Option.map EConstr.of_constr (constant_opt_value_in env (cst,u))
let cu = (cst, u) in
begin match CacheTable.find_opt cache cu with
| Some v -> v
| None ->
let v = Option.map EConstr.of_constr (constant_opt_value_in env cu) in
CacheTable.add cache cu v;
v
end
| EvalVar id ->
env |> lookup_named id |> NamedDecl.get_value
| EvalRel n ->
Expand All @@ -143,8 +163,8 @@ let reference_opt_value env sigma eval u =
| c -> Some (EConstr.of_kind c)

exception NotEvaluable
let reference_value env sigma c u =
match reference_opt_value env sigma c u with
let reference_value cache env sigma c u =
match reference_opt_value cache env sigma c u with
| None -> raise NotEvaluable
| Some d -> d

Expand Down Expand Up @@ -252,8 +272,8 @@ let check_fix_reversibility env sigma ref u labs args minarg refs ((lv,i),_ as f
refolding_data;
}

let compute_fix_wrapper allowed_reds env sigma ref u =
try match reference_opt_value env sigma ref u with
let compute_fix_wrapper ((cache,_),allowed_reds) env sigma ref u =
try match reference_opt_value cache env sigma ref u with
| None -> None
| Some c ->
let labs, ccl = whd_decompose_lambda env sigma c in
Expand Down Expand Up @@ -304,7 +324,7 @@ let deactivate_delta allowed_reds =
for unary fixpoints and to the last constant encapsulating the Fix
for mutual fixpoints *)

let compute_consteval allowed_reds env sigma ref u =
let compute_consteval ((cache,_),allowed_reds as cache_reds) env sigma ref u =
let allowed_reds_no_delta = deactivate_delta allowed_reds in
let rec srec env all_abs lastref lastu onlyproj c =
let c', args = whd_stack_gen allowed_reds_no_delta env sigma c in
Expand All @@ -325,7 +345,7 @@ let compute_consteval allowed_reds env sigma ref u =
with Elimconst -> NotAnElimination
else
(* Try to refold to [lastref] *)
let last_labs, last_args, names = invert_names allowed_reds env sigma lastref lastu names i in
let last_labs, last_args, names = invert_names cache_reds env sigma lastref lastu names i in
try check_fix_reversibility env sigma lastref lastu last_labs last_args n_all_abs names fix
with Elimconst -> NotAnElimination)
| Case (_,_,_,_,_,d,_) when isRel sigma d && not onlyproj -> EliminationCases (List.length all_abs)
Expand All @@ -334,38 +354,29 @@ let compute_consteval allowed_reds env sigma ref u =
| _ when isTransparentEvalRef env sigma (RedFlags.red_transparent allowed_reds) c' ->
(* Continue stepwise unfolding from [c' args] *)
let ref, u = destEvalRefU sigma c' in
(match reference_opt_value env sigma ref u with
(match reference_opt_value cache env sigma ref u with
| None -> NotAnElimination (* e.g. if a rel *)
| Some c -> srec env all_abs ref u onlyproj (applist (c, args)))
| _ -> NotAnElimination
in
match reference_opt_value env sigma ref u with
match reference_opt_value cache env sigma ref u with
| None -> NotAnElimination
| Some c -> srec env [] ref u false c

module CacheTable = Hashtbl.Make(struct
type t = Constant.t * UVars.Instance.t
let equal (c,u) (c',u') =
Constant.CanOrd.equal c c' && UVars.Instance.equal u u'

let hash (c,u) =
Hashset.Combine.combine (Constant.CanOrd.hash c) (UVars.Instance.hash u)
end)

let make_simpl_cache () =
CacheTable.create 12
CacheTable.create 12, CacheTable.create 12

let reference_eval (cache,allowed_reds) env sigma ref u =
let reference_eval ((_,cache),_ as cache_reds) env sigma ref u =
match ref with
| EvalConst cst as ref ->
let cu = cst, EInstance.kind sigma u in
(match CacheTable.find_opt cache cu with
| Some v -> v
| None ->
let v = compute_consteval allowed_reds env sigma ref u in
let v = compute_consteval cache_reds env sigma ref u in
CacheTable.add cache cu v;
v)
| ref -> compute_consteval allowed_reds env sigma ref u
| ref -> compute_consteval cache_reds env sigma ref u

(* If f is bound to EliminationFix (n',refs,infos), then n' is the minimal
number of args for triggering the reduction and infos is
Expand Down Expand Up @@ -628,7 +639,7 @@ let make_simpl_reds env =
constants by keeping the name of the constants in the recursive calls;
it fails if no redex is around *)

let rec red_elim_const allowed_reds env sigma ref u largs =
let rec red_elim_const ((cache,_),_ as cache_reds) env sigma ref u largs =
let open ReductionBehaviour in
let nargs = List.length largs in
let* largs, unfold_anyway, unfold_nonelim, nocase =
Expand All @@ -640,47 +651,47 @@ let rec red_elim_const allowed_reds env sigma ref u largs =
| Some (UnfoldWhen { recargs = x::l } | UnfoldWhenNoMatch { recargs = x::l })
when nargs <= List.fold_left max x l -> NotReducible
| Some (UnfoldWhen { recargs; nargs = None }) ->
let* params = reduce_params allowed_reds env sigma largs recargs in
let* params = reduce_params cache_reds env sigma largs recargs in
Reduced (params,
false,
false,
false)
| Some (UnfoldWhenNoMatch { recargs; nargs = None }) ->
let* params = reduce_params allowed_reds env sigma largs recargs in
let* params = reduce_params cache_reds env sigma largs recargs in
Reduced (params,
false,
false,
true)
| Some (UnfoldWhen { recargs; nargs = Some n }) ->
let is_empty = List.is_empty recargs in
let* params = reduce_params allowed_reds env sigma largs recargs in
let* params = reduce_params cache_reds env sigma largs recargs in
Reduced (params,
is_empty && nargs >= n,
not is_empty && nargs >= n,
false)
| Some (UnfoldWhenNoMatch { recargs; nargs = Some n }) ->
let is_empty = List.is_empty recargs in
let* params = reduce_params allowed_reds env sigma largs recargs in
let* params = reduce_params cache_reds env sigma largs recargs in
Reduced (params,
is_empty && nargs >= n,
not is_empty && nargs >= n,
true)
in
let ans = match reference_eval allowed_reds env sigma ref u with
let ans = match reference_eval cache_reds env sigma ref u with
| EliminationCases n when nargs >= n ->
let c = reference_value env sigma ref u in
let c = reference_value cache env sigma ref u in
let c', lrest = whd_nothing_for_iota env sigma (c, largs) in
let* ans = special_red_case allowed_reds env sigma (EConstr.destCase sigma c') in
let* ans = special_red_case cache_reds env sigma (EConstr.destCase sigma c') in
Reduced ((ans, lrest), nocase)
| EliminationProj n when nargs >= n ->
let c = reference_value env sigma ref u in
let c = reference_value cache env sigma ref u in
let c', lrest = whd_nothing_for_iota env sigma (c, largs) in
let* ans = reduce_proj allowed_reds env sigma c' in
let* ans = reduce_proj cache_reds env sigma c' in
Reduced ((ans, lrest), nocase)
| EliminationFix {trigger_min_args; refolding_target; refolding_data}
when nargs >= trigger_min_args ->
let rec descend (ref,u) args =
let c = reference_value env sigma ref u in
let c = reference_value cache env sigma ref u in
if evaluable_reference_eq sigma ref refolding_target then
(c,args)
else
Expand All @@ -689,16 +700,16 @@ let rec red_elim_const allowed_reds env sigma ref u largs =
let (_, midargs as s) = descend (ref,u) largs in
let d, lrest = whd_nothing_for_iota env sigma s in
let f = refolding_data, midargs in
let* (c, rest) = reduce_fix allowed_reds env sigma (Some f) (destFix sigma d) lrest in
let* (c, rest) = reduce_fix cache_reds env sigma (Some f) (destFix sigma d) lrest in
Reduced ((c, rest), nocase)
| NotAnElimination when unfold_nonelim ->
let c = reference_value env sigma ref u in
let c = reference_value cache env sigma ref u in
Reduced ((whd_betaiotazeta env sigma (applist (c, largs)), []), nocase)
| _ -> NotReducible
in
match ans with
| NotReducible when unfold_anyway ->
let c = reference_value env sigma ref u in
let c = reference_value cache env sigma ref u in
Reduced ((whd_betaiotazeta env sigma (applist (c, largs)), []), nocase)
| _ -> ans

Expand Down Expand Up @@ -845,11 +856,11 @@ and whd_construct_stack allowed_reds env sigma s =

(* reduce until finding an applied constructor (or primitive value) or fail *)

and whd_construct allowed_reds env sigma s =
and whd_construct ((cache,_),_ as allowed_reds) env sigma s =
let (constr, cargs) = whd_simpl_stack allowed_reds env sigma s in
match match_eval_ref env sigma constr cargs with
| Some (ref, u) ->
(match reference_opt_value env sigma ref u with
(match reference_opt_value cache env sigma ref u with
| None -> NotReducible
| Some gvalue ->
if reducible_construct sigma gvalue then Reduced (Some (ref, u), gvalue, cargs)
Expand All @@ -868,6 +879,7 @@ and whd_construct allowed_reds env sigma s =

let try_red_product env sigma c =
let simpfun c = clos_norm_flags RedFlags.betaiotazeta env sigma c in
let cache = CacheTable.create 12 in
let rec redrec env x =
let x = whd_betaiota env sigma x in
match EConstr.kind sigma x with
Expand Down Expand Up @@ -907,7 +919,7 @@ let try_red_product env sigma c =
| Some (ref, u) ->
(* TO DO: re-fold fixpoints after expansion *)
(* to get true one-step reductions *)
(match reference_opt_value env sigma ref u with
(match reference_opt_value cache env sigma ref u with
| None -> NotReducible
| Some c -> Reduced c)
| _ -> NotReducible)
Expand Down Expand Up @@ -1299,7 +1311,7 @@ let find_hnf_rectype env sigma t =
exception NotStepReducible

let one_step_reduce env sigma c =
let cache_reds = make_simpl_cache(), RedFlags.betadeltazeta in
let (cache,_), _ as cache_reds = make_simpl_cache(), RedFlags.betadeltazeta in
let rec redrec (x, stack) =
match EConstr.kind sigma x with
| Lambda (n,t,c) ->
Expand All @@ -1324,7 +1336,7 @@ let one_step_reduce env sigma c =
begin match red_elim_const cache_reds env sigma ref u stack with
| Reduced (c, _) -> c
| NotReducible ->
match reference_opt_value env sigma ref u with
match reference_opt_value cache env sigma ref u with
| Some d -> (d, stack)
| None -> raise NotStepReducible
end
Expand Down

0 comments on commit 880f3da

Please sign in to comment.