Last time , we laid out the AST and Type for the language we are building. We also got a bird’s-eye view of our type inference algorithm: constraint generation, constraint solving, substitute our solved types. This time we’re implementing the constraint generation portion of our type inference algorithm.

Our passes will need to share some state between each other. We introduce a `TypeInference`

struct to hold this shared state and implement our passes as methods on that struct:

```
struct TypeInference {
unification_table: InPlaceUnificationTable<TypeVar>,
}
```

We’ll talk more about `unification_table`

when we talk about constraint solving. For now, it’s enough to think of it as a mapping from type variables to type, and we’ll use it in constraint generation to generate new type variables and keep track of them for later.

## Constraint Generation Link to heading

We generate our set of constraints from contextual information in our AST. To get this context we need to visit every node of our AST and collect constraints for that node. Traditionally this is done with a bottom-up tree traversal (e.g. HM’s Algorithm J). We visit all of a node’s children and then use the children’s context to better infer the type of our node. This approach is logically correct. We always infer the correct type and type error when we should. While correct, this approach doesn’t make the most efficient use of information available. For example, in an application node we know that the function child must have a function type. Since types are only inferred bottom-up, we have to infer an arbitrary type for our function node and add a constraint that the inferred type must be a function.

In recent years, a new approach to type checking called Bidirectional Type Checking has arisen to solve this inefficiency. With Bidirectional Type Checking we have two modes of type checking:

`infer`

- works the same as the HM type systems we just described, so types are inferred bottom-up.`check`

- works in the opposite direction, top-down. A type is passed into`check`

, and we*check*that our AST has the same type.

Our two modes will call each other mutually recursively to traverse the AST. Now when we want to type check an application node, we have a new option. We can construct a function type at the application node and check the function node against our constructed function type. Making better use of top-down contextual info like this allows us to generate fewer type variables and produce better error messages. Fewer type variables may not immediately appear as a benefit, but it makes the type checker faster. Our constraint solving is in part bound by the number of type variables we have to solve. It also makes debugging far easier. Each type variable acts as a point of indirection so fewer is better for debugging. So while Bidirectional Type Checking doesn’t allow us to infer “better” types, it does provide tangible benefits for a modest extension of purely bottom-up type inference.

Now that we know how we’ll be traversing our AST to collect constraints, let’s talk about what our constraints will actually look like. For our first Minimum Viable Product (MVP) of type inference, `Constraint`

will just be type equality:

```
enum Constraint {
TypeEqual(Type, Type)
}
```

We’ll talk more about what it means for two types to be equal during constraint solving. For now, it’s sufficient to produce a set of type equalities as a result of our constraint generation.

Constraint generation will be implemented on our `TypeInference`

struct with 3 methods:

```
impl TypeInference {
fn fresh_ty_var(&mut self)
-> TypeVar { ... }
fn infer(
&mut self,
env: &mut HashMap<Var, Type>,
ast: Ast<Var>
) -> (GenOut, Type) { ... }
fn check(
&mut self,
env: &mut HashMap<Var, Type>,
ast: Ast<Var>,
ty: Type
) -> GenOut { .. }
}
```

`fresh_ty_var`

is a helper method we’re going to brush past for now. (We’ll have a lot to cover in constraint solving!) It uses our `unification_table`

to produce a unique type variable every time we call it. Past that, we can see some parallels between our `infer`

and `check`

method that illustrate each mode. `infer`

takes an AST node and returns a type, whereas `check`

takes both an AST node and a type as parameters. This is because `infer`

is working bottom-up and `check`

is working top-down.

Let’s take a second to look at `env`

and `GenOut`

. Both `infer`

and `check`

take an `env`

parameter. This is used to track the type of AST variables in their current scope. `infer`

and `check`

both also return a `GenOut`

. This is a pair of our set of constraints and our typed AST:

```
struct GenOut {
// Set of constraints to be solved
constraints: Vec<Constraint>,
// Ast where all variables are annotated with their type
typed_ast: Ast<TypedVar>,
}
```

