1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
open Lang
module Make (M : sig
module Types : sig
type primitive_type = Types.primitive_type
type type_expression =
| Variable of string
| Primitive of Types.primitive_type
| Arrow of type_expression * type_expression
val comp : type_expression * type_expression -> bool
end
module Error : sig
exception Wrong_type of expr * Types.type_expression * Types.type_expression
exception Recursive_type of Types.type_expression
end
val get_type : string -> Types.type_expression
val get_cons_type : string -> Types.primitive_type
end) =
struct
let literal = function
| Unit ->
Types.Unit
| Bool _ ->
Types.Bool
| Custom id ->
M.get_cons_type id
let const = function
| Literal l ->
M.Types.Primitive (literal l)
| Var x ->
M.get_type x
let check_type e t expected_t =
if not @@ M.Types.comp (t, expected_t) then
raise @@ M.Error.Wrong_type (e, t, expected_t)
let rec expr = function
| Const c ->
const c
| Bind (x, e, e') ->
let t = expr e in
let expected_t = M.get_type x in
check_type e t expected_t ; expr e'
| Abstract (_, p, e) ->
M.Types.Arrow (M.get_type p, expr e)
| Apply (e, e') -> (
let t = expr e in
let t' = expr e' in
match t with
| M.Types.Arrow (_, t_out) ->
check_type e t (M.Types.Arrow (t', t_out)) ;
t_out
| _ ->
raise @@ M.Error.Recursive_type t )
| Match (_origin, match_expr, cases) -> (
let tmexpr = expr match_expr in
match cases with
| [] ->
failwith "internal error"
| (con1, expr1) :: s ->
let tcon1 = M.Types.Primitive (literal con1) in
let texpr1 = expr expr1 in
check_type match_expr tmexpr tcon1 ;
List.iter
(fun (con, e) ->
check_type (Const (Literal con))
(M.Types.Primitive (literal con))
tmexpr ;
check_type e (expr e) texpr1)
s ;
texpr1 )
| Type (_id, _cons, e) ->
expr e
let expr e = ignore (expr e)
end