リターン

ということで、リターンを考える。
リターンの嫌な所は、関数内の色々な箇所で脱出されること。式が値を持たないのかな?
リターンは関数の型になるけど、式の型にはならないと。
scala> def a():Int = { val a = return 1; a+1}
:8: error: value + is not a member of Nothing
ということで、scalaだと、return式の型はnothing型だと。
で、どう実装仕様って話になるわけだと。
Nothing型と他の型なら、他の型になるってかんじにすればいいようだ。ということだなたぶん。
で、returnの型は取っておいて、最後にunifyをオッケーなんじゃないかと思われる。
カリー化じゃまなんだけど。
カリー化は捨てよう。
Lambdaは素晴らしいけど、Letと一緒で、C言語風にはあわないので、引数はlistでいくつでも持てるように拡張しよう。
とりあえず、Listで拡張してしまったのだけど、タイプコンストラクタを後で使うように修正する予定にしよう。
ってことで、リターンコードをということを後でまとめて、分かりやすく書く。

とにかく疲れたのだけど、進めてよかったー。
課題はTyConでリストとか、タプルを用意できるといいということとか、ソースをもっと奇麗に書くといいと。

リターンはとりあえず、値が入ってない状態で他の式とのunifyで型が決まるようにしました。なので、Nothingではないかもしれませんけどまぁ、いいってかんじにしました。
とにかく、それっぽく動く物が出来たので今日はよしとしました。

ということで、以下ソースになります。

package hm

sealed trait SyntaxNode
case class Lambda(vs: List[String], body: SyntaxNode) extends SyntaxNode
case class Ident(name: String) extends SyntaxNode
case class Apply(fn: SyntaxNode, args: List[SyntaxNode]) extends SyntaxNode
case class Block(vs:List[SyntaxNode]) extends SyntaxNode
case class Var(v:String, defn:SyntaxNode) extends SyntaxNode
case class Return(v:SyntaxNode) extends SyntaxNode

object SyntaxNode {
  def string(ast: SyntaxNode): String = {
    ast match {
    case Ident(i) => i
    case Lambda(v,b) => "(fn ("+v./:(""){case (a,b)=>a +(if(a=="")"" else ",") + b}+") ⇒ "+string(b)+")"
    case Apply(f,a) => string(f)+"("+a./:(""){case (a,b)=>a+(if(a=="")"" else ",")+string(b)}+")"
    case Var(v,d) => "var "+v+" = "+string(d)
    case Block(xs) => "{" + xs./:(""){case (a, b)=>a + (if(a == "") "" else ";") + string(b)} + "}"
    case Return(v) => "return "+string(v)+";"
    }
  }
}


sealed trait Type
case class TyVar(id: Int, var instance: Option[Type]) extends Type {
  lazy val name:String = TypeSystem.nextUniqueName()
}
case class TyCon(name: String, args: Seq[Type]) extends Type
case class TypeError(msg: String) extends Exception(msg)
case class ParseError(msg: String) extends Exception(msg)


object TypeSystem {

  def Function(from:List[Type], to: Type):Type = {
    TyCon("→", from:::List(to))
  }
  val Integer = TyCon("int", List())
  val Bool = TyCon("bool", List())
  val Void = TyCon("Void", List())
  def Nothing:Type = newVar()

  var _nextVarName = 'α';
  def nextUniqueName():String = {
    val result = _nextVarName
    _nextVarName = (_nextVarName.toInt + 1).toChar
    result.toString
  }

  var _nextVarId = 0
  def newVar():TyVar = {
    val result = _nextVarId
    _nextVarId += 1
    TyVar(result, None)
  }

  def string(t:Type):String = {
    t match {
      case TyVar(_, Some(i)) => string(i)
      case v:TyVar => v.name
      case TyCon(name, args) =>
        if (args.length == 0) {
          name
        } else if (args.length == 2) {
          "(" + string(args(0)) + " " + name + " " + string(args(1)) + ")"
        } else {
          args.mkString(name + " ", " ", "")
        }
    }
  }


  def infer(ast:SyntaxNode, env:Map[String, Type]):Type = {
    infer(ast, env, Set())
  }
  var returns = Set[Type]()

