1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
open Lang
module Make (M : sig
module Types : sig
type primitive_type = Types.primitive_type
type type_expression =
| Variable of string
| Primitive of Types.primitive_type
| Arrow of type_expression * type_expression
val mk_fresh : unit -> type_expression
exception Rectype of type_expression
module Constraint () : sig
val add_constraint : type_expression -> type_expression -> unit
val solve : unit -> (string * type_expression) list
end
end
module Error : sig
exception Wrong_application of expr * Types.type_expression
exception Recursive_type of Types.type_expression
end
val get_cons_type : string -> Types.primitive_type
end) =
struct
module C = M.Types.Constraint ()
let rec unify =
let open M.Types in
function
| (Variable _ as x), (Arrow _ as a)
| (Arrow _ as a), (Variable _ as x)
| (Variable _ as x), (Primitive _ as a)
| (Primitive _ as a), (Variable _ as x)
| (Variable _ as x), (Variable _ as a) ->
C.add_constraint x a
| Arrow (t1, t1'), Arrow (t2, t2') ->
unify (t1, t2) ;
unify (t1', t2')
| _ ->
()
let infered = Hashtbl.create 512
let mk_fresh x =
let res = M.Types.mk_fresh () in
Hashtbl.add infered x res ; res
let literal = function
| Unit ->
Types.Unit
| Bool _ ->
Types.Bool
| Custom id ->
M.get_cons_type id
let const = function
| Literal l ->
M.Types.Primitive (literal l)
| Var x -> (
try Hashtbl.find infered x with Not_found -> mk_fresh x )
let rec expr = function
| Const c ->
const c
| Bind (id, e1, e2) ->
let t = expr e1 in
Hashtbl.add infered id t ; expr e2
| Abstract (_, id, e) ->
let t = expr e in
Arrow (Hashtbl.find infered id, t)
| Apply (e1, e2) -> (
let t1 = expr e1 in
let t2 = expr e2 in
match t1 with
| Arrow (t_in, t_out) ->
unify (t2, t_in) ;
t_out
| Variable _ ->
let t_out = M.Types.mk_fresh () in
let t_new = M.Types.Arrow (t2, t_out) in
unify (t_new, t1) ;
t_out
| _ ->
raise @@ M.Error.Wrong_application (e1, t1) )
| Match (_origin, match_expr, cases) -> (
let t_match_expr = expr match_expr in
match cases with
| [] ->
failwith "internal error"
| (case1, expr1) :: s ->
let tcase1 = const (Literal case1) in
unify (t_match_expr, tcase1) ;
let texpr1 = expr expr1 in
ignore
@@ List.fold_left
(fun prev_t curr_expr ->
let curr_t = expr curr_expr in
unify (prev_t, curr_t) ;
curr_t)
texpr1 (List.map snd s) ;
texpr1 )
| Type (_id, _cons, e) ->
expr e
let find_t env v =
let rec aux v =
match List.find_opt (fun el -> v = fst el) env with
| None ->
M.Types.Variable v
| Some t -> (
let open M.Types in
match snd t with
| Primitive _ ->
snd t
| Variable y ->
aux y
| Arrow (t1, t2) ->
Arrow (t1, t2) )
in
aux v
let expr e =
let res = ref @@ expr e in
let solved =
try C.solve ()
with M.Types.Rectype e -> raise @@ M.Error.Recursive_type e
in
let rec update =
let open M.Types in
function
| Variable x ->
find_t solved x
| Primitive p ->
Primitive p
| Arrow (t1, t2) ->
let t1' = update t1 in
let t2' = update t2 in
Arrow (t1', t2')
in
Hashtbl.iter
(fun var var_type -> Hashtbl.replace infered var (update var_type))
infered ;
res := update !res ;
!res
end