1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
open Hc

type var = int

type hidden = view hash_consed

and view = True | False | Node of var * hidden * hidden

module HashedT = struct
  type t = view

  let equal x y =
    match (x, y) with
    | True, True | False, False ->
        true
    | Node (v1, l1, r1), Node (v2, l2, r2) ->
        v1 = v2 && l1 == l2 && r1 == r2
    | _ ->
        false

  let hash = function
    | True ->
        1
    | False ->
        0
    | Node (v, l, r) ->
        (19 * ((19 * v) + l.tag)) + r.tag + 2
end

module Hbdd = Hc.Make (HashedT)

module Hash = struct
  type t = hidden

  let equal = ( == )

  let hash b = b.tag
end

let hc = Hbdd.hashcons

let view x = x.node

module Mem = Memo.MakeWeak (Hash)

let true_bdd = hc True

let false_bdd = hc False

let get_order bdd =
  match view bdd with True | False -> -1 | Node (v, _, _) -> v

let node v l h =
  if v <= get_order l || v <= get_order h then invalid_arg "node" ;
  if Hash.equal l h then l else hc (Node (v, l, h))

let var_bdd v = node v false_bdd true_bdd

let rec fprintf fmt bdd =
  match view bdd with
  | True ->
      Format.fprintf fmt "true"
  | False ->
      Format.fprintf fmt "false"
  | Node (v, l, h) ->
      Format.fprintf fmt "%d ? (%a) : (%a)" v fprintf h fprintf l

let to_string bdd =
  let buff = Buffer.create 512 in
  let fmt = Format.formatter_of_buffer buff in
  fprintf fmt bdd ;
  Format.pp_print_flush fmt () ;
  Buffer.contents buff

let of_bool = function true -> true_bdd | false -> false_bdd

let neg x =
  Mem.memo
    (fun neg x ->
      match view x with
      | True ->
          false_bdd
      | False ->
          true_bdd
      | Node (var, low, high) ->
          node var (neg low) (neg high))
    x

(* TODO: memo 2 ? *)
let rec comb_comm op n1 n2 =
  let comb_comm = comb_comm op in
  match (view n1, view n2) with
  | Node (v1, l1, h1), Node (v2, l2, h2) when v1 = v2 ->
      node v1 (comb_comm l1 l2) (comb_comm h1 h2)
  | Node (v1, l1, h1), Node (v2, _, _) when v1 < v2 ->
      node v1 (comb_comm l1 n2) (comb_comm h1 n2)
  | Node (_, _, _), Node (v2, l2, h2) ->
      node v2 (comb_comm n1 l2) (comb_comm n1 h2)
  | True, Node (v, l, h) | Node (v, l, h), True ->
      node v (comb_comm l true_bdd) (comb_comm h true_bdd)
  | False, Node (v, l, h) | Node (v, l, h), False ->
      node v (comb_comm l false_bdd) (comb_comm h false_bdd)
  | False, False ->
      of_bool (op false false)
  | False, True | True, False ->
      of_bool (op true false)
  | True, True ->
      of_bool (op true true)

(* TODO: memo2 ? *)
let rec comb op n1 n2 =
  let comb = comb op in
  match (view n1, view n2) with
  | Node (v1, l1, h1), Node (v2, l2, h2) when v1 = v2 ->
      node v1 (comb l1 l2) (comb h1 h2)
  | Node (v1, l1, h1), Node (v2, _, _) when v1 < v2 ->
      node v1 (comb l1 n2) (comb h1 n2)
  | Node (_, _, _), Node (v2, l2, h2) ->
      node v2 (comb n1 l2) (comb n1 h2)
  | True, Node (v, l, h) ->
      node v (comb true_bdd l) (comb true_bdd h)
  | Node (v, l, h), True ->
      node v (comb l true_bdd) (comb h true_bdd)
  | False, Node (v, l, h) ->
      node v (comb false_bdd l) (comb false_bdd h)
  | Node (v, l, h), False ->
      node v (comb l false_bdd) (comb h false_bdd)
  | False, False ->
      of_bool (op false false)
  | False, True ->
      of_bool (op false true)
  | True, False ->
      of_bool (op true false)
  | True, True ->
      of_bool (op true true)

let conj = comb_comm (fun x y -> x && y)

let disj = comb_comm (fun x y -> x || y)

let imp = comb (fun x y -> (not x) || y)

let eq = comb_comm (fun x y -> x = y)

let compute tbl =
  let rec compute_aux bdd =
    match view bdd with
    | False ->
        false_bdd
    | True ->
        true_bdd
    | Node (v, l, h) -> (
      match Hashtbl.find tbl v with
      | exception Not_found ->
          node v (compute_aux l) (compute_aux h)
      | true ->
          compute_aux h
      | false ->
          compute_aux l )
  in
  compute_aux

let size =
  (* TODO *)
  let module H = Hashtbl.Make (struct
    type t = Hash.t

    let equal = ( == )

    let hash = Hashtbl.hash (* TODO *)
  end) in
  let tbl = H.create 512 in
  let rec size bdd =
    match view bdd with
    | False | True ->
        0
    | _ when H.mem tbl bdd ->
        0
    | Node (_, l, h) ->
        H.add tbl bdd () ;
        1 + size l + size h
  in
  size

let is_sat bdd = match view bdd with False -> false | _ -> true

let count_sat card =
  let count =
    Mem.memo (fun count bdd ->
        match view bdd with
        | False ->
            0
        | True ->
            1
        | Node (v, l, h) ->
            assert (0 <= v && v < card) ;
            let count_side s = count s lsl (v - get_order s - 1) in
            count_side h + count_side l)
  in
  fun bdd -> count bdd lsl (card - get_order bdd - 1)

let any_sat =
  let rec aux assign bdd =
    match view bdd with
    | False ->
        None
    | True ->
        Some assign
    | Node (v, l, h) -> (
      match aux assign l with
      | None ->
          aux ((v, true) :: assign) h
      | Some assign ->
          Some ((v, false) :: assign) )
  in
  aux []

let all_sat bdd =
  let add_assign v b = function
    | None ->
        None
    | Some assign ->
        Some ((v, b) :: assign)
  in
  let rec aux assign bdd =
    match view bdd with
    | False ->
        [None]
    | True ->
        [Some assign]
    | Node (v, l, h) ->
        let add_assign = add_assign v in
        let aux = aux assign in
        List.map (add_assign false) (aux l) @ List.map (add_assign true) (aux h)
  in
  List.fold_left
    (fun acc -> function None -> acc | Some assign -> assign :: acc)
    [] (aux [] bdd)

(* TODO: in each assign, add all the unused vars. ? *)

let random_sat _ =
  (* let _ = count_sat maxn in *)
  let rec aux assign bdd =
    match view bdd with
    | False ->
        None
    | True ->
        Some assign
    | Node (v, l, h) -> (
        if is_sat l && is_sat h then
          if true (* TODO *) then aux ((v, false) :: assign) h
          else aux ((v, true) :: assign) l
        else
          match aux assign l with
          | None ->
              aux ((v, true) :: assign) h
          | Some assign ->
              Some ((v, false) :: assign) )
  in
  aux []