1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
open Lang

module Make (M : sig
  module Types : sig
    type primitive_type = Types.primitive_type

    type type_expression =
      | Variable of string
      | Primitive of Types.primitive_type
      | Arrow of type_expression * type_expression

    val comp : type_expression * type_expression -> bool
  end

  module Error : sig
    exception Wrong_type of expr * Types.type_expression * Types.type_expression

    exception Recursive_type of Types.type_expression
  end

  val get_type : string -> Types.type_expression

  val get_cons_type : string -> Types.primitive_type
end) =
struct
  let literal = function
    | Unit ->
        Types.Unit
    | Bool _ ->
        Types.Bool
    | Custom id ->
        M.get_cons_type id

  let const = function
    | Literal l ->
        M.Types.Primitive (literal l)
    | Var x ->
        M.get_type x

  let check_type e t expected_t =
    if not @@ M.Types.comp (t, expected_t) then
      raise @@ M.Error.Wrong_type (e, t, expected_t)

  let rec expr = function
    | Const c ->
        const c
    | Bind (x, e, e') ->
        let t = expr e in
        let expected_t = M.get_type x in
        check_type e t expected_t ; expr e'
    | Abstract (_, p, e) ->
        M.Types.Arrow (M.get_type p, expr e)
    | Apply (e, e') -> (
        let t = expr e in
        let t' = expr e' in
        match t with
        | M.Types.Arrow (_, t_out) ->
            check_type e t (M.Types.Arrow (t', t_out)) ;
            t_out
        | _ ->
            raise @@ M.Error.Recursive_type t )
    | Match (_origin, match_expr, cases) -> (
        let tmexpr = expr match_expr in
        match cases with
        | [] ->
            failwith "internal error"
        | (con1, expr1) :: s ->
            let tcon1 = M.Types.Primitive (literal con1) in
            let texpr1 = expr expr1 in
            check_type match_expr tmexpr tcon1 ;
            List.iter
              (fun (con, e) ->
                check_type (Const (Literal con))
                  (M.Types.Primitive (literal con))
                  tmexpr ;
                check_type e (expr e) texpr1)
              s ;
            texpr1 )
    | Type (_id, _cons, e) ->
        expr e

  let expr e = ignore (expr e)
end