




package hm

sealed abstract case class Exp(val t:Type)
case class EFun(override val t:Type, vs: List[String], body: Exp) extends Exp(t)
case class EId(override val t:Type, name: String) extends Exp(t)
case class EInt(i:Int) extends Exp(TInt())
case class EStr(s:String) extends Exp(TStr())
case class EAdd(override val t:Type, a:Exp,b:Exp) extends Exp(t)
case class Call(override val t:Type, fn: Exp, args: List[Exp]) extends Exp(t)
case class EBlock(override val t:Type, vs:List[Exp]) extends Exp(t)
case class EVar(override val t:Type, v:String, defn:Exp) extends Exp(t)
case class ERet(override val t:Type, v:Exp) extends Exp(t)

object Exp {
  def string(Exp: Exp): String = {
    Exp match {
    case EId(t, i) => i
    case EInt(i) => ""+i
    case EStr(s) => "\""+s+"\""
    case EFun(t, v,b) => "(fn ("+v./:(""){case (a,b)=>a +(if(a=="")"" else ",") + b}+") -> "+string(b)+")"
    case Call(t, f,a) => string(f)+"("+a./:(""){case (a,b)=>a+(if(a=="")"" else ",")+string(b)}+")"
    case EVar(t,v,d) => "var "+v+" = "+string(d)
    case EBlock(t,xs) => "{" + xs./:(""){case (a, b)=>a + (if(a == "") "" else ";") + string(b)} + "}"
    case ERet(t,v) => "return "+string(v)+";"

sealed trait Type
case class TVar(id: Int, var instance: Option[Type]) extends Type {
  lazy val name:String = TypeSystem.nextUniqueName()
case class TNew(name: String, args: Seq[Type]) extends Type
case class TInt() extends TNew("int", List())
case class TBool() extends TNew("bool", List())
case class TVoid() extends TNew("void", List())
case class TStr() extends TNew("str", List())
case class TFun(from:List[Type], to:Type) extends TNew("Fun", from:::List(to))

case class TypeError(msg: String) extends Exception(msg)
case class ParseError(msg: String) extends Exception(msg)

object TypeSystem {
  def types(b:List[String]):List[Type] = b match {
    case List() => List()
    case t::xs => newVar()::types(xs)

  def Nothing:Type = newVar()

  var _nextVarName = 'a';
  def nextUniqueName():String = {
    val result = _nextVarName
    _nextVarName = (_nextVarName.toInt + 1).toChar

  var _nextVarId = 0
  def newVar():TVar = {
    val result = _nextVarId
    _nextVarId += 1
    TVar(result, None)

  def string(t:Type):String = {
    t match {
      case TVar(_, Some(i)) => string(i)
      case v:TVar => v.name
      case TNew(name, args) =>
        if (args.length == 0) {
        } else if (args.length == 2) {
          "(" + string(args(0)) + " " + name + " " + string(args(1)) + ")"
        } else {
          args.mkString(name + " ", " ", "")

  def infer(ast:Exp, env:Map[String, Type]):Type = {
    infer(ast, env, Set())
  var returns = Set[Type]()

  def infer(ast:Exp, env:Map[String, Type], nongen:Set[TVar]):Type = {
    val t = ast match {
      case EId(t,name) => gettype(name, env, nongen)
      case EInt(i) => TInt()
      case EStr(_) => TStr()
      case Call(t,fn, args) => // 関数呼び出し
        val funtype = infer(fn, env, nongen)
        val argtypes = args./:(List[Type]()){case (ls,arg)=> infer(arg, env, nongen)::ls}.reverse
        val resulttype = newVar()
        unify(TFun(argtypes, resulttype), funtype)
      case EFun(t, args, body) => // 関数
        val as:Map[String,TVar] = args./:(Map[String,TVar]()){case (a,b)=> a + (b -> newVar())}
        val argtypes = as./:(List[Type]()){case (a,(b,c))=> c::a}.reverse
        val aslist = as.toList
        val nongen2 = aslist./:(nongen:Set[TVar]) {
          case (n, b) => b match { case (b,c:TVar)=> n + c}
        returns = Set[Type]()
        val resulttype = infer(body, env ++ as, nongen2)
        returns.foreach { case (a) => unify(a, resulttype) }
        TFun(argtypes, resulttype)
      case EVar(t, v, defn) => infer(defn, env, nongen)
      case EBlock(t, xs) =>
        var benv = Set[String]()
        def binfer(xs:List[Exp], env:Map[String, Type], nongen:Set[TVar]):Type = {
          xs match {
            case List() => TVoid()
            case List(x) => infer(x, env, nongen)
            case x::xs =>
              val t = infer(x, env, nongen)
              x match {
                case EVar(t,v, defn) if(benv.contains(v)) => throw new TypeError("Already defined "+v)
                case EVar(t,v, defn) => benv = benv + v; binfer(xs, env + (v -> t), nongen)
                case v => binfer(xs, env, nongen)
        binfer(xs, env, nongen)
      case ERet(t, v) => returns = returns + infer(v, env, nongen); Nothing
    ast.t match {
      case a:TVar => a.instance=Some(t)
      case _ =>

  def gettype(name:String, env:Map[String, Type], nongen:Set[TVar]):Type = {
    if (env.contains(name)) {
      fresh(env(name), nongen)
    } else {
      throw new ParseError("Undefined symbol " + name)

  def fresh(t:Type, nongen:Set[TVar]):Type = {
    import scala.collection.mutable
    val mappings = new mutable.HashMap[TVar, TVar]
    def freshrec(tp:Type):Type = {
      prune(tp) match {
        case v:TVar =>
          if (isgeneric(v, nongen)) {
            mappings.getOrElseUpdate(v, newVar)
          } else {
        case tp@TNew(name, args) => tp match {
          case TInt() => tp
          case TBool() => tp
          case _ => TNew(name, args.map(freshrec(_)))

  def unify(t1:Type, t2:Type) {
    val type1 = prune(t1)
    val type2 = prune(t2)
    (type1, type2) match {
      case (a:TVar, b) if(a == b) =>
      case (a:TVar, b) =>
        if (occursintype(a, b)) {
          throw new TypeError("recursive unification")
        a.instance = Some(b)
      case (a:TNew, b:TVar) => unify(b, a)
      case (a:TNew, b:TNew) =>
        if (a.name != b.name || a.args.length != b.args.length) {
          throw new TypeError("Type mismatch: "+string(a)+"!="+string(b))
        for (i <- 0 until a.args.length) {
          unify(a.args(i), b.args(i))

   * tが表す型を返します。
   * ただしTVar(Some(n))だった場合は、Some内の値を返します。
   * また、Someがネストしていた場合はネストを取り除きます。
  def prune(t:Type):Type = {
    t match {
      case v@TVar(_, Some(i)) =>
        var inst = prune(i)
        v.instance = Some(inst)
      case _ => t

   * Note: must be called with v 'pre-pruned'
  def isgeneric(v: TVar, nongen: Set[TVar]):Boolean = {
   // !occursin(v, nongen)
    !nongen.exists{t => occursintype(v, t)}

   * Note: must be called with v 'pre-pruned'
  def occursintype(v: TVar, type2: Type): Boolean = {
    prune(type2) match {
      case `v` => true
      case TNew(name, args) => args.exists{t => occursintype(v, t)}
      case _ => false
  def isIntegerLiteral(name: String):Boolean = {
    val checkDigits = "^(\\d+)$".r

  def gs(es:List[Exp]):List[Exp] = es match {
    case List() => List()
    case x::xs => g(x)::gs(xs)
  def p(t:Type):Type = {
    t match {
      case v@TVar(_, Some(i)) =>
        var inst = p(i)
        v.instance = Some(inst)
      case v@TNew(name: String, args: Seq[Type]) if (args.size==0)=>t
      case v@TFun(from:List[Type], to:Type) => TFun(from map{case a=>p(a)}, p(to))
      case v@TNew(name: String, args: Seq[Type]) => TNew(name, args map{case a=>p(a)})
      case _ => t

  def g(ast:Exp):Exp = ast match {
    case EFun(t, vs, body) => EFun(p(t), vs, g(body))
    case EId(t,name) => EId(p(t), name)
    case EInt(i) => ast
    case EStr(s) => ast
    case EAdd(t, a, b) => EAdd(p(t),g(a),g(b))
    case Call(t,fn, args) => Call(p(t),g(fn), gs(args))
    case EBlock(t, vs) => EBlock(p(t),gs(vs))
    case EVar(t, v, d) => EVar(p(t), v, g(d))
    case ERet(t, v) => ERet(p(t),g(v))

object hm {
  def n:TVar = TypeSystem.newVar()

  def main(args: Array[String]){

    Console.setOut(new java.io.PrintStream(Console.out, true, "utf-8"))

    val var1 = n
    val var2 = n
    val pairtype = TNew(",", Array(var1, var2))

    val var3 = n

    val env = Map(
      "pair" -> TFun(List(var1,var2), pairtype),
      "true" -> TBool(),
      "if" -> TFun(List(TBool(), var3, var3), var3),
      "zero" -> TFun(List(TInt()), TBool()),
      "pred" -> TFun(List(TInt()), TInt()),
      "mul"-> TFun(List(TInt(), TInt()), TInt()),
      "add"-> TFun(List(var1,var2),var3)

    // Should fail:
    // fn x => (pair(x(3) (x(true)))
    tryexp(env, EFun(n,List("x"),
            Call(n,EId(n,"x"), List(EInt(3))),
            Call(n,EId(n,"x"), List(EId(n,"true")))

    // pair(f(3), f(true))
        Call(n,EId(n,"pair"), List(
		Call(n,EId(n,"f"), List(EInt(3))),
        	Call(n,EId(n,"f"), List(EId(n,"true"))))))

    // { var f = (fn x => x); ((pair (f 4)) (f true)) }
    tryexp(env, EBlock(n,List(
           EFun(n,List("x"), EId(n,"x"))),

      Call(n,EId(n,"pair"), List(
        Call(n,EId(n,"f"), List(EInt(4))),
        Call(n,EId(n,"f"), List(EId(n,"true")))))
    // fn f => f f (fail)
    tryexp(env, EFun(n,

    // let g = fn f => 5 in g g
    tryexp(env, EBlock(n,List(
             EFun(n,List("f"), EInt(5))),
        Call(n,EId(n,"g"), List(EId(n,"g")))

    // example that demonstrates generic and non-generic Vars:
    // fn g => let f = fn x => g in pair (f 3, f true)
fn g => {
  var f = fn x g;
  pair(f 3, f true)
    tryexp(env, EFun(n,List("g"), EBlock(n,List(
           EVar(n,"f", EFun(n,List("x"), EId(n,"g"))),
           Call(n,EId(n,"pair"), List(
             Call(n,EId(n,"f"), List(EInt(3))),
             Call(n,EId(n,"f"), List(EId(n,"true")))

    // Function composition
    // fn f (fn g (fn arg (f g arg)))
                EFun(n,List("arg"), Call(n,EId(n,"g"), List(Call(n,EId(n,"f"), List(EId(n,"arg"))))))))
    tryexp(env, EBlock(n,List()))
    tryexp(env, EBlock(n,List[Exp](
      EVar(n,"a", EInt(1)),
    tryexp(env, EBlock(n,List[Exp](
      EVar(n,"a", EInt(1)),
      EVar(n,"a", EId(n,"true")),
    // if("true")  3 else 4
    tryexp(env, Call(n,EId(n,"if"), List(EId(n,"true"),EInt(4),EInt(5))))
    tryexp(env, Call(n,EId(n,"if"), List(EId(n,"true"),ERet(n,EInt(4)),EInt(5))))
    tryexp(env, EBlock(n,List[Exp](
      Call(n,EId(n,"if"), List(EId(n,"true"),ERet(n,EInt(4)),ERet(n,EInt(5)))),

    tryexp(env, EFun(n,List(),EBlock(n,List[Exp](
      Call(n,EId(n,"if"), List(EId(n,"true"),ERet(n,EInt(4)),ERet(n,EInt(5)))),

    tryexp(env, Call(n,EId(n,"mul"),List(EInt(1),EInt(2))))
    tryexp(env, Call(n,EId(n,"mul"),List(EStr("a"),EInt(2))))
    tryexp(env, Call(n,EId(n,"add"),List(EInt(1),EInt(2))))
    tryexp(env, Call(n,EId(n,"add"),List(EStr("a"),EInt(2))))

  def tryexp(env: Map[String, Type], ast: Exp) {
    print(Exp.string(ast) + " : ")
    try {
      val t = TypeSystem.infer(ast, env)
      case ParseError(m) => print(m)
      case TypeError(m) => print(m)