One final thing to note, we have no way to return an error from `infer`

or `check`

. We could of course panic, but for the sake of our future selves, we’ll return errors with `Result`

where relevant. It just so happens it’s not relevant for constraint generation. Our output is a set of constraints. It’s perfectly valid for us to return a set of constraints that contradict each other. We discover the contradiction and produce a type error when we try to solve our constraints. That means there aren’t error cases during constraint generation. Neat!

### infer Link to heading

With our setup out of the way, we can dive into our implementation of `infer`

and `check`

. We’ll cover `infer`

first. Because our AST begins untyped, we always call `infer`

first in our type inference, so it is a natural starting point. `infer`

is just a match on our input `ast`

:

```
fn infer(
&mut self,
env: &mut HashMap<Var, Type>,
ast: Ast<Var>
) -> (GenOut, Type) {
match ast {
Ast::Int(i) => todo!(),
Ast::Var(v) => todo!(),
Ast::Fun(arg, body) => todo!(),
Ast::App(fun, arg) => todo!(),
}
}
```

We’ll talk about each case individually, let’s start with an easy one to get our feet wet:

```
Ast::Int(i) => (
GenOut::new(
vec![],
Ast::Int(i)
),
Type::Int
),
```

When we see an integer literal, we know immediately that its type is `Int`

. We don’t need any constraints to be true for this to hold, so we return an empty `Vec`

. One step up in complexity over integers is our variable case:

```
Ast::Var(v) => {
let ty = &env[&v];
(
GenOut::new(
vec![],
// Return a `TypedVar` instead of `Var`
Ast::Var(TypedVar(v, ty.clone())
),
ty.clone(),
)
},
```

When we encounter a variable, we look up its type in our `env`

and return its type. Our env lookup might fail though. What happens if we ask for a variable we don’t have an entry for? That means we have an undefined variable, and we’ll `panic!`

. That’s fine for our purposes; we expect to have done some form of name resolution prior to type inference. If we encounter an undefined variable, we should’ve already exited with an error during name resolution. Past that, our `Var`

case looks very similar to our `Int`

case. We have no constraints to generate and immediately return the type we look up. Next we take a look at our `Fun`

case:

```
Ast::Fun(arg, body) => {
// Create a type variable for our unknown type variable
let arg_ty_var = self.fresh_ty_var();
// Add our agrument to our environment with it's type
let env = env.update(arg, Type::Var(arg_ty_var));
// Check the body of our function with our extended environment
let (body_out, body_ty) = self.infer(env, *body);
(
GenOut::new(
// body constraints are propagated
body_out.constraints,
Ast::fun(
// Our `Fun` holds a `TypedVar` now
TypedVar(arg, Type::Var(arg_ty_var)),
body_out.typed_ast,
),
),
Type::fun(Type::Var(arg_ty_var), body_ty),
)
}
```

`Fun`

is where we actually start doing some nontrivial inference. We create a fresh type variable and record it as the type of `arg`

in our `env`

. With our fresh type variable in scope, we infer a type for `body`

. We then use our inferred body type and generated argument type to construct a function type for our `Fun`

node. While `Fun`

itself doesn’t produce any constraints, it does pass on any constraints that `body`

generated. Now that we know how to type a function, let’s learn how to type a function application:

```
Ast::App(fun, arg) => {
let (arg_out, arg_ty) = self.infer(env.clone(), *arg);
let ret_ty = Type::Var(self.fresh_ty_var());
let fun_ty = Type::fun(arg_ty, ret_ty.clone());
// Because we inferred an argument type, we can
// construct a function type to check against.
let fun_out = self.check(env, *fun, fun_ty);
(
GenOut::new(
// Pass on constraints from both child nodes
arg_out
.constraints
.into_iter()
.chain(fun_out.constraints.into_iter())
.collect(),
Ast::app(fun_out.typed_ast, arg_out.typed_ast),
),
ret_ty,
)
}
```

