1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
open Lang

module Make (M : sig
  type type_expression

  val fprintf_type : Format.formatter -> type_expression -> unit
end) =
struct
  let fprintf_literal fmt = function
    | Unit ->
        Format.fprintf fmt "Unit"
    | Bool b ->
        Format.fprintf fmt "%s" (if b then "True" else "False")
    | Custom s ->
        Format.fprintf fmt "%s" s

  let fprintf_const fmt = function
    | Literal l ->
        Format.fprintf fmt "%a" fprintf_literal l
    | Var v ->
        Format.fprintf fmt "%s" v

  let fprintf_pattern fmt p = Format.fprintf fmt "%s" p

  let get_args_nosugar =
    let rec get_args_nosugar acc = function
      | Abstract (b, x, e) -> (
        match b with Raw -> get_args_nosugar (x :: acc) e | Generated -> acc )
      | _ ->
          acc
    in
    get_args_nosugar []

  let rec get_args_sugar acc = function
    | Abstract (b, x, e) as a -> (
      match b with Generated -> get_args_sugar (x :: acc) e | Raw -> (acc, a) )
    | e ->
        (acc, e)

  let fprintf_args fmt l =
    Format.fprintf fmt "%a"
      (Format.pp_print_list
         ~pp_sep:(fun _ () -> Format.fprintf fmt " ")
         fprintf_pattern)
      (List.rev l)

  let fprintf_case_string fmt s = Format.fprintf fmt "@.| %s" s

  let fprintf_type_dec fmt l =
    Format.fprintf fmt "%a"
      (Format.pp_print_list
         ~pp_sep:(fun _ () -> Format.fprintf fmt "")
         fprintf_case_string)
      l

  let rec fprintf_bind fmt (f, e1, e2) =
    let args, e = get_args_sugar [] e1 in
    match args with
    | [] ->
        Format.fprintf fmt "let %a = %a in %a" fprintf_pattern f fprintf_expr e1
          fprintf_expr e2
    | args ->
        Format.fprintf fmt "let %a %a = %a in %a" fprintf_pattern f fprintf_args
          args fprintf_expr e fprintf_expr e2

  and fprintf_expr fmt = function
    | Const c ->
        fprintf_const fmt c
    | Bind (p, e1, e2) ->
        fprintf_bind fmt (p, e1, e2)
    | Abstract (_, p, e) ->
        Format.fprintf fmt "(fun %a -> %a)" fprintf_pattern p fprintf_expr e
    | Apply (e, e') -> (
      match e' with
      | Apply _ ->
          Format.fprintf fmt "%a (%a)" fprintf_expr e fprintf_expr e'
      | _ ->
          Format.fprintf fmt "%a %a" fprintf_expr e fprintf_expr e' )
    | Match (orig, match_expr, cases) ->
        if orig = Generated then
          match cases with
          | [(Bool false, false_case); (Bool true, true_case)] ->
              Format.fprintf fmt "if %a then %a else %a end" fprintf_expr
                match_expr fprintf_expr true_case fprintf_expr false_case
          | _ ->
              failwith "internal error"
        else
          Format.fprintf fmt "match %a@.%a@.end@." fprintf_expr match_expr
            (Format.pp_print_list
               ~pp_sep:(fun _ () -> Format.fprintf fmt "@.")
               fprintf_match_case)
            cases
    | Type (id, cons, e) ->
        Format.fprintf fmt "let type %s =%a@.in@.%a" id fprintf_type_dec cons
          fprintf_expr e

  and fprintf_match_case fmt (con, expr) =
    Format.fprintf fmt "| %a -> %a" fprintf_literal con fprintf_expr expr

  let fprintf_primitive_type fmt =
    let open Types in
    function
    | Unit ->
        Format.fprintf fmt "unit"
    | Bool ->
        Format.fprintf fmt "bool"
    | Custom (id, _cons) ->
        Format.fprintf fmt "%s" id

  let fprintf_type = M.fprintf_type

  let fprintf_et fmt (e, t) =
    Format.fprintf fmt "%a : %a" fprintf_expr e fprintf_type t
end