まずは型推論で型チェック

1.型推論で型チェックをしっかり行う。
2.JVMコンパイルできるようにする。
3.アセンブラコンパイルできるようにする。

という段階を追って行く事にします。
型推論ではまず、関数の型推論だけではなく、各段階の全ての式について型情報を残すようにします。

とりあえず、パラメータ多相までは出来ています。
暗黙的な型変換が今必要です。暗黙的な型変換はアドホックな多相性の型推論が近いようなので勉強してみます。

適当にScalaで書いた型推論のプログラムを以下に示します。
これは、

package hm

sealed abstract case class Exp(val t:Type)
case class EFun(override val t:Type, vs: List[String], body: Exp) extends Exp(t)
case class EId(override val t:Type, name: String) extends Exp(t)
case class EInt(i:Int) extends Exp(TInt())
case class EStr(s:String) extends Exp(TStr())
case class EAdd(override val t:Type, a:Exp,b:Exp) extends Exp(t)
case class Call(override val t:Type, fn: Exp, args: List[Exp]) extends Exp(t)
case class EBlock(override val t:Type, vs:List[Exp]) extends Exp(t)
case class EVar(override val t:Type, v:String, defn:Exp) extends Exp(t)
case class ERet(override val t:Type, v:Exp) extends Exp(t)

object Exp {
  def string(Exp: Exp): String = {
    Exp match {
    case EId(t, i) => i
    case EInt(i) => ""+i
    case EStr(s) => "\""+s+"\""
    case EFun(t, v,b) => "(fn ("+v./:(""){case (a,b)=>a +(if(a=="")"" else ",") + b}+") -> "+string(b)+")"
    case Call(t, f,a) => string(f)+"("+a./:(""){case (a,b)=>a+(if(a=="")"" else ",")+string(b)}+")"
    case EVar(t,v,d) => "var "+v+" = "+string(d)
    case EBlock(t,xs) => "{" + xs./:(""){case (a, b)=>a + (if(a == "") "" else ";") + string(b)} + "}"
    case ERet(t,v) => "return "+string(v)+";"
    }
  }
}


sealed trait Type
case class TVar(id: Int, var instance: Option[Type]) extends Type {
  lazy val name:String = TypeSystem.nextUniqueName()
}
case class TNew(name: String, args: Seq[Type]) extends Type
case class TInt() extends TNew("int", List())
case class TBool() extends TNew("bool", List())
case class TVoid() extends TNew("void", List())
case class TStr() extends TNew("str", List())
case class TFun(from:List[Type], to:Type) extends TNew("Fun", from:::List(to))

case class TypeError(msg: String) extends Exception(msg)
case class ParseError(msg: String) extends Exception(msg)

object TypeSystem {
  def types(b:List[String]):List[Type] = b match {
    case List() => List()
    case t::xs => newVar()::types(xs)
  }

  def Nothing:Type = newVar()

  var _nextVarName = 'a';
  def nextUniqueName():String = {
    val result = _nextVarName
    _nextVarName = (_nextVarName.toInt + 1).toChar
    result.toString
  }

  var _nextVarId = 0
  def newVar():TVar = {
    val result = _nextVarId
    _nextVarId += 1
    TVar(result, None)
  }

  def string(t:Type):String = {
    t match {
      case TVar(_, Some(i)) => string(i)
      case v:TVar => v.name
      case TNew(name, args) =>
        if (args.length == 0) {
          name
        } else if (args.length == 2) {
          "(" + string(args(0)) + " " + name + " " + string(args(1)) + ")"
        } else {
          args.mkString(name + " ", " ", "")
        }
    }
  }


  def infer(ast:Exp, env:Map[String, Type]):Type = {
    infer(ast, env, Set())
  }
  var returns = Set[Type]()

  def infer(ast:Exp, env:Map[String, Type], nongen:Set[TVar]):Type = {
    val t = ast match {
      case EId(t,name) => gettype(name, env, nongen)
      case EInt(i) => TInt()
      case EStr(_) => TStr()
      case Call(t,fn, args) => // 関数呼び出し
        val funtype = infer(fn, env, nongen)
        val argtypes = args./:(List[Type]()){case (ls,arg)=> infer(arg, env, nongen)::ls}.reverse
        val resulttype = newVar()
        unify(TFun(argtypes, resulttype), funtype)
        resulttype
      case EFun(t, args, body) => // 関数
        val as:Map[String,TVar] = args./:(Map[String,TVar]()){case (a,b)=> a + (b -> newVar())}
        val argtypes = as./:(List[Type]()){case (a,(b,c))=> c::a}.reverse
        val aslist = as.toList
        val nongen2 = aslist./:(nongen:Set[TVar]) {
          case (n, b) => b match { case (b,c:TVar)=> n + c}
        }
        returns = Set[Type]()
        val resulttype = infer(body, env ++ as, nongen2)
        //println("returns="+returns)
        returns.foreach { case (a) => unify(a, resulttype) }
        TFun(argtypes, resulttype)
      case EVar(t, v, defn) => infer(defn, env, nongen)
      case EBlock(t, xs) =>
        var benv = Set[String]()
        def binfer(xs:List[Exp], env:Map[String, Type], nongen:Set[TVar]):Type = {
          xs match {
            case List() => TVoid()
            case List(x) => infer(x, env, nongen)
            case x::xs =>
              val t = infer(x, env, nongen)
              x match {
                case EVar(t,v, defn) if(benv.contains(v)) => throw new TypeError("Already defined "+v)
                case EVar(t,v, defn) => benv = benv + v; binfer(xs, env + (v -> t), nongen)
                case v => binfer(xs, env, nongen)
              }
          }
        }
        binfer(xs, env, nongen)
      case ERet(t, v) => returns = returns + infer(v, env, nongen); Nothing
    }
    ast.t match {
      case a:TVar => a.instance=Some(t)
      case _ =>
    }
    t
  }