`App`

is more nuanced than our previous cases. We infer the type of our `arg`

and use that to construct a function type with a fresh type variable as our return type. We use this function type to check our `fun`

node is a function type as well. Our final type for our `App`

node is our fresh return type, and we combine the constraints from `fun`

and `arg`

to produce our final constraint set.

You may wonder why we’ve chosen to infer the type for `arg`

instead of inferring a type for our `fun`

node. This would be reasonable and would produce equally valid results. We’ve opted not to for a few key reasons. If we infer a type for our `fun`

node, it is opaque. We know it has to be a function type, but all we have after inference is a `Type`

. To coerce it into a function type we have to emit a constraint against a freshly constructed function type:

```
let (fun_out, infer_fun_ty) = self.infer(env.clone(), *fun);
let arg_ty = self.fresh_ty_var();
let ret_ty = self.fresh_ty_var();
let fun_ty = Type::fun(arg_ty.clone(), ret_ty.clone());
let fun_constr = Constraint::TypeEqual(infer_fun_ty, fun_ty);
let arg_out = self.check(env, *arg, arg_ty);
// ...
```

We have to create an extra type variable and an extra constraint compared to inferring a type for `arg`

first. Not a huge deal, and in fact in more expressive type systems this tradeoff is worth inferring the function type first as it provides valuable metadata for checking the argument types. Our type system isn’t in that category, though, so we take fewer constraints and fewer type variables every time. Choices like this crop up a lot where it’s not clear when we should infer and when we should check our nodes. Bidirectional Typing
has an in-depth discussion of the tradeoffs and how to decide which approach to take.

### check Link to heading

That covers all of our inference cases, completing our bottom-up traversal. Next let’s talk about its sibling `check`

. Unlike `infer`

, `check`

does not cover every AST case explicitly. Because we are checking our AST against a known type, we only match on cases we know will check and rely on a catch-all bucket case to handle everything else. We’re still working case by case though, so at a high level our check looks very similar to `infer`

:

```
fn check(
&mut self,
ast: Ast<Var>,
ty: Type
) -> GenOut {
match (ast, ty) {
// ...
}
}
```

Notice we match on both our AST and type at once, so we can select just the cases we care about. Let’s look at our cases:

```
(Ast::Int(i), Type::Int) =>
GenOut::new(
vec![],
Ast::Int(i)
),
```

An integer literal trivially checks against the integer type. This case might appear superfluous; couldn’t we just let it be caught by the bucket case? Of course, we could, but this explicit case allows us to avoid type variables and avoid constraints. Our other explicit check case is for `Fun`

:

```
(Ast::Fun(arg, body), Type::Fun(arg_ty, ret_ty)) => {
let env = env.update(arg, *arg_ty);
self.check(env, *body, *ret_ty)
}
```

Our `Fun`

case is also straightforward. We decompose our `Type::Fun`

into it’s argument and return type. Record our `arg`

has `arg_ty`

in our `env`

, and then check that `body`

has `ret_ty`

in our updated `env`

. It almost mirrors our `infer`

’s `Fun`

case, but instead of bubbling a type up, we’re pushing a type down. Those are our only two explicit check cases. Everything else is handled by our bucket case:

```
(ast, expected_ty) => {
let (mut out, actual_ty) = self.infer(ast);
out.constraints
.push(Constraint::TypeEqual(expected_ty, actual_ty));
out
}
```

Finally, we have our bucket case. At first this might seem a little too easy. If we encounter an unknown pair, we just infer a type for our AST and add a constraint saying that type has to be equal to the type we’re checking against. If we think about this, it makes some sense though. In the unlikely case that neither of our types are variables (`(Int, Fun)`

or `(Fun, Int)`

), we will produce a type error when we try to solve our constraint. In the case that one of our types is a variable, we’ve now recorded the contextual info necessary about that variable by adding a constraint. We can rely on constraint solving to propagate that info to wherever it’s needed.

