1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
open Lang

module Make (M : sig
  module Types : sig
    type primitive_type = Types.primitive_type

    type type_expression =
      | Variable of string
      | Primitive of Types.primitive_type
      | Arrow of type_expression * type_expression

    val mk_fresh : unit -> type_expression

    exception Rectype of type_expression

    module Constraint () : sig
      val add_constraint : type_expression -> type_expression -> unit

      val solve : unit -> (string * type_expression) list
    end
  end

  module Error : sig
    exception Wrong_application of expr * Types.type_expression

    exception Recursive_type of Types.type_expression
  end

  val get_cons_type : string -> Types.primitive_type
end) =
struct
  module C = M.Types.Constraint ()

  let rec unify =
    let open M.Types in
    function
    | (Variable _ as x), (Arrow _ as a)
    | (Arrow _ as a), (Variable _ as x)
    | (Variable _ as x), (Primitive _ as a)
    | (Primitive _ as a), (Variable _ as x)
    | (Variable _ as x), (Variable _ as a) ->
        C.add_constraint x a
    | Arrow (t1, t1'), Arrow (t2, t2') ->
        unify (t1, t2) ;
        unify (t1', t2')
    | _ ->
        ()

  let infered = Hashtbl.create 512

  let mk_fresh x =
    let res = M.Types.mk_fresh () in
    Hashtbl.add infered x res ; res

  let literal = function
    | Unit ->
        Types.Unit
    | Bool _ ->
        Types.Bool
    | Custom id ->
        M.get_cons_type id

  let const = function
    | Literal l ->
        M.Types.Primitive (literal l)
    | Var x -> (
      try Hashtbl.find infered x with Not_found -> mk_fresh x )

  let rec expr = function
    | Const c ->
        const c
    | Bind (id, e1, e2) ->
        let t = expr e1 in
        Hashtbl.add infered id t ; expr e2
    | Abstract (_, id, e) ->
        let t = expr e in
        Arrow (Hashtbl.find infered id, t)
    | Apply (e1, e2) -> (
        let t1 = expr e1 in
        let t2 = expr e2 in
        match t1 with
        | Arrow (t_in, t_out) ->
            unify (t2, t_in) ;
            t_out
        | Variable _ ->
            let t_out = M.Types.mk_fresh () in
            let t_new = M.Types.Arrow (t2, t_out) in
            unify (t_new, t1) ;
            t_out
        | _ ->
            raise @@ M.Error.Wrong_application (e1, t1) )
    | Match (_origin, match_expr, cases) -> (
        let t_match_expr = expr match_expr in
        match cases with
        | [] ->
            failwith "internal error"
        | (case1, expr1) :: s ->
            let tcase1 = const (Literal case1) in
            unify (t_match_expr, tcase1) ;
            let texpr1 = expr expr1 in
            ignore
            @@ List.fold_left
                 (fun prev_t curr_expr ->
                   let curr_t = expr curr_expr in
                   unify (prev_t, curr_t) ;
                   curr_t)
                 texpr1 (List.map snd s) ;
            texpr1 )
    | Type (_id, _cons, e) ->
        expr e

  let find_t env v =
    let rec aux v =
      match List.find_opt (fun el -> v = fst el) env with
      | None ->
          M.Types.Variable v
      | Some t -> (
          let open M.Types in
          match snd t with
          | Primitive _ ->
              snd t
          | Variable y ->
              aux y
          | Arrow (t1, t2) ->
              Arrow (t1, t2) )
    in
    aux v

  let expr e =
    let res = ref @@ expr e in
    let solved =
      try C.solve ()
      with M.Types.Rectype e -> raise @@ M.Error.Recursive_type e
    in
    let rec update =
      let open M.Types in
      function
      | Variable x ->
          find_t solved x
      | Primitive p ->
          Primitive p
      | Arrow (t1, t2) ->
          let t1' = update t1 in
          let t2' = update t2 in
          Arrow (t1', t2')
    in
    Hashtbl.iter
      (fun var var_type -> Hashtbl.replace infered var (update var_type))
      infered ;
    res := update !res ;
    !res
end