  def gettype(name:String, env:Map[String, Type], nongen:Set[TVar]):Type = {
    if (env.contains(name)) {
      fresh(env(name), nongen)
    } else {
      throw new ParseError("Undefined symbol " + name)
    }
  }

  def fresh(t:Type, nongen:Set[TVar]):Type = {
    import scala.collection.mutable
    val mappings = new mutable.HashMap[TVar, TVar]
    def freshrec(tp:Type):Type = {
      prune(tp) match {
        case v:TVar =>
          if (isgeneric(v, nongen)) {
            mappings.getOrElseUpdate(v, newVar)
          } else {
            v
          }
        case tp@TNew(name, args) => tp match {
          case TInt() => tp
          case TBool() => tp
          case _ => TNew(name, args.map(freshrec(_)))
        }
      }
    }
    freshrec(t)
  }



  def unify(t1:Type, t2:Type) {
    val type1 = prune(t1)
    val type2 = prune(t2)
    (type1, type2) match {
      case (a:TVar, b) if(a == b) =>
      case (a:TVar, b) =>
        if (occursintype(a, b)) {
          throw new TypeError("recursive unification")
        }
        a.instance = Some(b)
      case (a:TNew, b:TVar) => unify(b, a)
      case (a:TNew, b:TNew) =>
        if (a.name != b.name || a.args.length != b.args.length) {
          throw new TypeError("Type mismatch: "+string(a)+"!="+string(b))
        }
        for (i <- 0 until a.args.length) {
          unify(a.args(i), b.args(i))
        }
    }
  }


  /**
   * tが表す型を返します。
   * ただしTVar(Some(n))だった場合は、Some内の値を返します。
   * また、Someがネストしていた場合はネストを取り除きます。
   */
  def prune(t:Type):Type = {
    t match {
      case v@TVar(_, Some(i)) =>
        var inst = prune(i)
        v.instance = Some(inst)
        inst
      case _ => t
    }
  }

  /**
   * Note: must be called with v 'pre-pruned'
   */
  def isgeneric(v: TVar, nongen: Set[TVar]):Boolean = {
   // !occursin(v, nongen)
    !nongen.exists{t => occursintype(v, t)}
  }

  /**
   * Note: must be called with v 'pre-pruned'
   */
  def occursintype(v: TVar, type2: Type): Boolean = {
    prune(type2) match {
      case `v` => true
      case TNew(name, args) => args.exists{t => occursintype(v, t)}
      case _ => false
    }
  }
  def isIntegerLiteral(name: String):Boolean = {
    val checkDigits = "^(\\d+)$".r
    checkDigits.findFirstIn(name).isDefined
  }

  def gs(es:List[Exp]):List[Exp] = es match {
    case List() => List()
    case x::xs => g(x)::gs(xs)
  }
  def p(t:Type):Type = {
    t match {
      case v@TVar(_, Some(i)) =>
        var inst = p(i)
        v.instance = Some(inst)
        inst
      case v@TNew(name: String, args: Seq[Type]) if (args.size==0)=>t
      case v@TFun(from:List[Type], to:Type) => TFun(from map{case a=>p(a)}, p(to))
      case v@TNew(name: String, args: Seq[Type]) => TNew(name, args map{case a=>p(a)})
      case _ => t
    }
  }

  def g(ast:Exp):Exp = ast match {
    case EFun(t, vs, body) => EFun(p(t), vs, g(body))
    case EId(t,name) => EId(p(t), name)
    case EInt(i) => ast
    case EStr(s) => ast
    case EAdd(t, a, b) => EAdd(p(t),g(a),g(b))
    case Call(t,fn, args) => Call(p(t),g(fn), gs(args))
    case EBlock(t, vs) => EBlock(p(t),gs(vs))
    case EVar(t, v, d) => EVar(p(t), v, g(d))
    case ERet(t, v) => ERet(p(t),g(v))
  }
}

object hm {
  def n:TVar = TypeSystem.newVar()