This is the only place where we emit a constraint explicitly. Everywhere else we just propagate constraints from our children’s recursive calls. The point where we switch from checking back to inference is the only point where we require a constraint to ensure our type line up. Our intuition for `infer`

and `check`

help guide us to that conclusion. This is in part the insight and the power of a bidirectional type system. It will only become more valuable as we extend our type system to handle more complicated types.

## Example Link to heading

It’s hard to see how our two functions fit together from just from their implementations. Let’s walk through an example to see `infer`

and `check`

in action. Consider a contrived AST:

```
Ast::app(
Ast::fun(
Var(0),
Ast::Var(Var(0))),
Ast::Int(3)
)
```

This is the identity function applied to an integer. A simple example, but it uses all of our AST nodes and will give us some insight into how `check`

lets us propagate more type information than `infer`

alone. Our example will use some notation to let us introspect our environment and use human friendly names:

```
x, y, z represent variables
α, β, γ represent type variables
env will use a literal
formatted as { <var0>: <type0>, <var1>: <type1>, ... }
with {} being an empty environment.
```

Using our new notation we can shorten our AST example:

```
Ast::app(
Ast:fun(x, Ast::Var(x)),
Ast::Int(3)
)
```

Okay, we start by calling `infer`

on our root `App`

node:

```
infer({}, Ast::App(...))
```

Our `App`

case starts by inferring a type for our `arg`

. Because our argument is `Ast::Int(3)`

its inferred type is `Int`

:

```
infer({}, Ast::Int(3)) = Int
```

We use this inferred argument, and a fresh return type, to construct a function type that we check against our `App`

’s `func`

:

```
check(
{},
Ast::Fun(x, Ast::Var(x)),
Fun(Int, Var(α))
)
```

A function and a function type is one of our specific check cases (it doesn’t fall into the bucket case). We destructure the function type to determine the type of our argument and body. This is where `check`

shines. If we just had `infer`

we would have to introduce a new type variable for `x`

and add a constraint that `x`

’s type variable must be `Int`

. Instead, we can immediately determine `x`

’s type must be `Int`

. With our env updated to include `x`

has type `Int`

, we check body against our function type’s return type:

```
check(
{ x: Int },
Var(x),
Var(α)
)
```

This is not a check case we know how to handle, it falls into the bucket case. The bucket case infers a type for our body. This looks up `x`

’s type in the environment and returns it:

```
infer({ x: Int }, Var(x)) = Int
```

We don’t show it in the example, but this will also return a new AST where `x`

is annotated with its type: `TypeVar(x, Int)`

. We’ll see how that gets used when we look at the final output of our example. A constraint is added that our checked type is equal to our inferred type:

```
Constraint::TypeEqual(
Var(α),
Int
)
```

Once we output that constraint we’re done calling `infer`

and `check`

. We propagate our constraints up the call stack and construct our typed AST as we go. At the end of returning from all our recursive calls we have our constraint set, with just one constraint:

```
vec![Constraint::TypeEqual(Var(α), Int)]
```

and our typed AST:

```
Ast::app(
Ast::fun(
TypeVar(x, Int),
Ast::Var(TypedVar(x, Int))
),
Ast::Int(3)
)
```

The final overall type of our AST, returned from our first `infer`

call, is `Var(α)`

, remember our function’s type is `Fun(Int, Var(α))`

. This illustrates why we need the final substitution step after constraint solving. Only once we’ve solved our constraints do we know `α = Int`

, and we can correctly determine our overall AST’s type is `Int`

.

With that we’ve finished generating our constraints. As output of constraint generation we produce three things: a set of constraints, a typed AST, and a Type for our whole AST. Our typed AST has a type associated to every variable (and from that we can recover the type of every node). However, a lot of these are still unknown type variables. We’ll save that AST for now and revisit it once we’ve solved our set of constraints and have a solution for all our type variables. Naturally then, next time we’ll implement constraint solving.