  def infer(ast:SyntaxNode, env:Map[String, Type], nongen:Set[TyVar]):Type = {
    ast match {
      case Ident(name) => gettype(name, env, nongen)
      case Apply(fn, args) =>// 関数呼び出し
        val funtype = infer(fn, env, nongen)
        val argtypes = args./:(List[Type]()){case (ls,arg)=> infer(arg, env, nongen)::ls}.reverse
        val resulttype = newVar()
        unify(Function(argtypes, resulttype), funtype)
        resulttype
      case Lambda(args, body) =>// 関数
        val as:Map[String,TyVar] = args./:(Map[String,TyVar]()){case (a,b)=> a + (b -> newVar())}
        val argtypes = as./:(List[Type]()){case (a,(b,c))=> c::a}.reverse
        val aslist = as.toList
        val nongen2 = aslist./:(nongen:Set[TyVar]) {
          case (n, b) => b match { case (b,c:TyVar)=> n + c}
        }
        returns = Set[Type]()
        val resulttype = infer(body, env ++ as, nongen2)
        //println("returns="+returns)
        returns.foreach { case (a) => unify(a, resulttype) }
        Function(argtypes, resulttype)
      case Let(v, defn, body) =>
        val defntype = infer(defn, env, nongen)
        val newenv = env + (v -> defntype)
        infer(body, newenv, nongen)
      case Var(v, defn) => infer(defn, env, nongen)
      case Block(xs) =>
        var benv = Set[String]()
        def binfer(xs:List[SyntaxNode], env:Map[String, Type], nongen:Set[TyVar]):Type = {
          xs match {
            case List() => Void
            case List(x) => infer(x, env, nongen)
            case x::xs =>
              val t = infer(x, env, nongen)
              x match {
                case Var(v, defn) if(benv.contains(v)) => throw new TypeError("Already defined "+v)
                case Var(v, defn) => benv = benv + v; binfer(xs, env + (v -> t), nongen)
                case v => binfer(xs, env, nongen)
              }
          }
        }
        binfer(xs, env, nongen)
      case Return(v) => returns = returns + infer(v, env, nongen); Nothing
    }
  }

  def gettype(name:String, env:Map[String, Type], nongen:Set[TyVar]):Type = {
    if (env.contains(name)) {
      fresh(env(name), nongen)
    } else if (isIntegerLiteral(name)) {
      Integer
    } else {
      throw new ParseError("Undefined symbol " + name)
    }
  }

  def fresh(t:Type, nongen:Set[TyVar]):Type = {
    import scala.collection.mutable
    val mappings = new mutable.HashMap[TyVar, TyVar]
    def freshrec(tp:Type):Type = {
      prune(tp) match {
        case v:TyVar =>
          if (isgeneric(v, nongen)) {
            mappings.getOrElseUpdate(v, newVar)
          } else {
            v
          }
        case TyCon(name, args) => TyCon(name, args.map(freshrec(_)))
      }
    }
    freshrec(t)
  }



  def unify(t1:Type, t2:Type) {
    val type1 = prune(t1)
    val type2 = prune(t2)
    (type1, type2) match {
      case (a:TyVar, b) if(a == b) =>
      case (a:TyVar, b) =>
        if (occursintype(a, b)) {
          throw new TypeError("recursive unification")
        }
        a.instance = Some(b)
      case (a:TyCon, b:TyVar) => unify(b, a)
      case (a:TyCon, b:TyCon) =>
        if (a.name != b.name || a.args.length != b.args.length) {
          throw new TypeError("Type mismatch: "+string(a)+"≠"+string(b))
        }
        for (i <- 0 until a.args.length) {
          unify(a.args(i), b.args(i))
        }
    }
  }


  /**
   * tが表す型を返します。
   * ただしTyVar(Some(n))だった場合は、Some内の値を返します。
   * また、Someがネストしていた場合はネストを取り除きます。
   */
  def prune(t:Type):Type = {
    t match {
      case v@TyVar(_, Some(i)) =>
        var inst = prune(i)
        v.instance = Some(inst)
        inst
      case _ => t
    }
  }

  /**
   * Note: must be called with v 'pre-pruned'
   */
  def isgeneric(v: TyVar, nongen: Set[TyVar]):Boolean = {
   // !occursin(v, nongen)
    !nongen.exists{t => occursintype(v, t)}
  }

  /**
   * Note: must be called with v 'pre-pruned'
   */
  def occursintype(v: TyVar, type2: Type): Boolean = {
    prune(type2) match {
      case `v` => true
      case TyCon(name, args) => args.exists{t => occursintype(v, t)}
      case _ => false
    }
  }
/*
  def occursin(t: TyVar, list: Iterable[Type]):Boolean = {
    list.exists{t2 => occursintype(t, t2)}
  }
*/
  def isIntegerLiteral(name: String):Boolean = {
    val checkDigits = "^(\\d+)$".r
    checkDigits.findFirstIn(name).isDefined
  }
}

