テストコード

テストをつけました。
今2日くらいで、以下のコードを12段階に分けて書く練習をしてます。

trait Exp
case class Const(a:Int) extends Exp
case class Var(a:String) extends Exp
case class Add(a:Exp,b:Exp) extends Exp
case class Let(a:String,b:Exp,c:Exp) extends Exp
case class Fun(a:String,c:Exp) extends Exp
case class App(a:Exp,b:Exp) extends Exp

trait Typ
case class TInt() extends Typ
case class TFun(a:Typ,b:Typ) extends Typ
case class TVar(var t:Option[Typ]) extends Typ

object tes {
  def main(argv:Array[String]) {
    println(infer(List(), Const(3))==TInt())
    val env = List("a"->TInt(),"b"->TInt())

    println(find(env, "a")==TInt())
    println(find(env, "b")==TInt())
    try {
      find(env, "c")
    } catch {
      case e:Exception => println(e.getMessage()=="error")
    }
    println(infer(env, Var("a"))==TInt())
    println(infer(env, Var("b"))==TInt())
    try {
      infer(env, Var("c"))
    } catch {
      case e:Exception => println(e.getMessage()=="error")
    }
    // 足し算
    println(infer(env, Add(Var("a"),Var("b")))==TInt())
    val exp = Let("c", Const(1), Var("c"))
    println(infer(env, exp)==TInt())

    unify(TFun(TInt(),TInt()), TFun(TInt(),TInt()))

    // 関数
    val fun = Fun("x", Add(Var("x"),Const(3)))
    println(infer(env, fun)==TFun(TVar(Some(TInt())),TInt()))
    unify(infer(env, fun),TFun(TInt(),TInt()))
    // 決まらない関数

    val fun2 = Fun("x", Var("x"))
    println(infer(env, fun2)==TFun(TVar(None),TVar(None)))

    val app = App(fun, Const(1))
    println(infer(env, app)==TVar(Some(TInt())))
    try {
      val occurFun = Fun("f", App(Var("f"), Var("f")))
      println(infer(env, occurFun))
    } catch {
      case e:Exception => println(e.getMessage()=="type error")
    }
  }

  def infer(tenv:List[(String, Typ)], e:Exp):Typ = {
    e match {
      case Const(_) => TInt()
      case Var(x) => find(tenv, x)
      case Add(x, y) =>
        unify(infer(tenv, x), TInt())
        unify(infer(tenv, y), TInt())
        TInt()
      case Let(x, e1, e2) =>
        val t = infer(tenv, e1)
        val newtenv = (x, t) :: tenv
        infer(newtenv, e2)
     case Fun(x, e0) =>
       val t1 = TVar(None)
       val newtenv = (x, t1)::tenv
       val t2 = infer(newtenv, e0)
       TFun(t1, t2)
     case App(e1, e2) =>
       val t1 = infer(tenv, e1)
       val t2 = infer(tenv, e2)
       val t = TVar(None)
       unify(t1, TFun(t2, t))
       t
    }
  }

  def unify(t1:Typ, t2:Typ) {
    (t1, t2) match {
      case (TInt(), TInt()) =>
      case (TFun(a1,b1), TFun(a2,b2))=>
        unify(a1,a2)
        unify(b1,b2)
      case (TVar(r1), TVar(r2)) if (r1 == r2) =>
      case (r@TVar(r1), _) if (!occur(r1, t2)) =>
        // t1はt2の中に現われない型変数
        r1 match {
        case None => // t1は未定
          r.t = Some(t2) // t1にt2を代入
        case Some(t1d) => // すでにt1'が代入されている
          unify(t1d, t2) // t1'とt2を等しくする
        }
      case (_, TVar(r2)) => // t2が型変数
        unify(t2, t1) // t1とt2を入れ替えてunify
      case (_, _) => throw new Exception("type error")
    }
  }



  def occur(r1:Option[Typ], t2:Typ):Boolean = {
    t2 match {
      case TInt() => false
      case TFun(t21, t22) => occur(r1, t21) || occur(r1, t22)
      case TVar(r2) => (r1 == r2) || (r2 match {
          case None => false
          case Some(t2d) => occur(r1, t2d)
        })
    }
  }

  def find(env:List[(String, Typ)], x:String):Typ = {
    env match {
      case List() => throw new Exception("error")
      case (`x`, t)::xs => t
      case _::xs => find(xs, x)
    }
  }
}