  def main(args: Array[String]){

    Console.setOut(new java.io.PrintStream(Console.out, true, "utf-8"))

    val var1 = n
    val var2 = n
    val pairtype = TNew(",", Array(var1, var2))

    val var3 = n

    val env = Map(
      "pair" -> TFun(List(var1,var2), pairtype),
      "true" -> TBool(),
      "if" -> TFun(List(TBool(), var3, var3), var3),
      "zero" -> TFun(List(TInt()), TBool()),
      "pred" -> TFun(List(TInt()), TInt()),
      "mul"-> TFun(List(TInt(), TInt()), TInt()),
      "add"-> TFun(List(var1,var2),var3)
    )



    // Should fail:
    // fn x => (pair(x(3) (x(true)))
    tryexp(env, EFun(n,List("x"),
          Call(n,EId(n,"pair"),List(
            Call(n,EId(n,"x"), List(EInt(3))),
            Call(n,EId(n,"x"), List(EId(n,"true")))
          ))))

    // pair(f(3), f(true))
    tryexp(env,
        Call(n,EId(n,"pair"), List(
		Call(n,EId(n,"f"), List(EInt(3))),
        	Call(n,EId(n,"f"), List(EId(n,"true"))))))

    // { var f = (fn x => x); ((pair (f 4)) (f true)) }
    tryexp(env, EBlock(n,List(
      EVar(n,"f",
           EFun(n,List("x"), EId(n,"x"))),

      Call(n,EId(n,"pair"), List(
        Call(n,EId(n,"f"), List(EInt(4))),
        Call(n,EId(n,"f"), List(EId(n,"true")))))
    )))
    // fn f => f f (fail)
    tryexp(env, EFun(n,
          List("f"),
          Call(n,EId(n,"f"),List(EId(n,"f")))))

    // let g = fn f => 5 in g g
    tryexp(env, EBlock(n,List(
        EVar(n,"g",
             EFun(n,List("f"), EInt(5))),
        Call(n,EId(n,"g"), List(EId(n,"g")))
    )))

    // example that demonstrates generic and non-generic Vars:
    // fn g => let f = fn x => g in pair (f 3, f true)
    /*
fn g => {
  var f = fn x g;
  pair(f 3, f true)
}
*/
    tryexp(env, EFun(n,List("g"), EBlock(n,List(
           EVar(n,"f", EFun(n,List("x"), EId(n,"g"))),
           Call(n,EId(n,"pair"), List(
             Call(n,EId(n,"f"), List(EInt(3))),
             Call(n,EId(n,"f"), List(EId(n,"true")))
           ))
    ))))

    // Function composition
    // fn f (fn g (fn arg (f g arg)))
    tryexp(env,
      EFun(n,List("f"),
           EFun(n,List("g"),
                EFun(n,List("arg"), Call(n,EId(n,"g"), List(Call(n,EId(n,"f"), List(EId(n,"arg"))))))))
    )
    println(Exp.string(EBlock(n,List(EId(n,"f"),EId(n,"d")))))
    tryexp(env, EBlock(n,List()))
    tryexp(env, EBlock(n,List[Exp](
      EVar(n,"a", EInt(1)),
      EId(n,"a")
    )))
    tryexp(env, EBlock(n,List[Exp](
      EVar(n,"a", EInt(1)),
      EVar(n,"a", EId(n,"true")),
      EId(n,"a")
    )))
    // if("true")  3 else 4
    tryexp(env, Call(n,EId(n,"if"), List(EId(n,"true"),EInt(4),EInt(5))))
    tryexp(env, Call(n,EId(n,"if"), List(EId(n,"true"),ERet(n,EInt(4)),EInt(5))))
    tryexp(env, EBlock(n,List[Exp](
      Call(n,EId(n,"if"), List(EId(n,"true"),ERet(n,EInt(4)),ERet(n,EInt(5)))),
      EInt(4)
    )))

    tryexp(env, EFun(n,List(),EBlock(n,List[Exp](
      Call(n,EId(n,"if"), List(EId(n,"true"),ERet(n,EInt(4)),ERet(n,EInt(5)))),
      EId(n,"true")
    ))))

    tryexp(env, Call(n,EId(n,"mul"),List(EInt(1),EInt(2))))
    tryexp(env, Call(n,EId(n,"mul"),List(EStr("a"),EInt(2))))
    tryexp(env, Call(n,EId(n,"add"),List(EInt(1),EInt(2))))
    tryexp(env, Call(n,EId(n,"add"),List(EStr("a"),EInt(2))))
  }

  def tryexp(env: Map[String, Type], ast: Exp) {
    print(Exp.string(ast) + " : ")
    try {
      val t = TypeSystem.infer(ast, env)
      print(TypeSystem.string(t))
      print(TypeSystem.g(ast))
    }catch{
      case ParseError(m) => print(m)
      case TypeError(m) => print(m)
    }
    println
  }
}