Skip to content

Commit

Permalink
refactor(pkg): use stdune in sat solver (#11282)
Browse files Browse the repository at this point in the history
Signed-off-by: Rudi Grinberg <[email protected]>
  • Loading branch information
rgrinberg authored Jan 12, 2025
1 parent 098117d commit 53c6dfc
Showing 1 changed file with 45 additions and 35 deletions.
80 changes: 45 additions & 35 deletions src/0install-solver/sat.ml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
- We add an at_most_one_clause (the paper suggests this in the Excercises, and
it's very useful for our purposes). *)

open Stdune

let debug = false

module type USER = sig
Expand All @@ -28,14 +30,16 @@ module VarID : sig
type t
type mint

val compare : t -> t -> int
val to_dyn : t -> Dyn.t
val compare : t -> t -> Ordering.t
val make_mint : unit -> mint
val issue : mint -> t
end = struct
type t = int
type mint = int ref

let compare (a : int) (b : int) = compare a b
let to_dyn = Dyn.int
let compare (a : int) (b : int) = Int.compare a b
let make_mint () = ref 0

let issue mint =
Expand Down Expand Up @@ -123,17 +127,24 @@ module Make (User : USER) = struct

let lit_equal (s1, v1) (s2, v2) = s1 == s2 && v1 == v2

module VarSet = Set.Make (VarID)
module C = Comparable.Make (VarID)
module VarSet = C.Set

module LitSet = struct
module C = Comparable.Make (struct
type t = lit

let to_dyn = Dyn.opaque

module LitSet = Set.Make (struct
type t = lit
let compare (s1, v1) (s2, v2) =
match VarID.compare v1.id v2.id with
| Eq -> Poly.compare s1 s2
| x -> x
;;
end)

let compare (s1, v1) (s2, v2) =
match VarID.compare v1.id v2.id with
| 0 -> compare s1 s2
| x -> x
;;
end)
include C.Set
end

type solution = lit -> bool

Expand Down Expand Up @@ -183,9 +194,9 @@ module Make (User : USER) = struct
let seen = ref LitSet.empty in
let rec find_unique = function
| [] -> []
| x :: xs when LitSet.mem x !seen -> find_unique xs
| x :: xs when LitSet.mem !seen x -> find_unique xs
| x :: xs ->
seen := LitSet.add x !seen;
seen := LitSet.add !seen x;
x :: find_unique xs
in
find_unique lits
Expand Down Expand Up @@ -281,7 +292,7 @@ module Make (User : USER) = struct
var_info.level <- get_decision_level problem;
var_info.reason <- Some reason;
problem.trail <- lit :: problem.trail;
Queue.add lit problem.propQ;
Queue.push problem.propQ lit;
true
;;

Expand Down Expand Up @@ -330,15 +341,15 @@ module Make (User : USER) = struct
(* if debug then log_debug "propagate: queue length = %d" (Queue.length problem.propQ); *)
try
while not (Queue.is_empty problem.propQ) do
let lit = Queue.take problem.propQ in
let lit = Queue.pop_exn problem.propQ in
let old_watches = Queue.create () in
let watches = watch_queue lit in
Queue.transfer watches old_watches;
(* if debug then log_debug "%s -> True : watches: %d" (name_lit lit) (Queue.length old_watches); *)

(* Notifiy all watchers *)
while not (Queue.is_empty old_watches) do
let clause = Queue.take old_watches in
let clause = Queue.pop_exn old_watches in
if not (clause#propagate lit)
then (
(* Conflict *)
Expand All @@ -361,7 +372,7 @@ module Make (User : USER) = struct
(* Call [clause#propagate lit] when lit becomes True *)
let watch_lit lit clause =
(* if debug then log_debug "%s is watching for %s to become True" clause#to_string (name_lit lit); *)
Queue.add clause (watch_queue lit)
Queue.push (watch_queue lit) clause
;;

let union_clause problem lits =
Expand Down Expand Up @@ -418,7 +429,7 @@ module Make (User : USER) = struct

(* We can only cause a conflict if all our lits are False, so they're all the cause.
e.g. if we are "A or B or not(C)" then "not(A) and not(B) and C" causes a conflict. *)
method calc_reason = List.map neg (Array.to_list lits)
method calc_reason = List.map ~f:neg (Array.to_list lits)

(** Which literals caused [lit] to have its current value? *)
method calc_reason_for lit =
Expand Down Expand Up @@ -481,8 +492,7 @@ module Make (User : USER) = struct
var_info.undo <- undo :: var_info.undo;
try
(* We set all other literals to False. *)
lits
|> List.iter (fun l ->
List.iter lits ~f:(fun l ->
match lit_value l with
| True when not (lit_equal l lit) ->
(* Due to queuing, we might get called with current = None
Expand Down Expand Up @@ -520,11 +530,11 @@ module Make (User : USER) = struct
(** Which literals caused [lit] to have its current value? *)
method calc_reason_for lit =
(* Find the True literal. Any true literal other than [lit] would do. *)
[ List.find (fun l -> (not (lit_equal l lit)) && lit_value l = True) lits ]
[ List.find_exn lits ~f:(fun l -> (not (lit_equal l lit)) && lit_value l = True) ]

method best_undecided =
(* if debug then log_debug "best_undecided: %s" (string_of_lits lits); *)
List.find_opt (fun l -> lit_value l = Undecided) lits
List.find lits ~f:(fun l -> lit_value l = Undecided)

method get_selected = !current
method pp = Pp.text "<at most one: " ++ pp_lits lits ++ Pp.char '>'
Expand Down Expand Up @@ -574,19 +584,19 @@ module Make (User : USER) = struct
if List.length lits = 0
then problem.toplevel_conflict <- true
else if (* if debug then log_debug "at_least_one(%s)" (string_of_lits lits); *)
List.exists (fun l -> lit_value l = True) lits
List.exists lits ~f:(fun l -> lit_value l = True)
then (* Trivially true already if any literal is True. *)
()
else (
let seen = ref LitSet.empty in
let rec simplify unique = function
| [] -> Some unique
| x :: _ when LitSet.mem (neg x) !seen -> None (* X or not(X) is always True *)
| x :: xs when LitSet.mem x !seen -> simplify unique xs (* Skip duplicates *)
| x :: _ when LitSet.mem !seen (neg x) -> None (* X or not(X) is always True *)
| x :: xs when LitSet.mem !seen x -> simplify unique xs (* Skip duplicates *)
| x :: xs when lit_value x = False ->
simplify unique xs (* Skip values known to be False *)
| x :: xs ->
seen := LitSet.add x !seen;
seen := LitSet.add !seen x;
simplify (x :: unique) xs
in
(* At this point, [unique] contains only [Undefined] literals. *)
Expand Down Expand Up @@ -623,9 +633,9 @@ module Make (User : USER) = struct
(* Ignore any literals already known to be False.
If any are True then they're enqueued and we'll process them
soon. *)
let lits = List.filter (fun l -> lit_value l <> False) lits in
let lits = List.filter lits ~f:(fun l -> lit_value l <> False) in
let clause = new at_most_one_clause problem lits in
List.iter (fun l -> watch_lit l (clause :> clause)) lits;
List.iter lits ~f:(fun l -> watch_lit l (clause :> clause));
clause
;;

Expand Down Expand Up @@ -699,12 +709,11 @@ module Make (User : USER) = struct
- if the variable was assigned at the current level,
mark it for expansion
- otherwise, add it to learnt *)
p_reason
|> List.iter (fun lit ->
List.iter p_reason ~f:(fun lit ->
let var = var_of_lit lit in
if not (VarSet.mem var.id !seen)
if not (VarSet.mem !seen var.id)
then (
seen := VarSet.add var.id !seen;
seen := VarSet.add !seen var.id;
let var_info = var_of_lit lit in
if var_info.level = get_decision_level problem
then
Expand Down Expand Up @@ -744,7 +753,7 @@ module Make (User : USER) = struct
let var = var_of_lit lit in
let reason = var.reason in
undo_one problem;
if not (VarSet.mem var.id !seen)
if not (VarSet.mem !seen var.id)
then
(* if debug then log_debug "(irrelevant: %s)" (name_lit lit); *)
next_interesting ()
Expand Down Expand Up @@ -820,8 +829,9 @@ module Make (User : USER) = struct
If it leads to a conflict, we'll backtrack and
try it the other way. *)
let undecided =
try List.find (fun info -> info.value = Undecided) problem.vars with
| Not_found ->
match List.find problem.vars ~f:(fun info -> info.value = Undecided) with
| Some s -> s
| None ->
(* Everything is assigned without conflicts *)
(* if debug then log_debug "SUCCESS!"; *)
raise
Expand Down

0 comments on commit 53c6dfc

Please sign in to comment.