object hm {

  def main(args: Array[String]){
    Console.setOut(new java.io.PrintStream(Console.out, true, "utf-8"))

    val var1 = TypeSystem.newVar()
    val var2 = TypeSystem.newVar()
    val pairtype = TyCon("×", Array(var1, var2))

    val var3 = TypeSystem.newVar()

    val env = Map(
      "pair" -> TypeSystem.Function(List(var1,var2), pairtype),
      "true" -> TypeSystem.Bool,
      "if" -> TypeSystem.Function(List(TypeSystem.Bool, var3, var3), var3),
      "zero" -> TypeSystem.Function(List(TypeSystem.Integer), TypeSystem.Bool),
      "pred" -> TypeSystem.Function(List(TypeSystem.Integer), TypeSystem.Integer),
      "mul"-> TypeSystem.Function(List(TypeSystem.Integer, TypeSystem.Integer), TypeSystem.Integer)
    )


    val pair = Apply(Ident("pair"), List(Apply(Ident("f"), List(Ident("4"))), Apply(Ident("f"), List(Ident("true")))))
    
    // Should fail:
    // fn x => (pair(x(3) (x(true)))
    tryexp(env, Lambda(List("x"),
          Apply(Ident("pair"),List(
            Apply(Ident("x"), List(Ident("3"))),
            Apply(Ident("x"), List(Ident("true")))))))

    // pair(f(3), f(true))
    tryexp(env,
        Apply(Ident("pair"), List(
		Apply(Ident("f"), List(Ident("4"))), 
        	Apply(Ident("f"), List(Ident("true"))))))

    // letrec f = (fn x => x) in ((pair (f 4)) (f true))
    // tryexp(env, Let("f", Lambda("x", Ident("x")), pair))

    // fn f => f f (fail)
    tryexp(env, Lambda(List("f"), Apply(Ident("f"), List(Ident("f")))))

    // let g = fn f => 5 in g g
    tryexp(env, Block(List(
        Var("g", Lambda(List("f"), Ident("5"))),
        Apply(Ident("g"), List(Ident("g")))
    )))

    // example that demonstrates generic and non-generic Vars:
    // fn g => let f = fn x => g in pair (f 3, f true)
    /*
fn g => {
  var f = fn x g;
  pair(f 3, f true)
}
*/
    tryexp(env, Lambda(List("g"), Block(List(
           Var("f", Lambda(List("x"), Ident("g"))),
           Apply(Ident("pair"), List(
             Apply(Ident("f"), List(Ident("3"))),
             Apply(Ident("f"), List(Ident("true")))
           ))
    ))))

    // Function composition
    // fn f (fn g (fn arg (f g arg)))
    tryexp(env,
      Lambda(List("f"), Lambda(List("g"), Lambda(List("arg"), Apply(Ident("g"), List(Apply(Ident("f"), List(Ident("arg"))))))))
    )
    println(SyntaxNode.string(Block(List(Ident("f"),Ident("d")))))
    tryexp(env, Block(List()))
    tryexp(env, Block(List(
      Var("a", Ident("1")),
      Ident("a")
    )))
    tryexp(env, Block(List(
      Var("a", Ident("1")),
      Var("a", Ident("true")),
      Ident("a")
    )))
    // if("true")  3 else 4
    tryexp(env, Apply(Ident("if"), List(Ident("true"),Ident("4"),Ident("5"))))
    tryexp(env, Apply(Ident("if"), List(Ident("true"),Return(Ident("4")),Ident("5"))))
    tryexp(env, Block(List(
      Apply(Ident("if"), List(Ident("true"),Return(Ident("4")),Return(Ident("5")))),
      Ident("4")
    )))

    tryexp(env, Lambda(List(),Block(List(
      Apply(Ident("if"), List(Ident("true"),Return(Ident("4")),Return(Ident("5")))),
      Ident("true")
    ))))

  }

  def tryexp(env: Map[String, Type], ast: SyntaxNode) {
    print(SyntaxNode.string(ast) + " : ")
    try {
      val t = TypeSystem.infer(ast, env)
      print(TypeSystem.string(t))
    }catch{
      case ParseError(m) => print(m)
      case TypeError(m) => print(m)
    }
    println
  }
}