暗黙の型変換付き四則演算計算機のコンパイラ(X86_64用)

doubleとintの暗黙の型変換付き四則演算のX86_64用コンパイラを作りました。

このプログラムでは以下のような処理を行っています。

構文木演算子順位法で読み込み(read)

2パターンマッチングで抽象構文木に変換(st2ast)

3パターンマッチングで暗黙の型変換(implicitConversion)

4ASTをSCVMのコード列であるListにコンパイル(compile)

osx用のx86_64アセンブラを出力(emit_x86_64)

前回のインタプリタ実行部分をJITする感じで、macosxx86_64のアセンブラを出力するように修正した形になっています。

scalac t.scala; scala t.t > t.s ; gcc t.s ; ./a.out

コンパイルして、出力結果をファイルに保存して、gccコンパイルして実行する事が出来ます。

package t
object t {
  def main(argv:Array[String]) {
    apply("(2.0+3)/1.5*3000")
  }
  def apply(s:String) {
    val st = read(s)
    println("# st "+st)
    val ast = st2ast(st)
    println("# ast "+ast)
    val tast = implicitConversion(ast)
    println("# tast "+tast)
    val codes = compile(tast)
    println("# codes "+codes)
    emit_x86_64(codes)
  }
}

object read {
  val ints = """(?s)^[\t\r\n ]*([1-9][0-9]*)(.*$)""".r
  val bytes = """(?s)^[\t\r\n ]*([1-9][0-9]*b)(.*$)""".r
  val doubles = """(?s)^[\t\r\n ]*([0-9]+\.[0-9]*)(.*$)""".r
  val ns = """(?s)^[\t\r\n ]*([a-zA-Z_][a-zA-Z_0-9]*|[\(\)\{\}+/*\-,=;]|)(.*$)""".r
  def apply(str:String):Any = {
    var src = str
    var token:Any = ""
    var ptoken:Any = ""
    def lex():Any = {
      ptoken = token
      src match {
        case doubles(a,b) => token = a.toDouble; src = b
        case bytes(a,b) => token = a.substring(0,a.length-1).toByte; src = b
        case ints(a,b) => token = a.toInt; src = b
        case ns(a,b) => token = a; src = b
      }
      ptoken
    }
    def eat(e:Any):Any = {
      if(lex() != e) {
        throw new Exception("syntax error. found unexpected token "+ptoken)
      }
      ptoken
    }
    lex()
    def prs(a:Any):Any = a match {
      case "(" => (0, "p",")")
      case _ => -1
    }
    def ins(a:Any):Any = a match {
      case "+" => (10,"l")
      case "-" => (10,"l")
      case "*" => (20,"l")
      case "/" => (20,"l")
      case "(" => (0,"p",")")
      case _ => -1
    }
    def exp(p:Int):Any = {
      if(token ==")" || token == "}") return "void"
      def pr(t:Any):Any = {
        val op = t
        prs(op) match {
          case (np:Int,"p",ep) => val e = exp(np); (op,e,eat(ep))
          case _ => op
        }
      }
      var t = pr(lex())
      def in(t:Any):Any = {
        ins(token) match {
          case (np:Int,"l") if(np > p) => val op = lex(); in(t, op, exp(np))
          case (np:Int,"p",ep) => val sp = lex(); val e = exp(np); in(t, sp, e, eat(ep))
          case _ => t
        }
      }
      in(t)
    }
    exp(0)
  }
}

trait TType
case class TFloat() extends TType
case class TInt() extends TType
case class TNone() extends TType
case class TStr() extends TType
case class TFun(prms:List[TType],rc:TType) extends TType
sealed abstract case class E(t:TType)
case class EId(override val t:TType, s:String) extends E(t)
case class ECast(override val t:TType, a:E) extends E(t)
case class EInt(override val t:TType, a:scala.Int) extends E(t)
case class EFloat(override val t:TType, a:Double) extends E(t)
case class ECall(override val t:TType, f:E, a:List[E]) extends E(t)

object st2ast {
  def apply(st:Any):E = st match {
    case st:Int => EInt(TInt(), st)
    case st:Double => EFloat(TFloat(), st)
    case (a,"+",b)=> ECall(TNone(), EId(TNone(),"add"), List(st2ast(a),st2ast(b)))
    case (a,"-",b)=> ECall(TNone(), EId(TNone(),"sub"), List(st2ast(a),st2ast(b)))
    case (a,"*",b)=> ECall(TNone(), EId(TNone(),"mul"), List(st2ast(a),st2ast(b)))
    case (a,"/",b)=> ECall(TNone(), EId(TNone(),"div"), List(st2ast(a),st2ast(b)))
    case ("(",a,")") => apply(a)
  }
}

object implicitConversion {
  def apply(e:E):E = f(e)

  def main(argv:Array[String]) {
    println("*"+f(ECall(TNone(),EId(TNone(),"add"),List(EFloat(TFloat(),1),EInt(TInt(),2)))))
    println("*"+f(ECall(TNone(),EId(TNone(),"add"),List(EInt(TInt(),1),EInt(TInt(),2)))))
    println("*"+f(ECall(TNone(),EId(TNone(),"add"), List(ECall(TNone(),EId(TNone(),"add"),List(EFloat(TFloat(),
1),EInt(TInt(),1))),EInt(TInt(),2)))))
  }
  def f(e:E):E = e match {
    case ECast(t,a) => ECast(t, f(a))
    case EInt(t,i) => EInt(TInt(), i)
    case EFloat(t,i) => EFloat(TFloat(), i)
    case EId(t,a) => EId(t,a)
    case ECall(t,EId(tid,id),xs1) =>
      val xs = xs1.map(f)
     // 暗黙の型変換つき型チェック
      def typeCheck(ts:List[TType],es:List[E]):Option[List[E]] = {
       (ts,es) match {
          case (List(),List()) => Some(List())
          case (t::ts,e::es) =>
            typeCheck(ts,es) match {
              case None => None
              case Some(e2) =>
                val t2 = e match {
                  case ECall(TFun(_,r),_,_) => r
                  case e:E => e.t
                }
                if(t==t2) Some(e::e2)
                else if(implicitConversions.contains(t->t2)) Some(implicitConversions(t->t2)(e)::e2)
                else None
            }
          case _ => None
        }
      }
      // 関数を全て取り出す
      def fns(funs:List[TFun]):Option[E] = funs match {
        case List() => None
        case (i@TFun(l,r))::ls =>
          typeCheck(l, xs) match {
            case None => fns(ls)
            case Some(e) => Some(ECall(r, EId(i,id), e))
          }
      }
      fns(functions(id)) match {
        case None=> throw new Exception("not found method "+e)
        case Some(e) => e
      }
  }


  // 関数定義表

  var functions = Map[String,List[TFun]](
    "add"->List(
      TFun(List(TInt(),TInt()),TInt()),
      TFun(List(TStr(),TStr()),TStr()),
      TFun(List(TFloat(),TFloat()),TFloat())
    ),
    "sub"->List(
      TFun(List(TInt(),TInt()),TInt()),
      TFun(List(TStr(),TStr()),TStr()),
      TFun(List(TFloat(),TFloat()),TFloat())
    ),
    "mul"->List(
      TFun(List(TInt(),TInt()),TInt()),
      TFun(List(TStr(),TStr()),TStr()),
      TFun(List(TFloat(),TFloat()),TFloat())
    ),
    "div"->List(
      TFun(List(TInt(),TInt()),TInt()),
      TFun(List(TStr(),TStr()),TStr()),
      TFun(List(TFloat(),TFloat()),TFloat())
    )
  )

  // 暗黙の型変換の表
  var implicitConversions = Map[(TType,TType),(E)=>E] (
    (TInt(),TStr()) -> ((a:E)=>ECast(TStr(),a)),
    (TFloat(),TStr()) -> ((a:E)=>ECast(TStr(),a)),
    (TFloat(),TInt()) -> ((a:E)=>ECast(TFloat(),a))
  )
}

object compile {

  def apply(e:E):List[Any] = {
    f(e,List[Any]("print"+typeNames(e.t)))
  }

  def typeNames(t:TType):String = t match {
    case TFun(_,a)=> typeNames(a)
    case TFloat()=>"d"
    case TInt()=>"i"
    case _ =>throw new Exception("compile error unknown type "+t)
  }
  def f(e:E,l:List[Any]):List[Any] = e match {
    case EId(t,s)=> throw new Exception("compile error")
    case ECast(t,a)=> f(a,("cast"+typeNames(a.t)+"2"+typeNames(t))::l)
    case EInt(t,a)=> ("pushi",a)::l
    case EFloat(t,a)=> ("pushd",a)::l
    case ECall(t,EId(_,n),a)=> a.foldLeft((n+typeNames(t))::l){case(l2,a)=>f(a,l2)}
  }
}

object emit_x86_64 {
  def asm(s:String) {
    println(s)
  }
  def apply(prgs:List[Any]) {
    _printd()
    _printi()
    enter("_main")
    compile(prgs)
    leave()
  }

  def _printd() {
    asm("""
	.section	__TEXT,__text,regular,pure_instructions
	.globl	_printd
_printd:
	pushq	%rbp
	movq	%rsp, %rbp
	subq	$16, %rsp
	movsd	%xmm0, -8(%rbp)
	movsd	-8(%rbp), %xmm0
	movb	$1, %al
	leaq	L_.str(%rip), %rcx
	movq	%rcx, %rdi
	callq	_printf
	addq	$16, %rsp
	popq	%rbp
	ret
 	.section	__TEXT,__cstring,cstring_literals
L_.str:
	.asciz	 "%f\n"
"""
     )
  }
  def _printi() {
    asm("""
	.section	__TEXT,__text,regular,pure_instructions
	.globl	_printi
_printi:
	pushq	%rbp
	movq	%rsp, %rbp
	subq	$16, %rsp
	movl	%edi, %eax
	movl	%eax, -4(%rbp)
	movl	-4(%rbp), %eax
	xorb	%cl, %cl
	leaq	L_.str1(%rip), %rdx
	movq	%rdx, %rdi
	movl	%eax, %esi
	movb	%cl, %al
	callq	_printf
	addq	$16, %rsp
	popq	%rbp
	ret
	.section	__TEXT,__cstring,cstring_literals
L_.str1:
	.asciz	 "%d\n"
    """)
  }
  def d2l(a:Double):Long = java.lang.Double.doubleToLongBits(a)
  var doubles = List[Double]()
  var doublesCount = 0
  def doubleL(d:Double):String = {
    var label = "LCPI_"+doublesCount
    doubles = doubles:::List(d)
    doublesCount += 1
    label
  }
  def enter(fname:String) {
   asm("\t.section\t__TEXT,__text,regular,pure_instructions")
   asm("\t.globl\t"+fname)
   asm(fname+":")
   asm("\tpushq\t%rbp")
   asm("\tmovq\t%rsp, %rbp")
  }

  def leave() {
    asm("\tpopq\t%rbp")
    asm("\tret")
    asm("\t.section\t__TEXT,__literal8,8byte_literals")
    asm("\t.align\t3")
    doubles.zipWithIndex.foreach {
      case (d,i) =>
        asm("LCPI_"+i+":")
        asm("\t.quad\t"+d2l(d)+"\n")
    }
    doubles = List[Double]()
  }
  def compile(ls:List[Any]) {
    ls.foreach {
      case a => comp(a)
    }
  }
  def comp(l:Any) {
    l match {
    case ("pushd", d:Double) =>
      asm("\tmovq\t"+doubleL(d)+"(%rip), %rax")
      asm("\tpushq\t%rax")
    case ("pushd", reg:String) =>
      asm("\tmovd "+reg+", %rax")
      asm("\tpushq\t%rax")
    case ("popd", reg:String) =>
      asm("\tpopq %rax")
      asm("\tmovd %rax, "+reg)
    case "addd" =>
      comp("popd","%xmm0")
      comp("popd","%xmm1")
      asm("\taddsd %xmm1, %xmm0")
      comp("pushd","%xmm0")
    case "subd" =>
      comp("popd","%xmm0")
      comp("popd","%xmm1")
      asm("\tsubsd %xmm1, %xmm0")
      comp("pushd","%xmm0")
    case "muld" =>
      comp("popd","%xmm0")
      comp("popd","%xmm1")
      asm("\tmulsd %xmm1, %xmm0")
      comp("pushd","%xmm0")
    case "divd" =>
      comp("popd","%xmm0")
      comp("popd","%xmm1")
      asm("\tdivsd %xmm1, %xmm0")
      comp("pushd","%xmm0")
    case "printd" =>
      comp("popd","%xmm0")
      asm("\tcallq\t_printd")
    case ("pushi", i:Int) =>
      asm("\tmovl\t$"+i+", %eax")
      asm("\tpushq\t%rax")
    case ("pushi", reg:String) =>
      asm("\tpushq\t"+reg)
    case ("popi", reg:String) =>
      asm("\tpopq "+reg)
    case "addi" =>
      comp("popi","%rax")
      comp("popi","%rcx")
      asm("\taddl %ecx, %eax")
      comp("pushi","%rax")
    case "subi" =>
      comp("popi","%rax")
      comp("popi","%rcx")
      asm("\tsubl %ecx, %eax")
      comp("pushi","%rax")
    case "muli" =>
      comp("popi","%rax")
      comp("popi","%rcx")
      asm("\timull %ecx, %eax")
      comp("pushi","%rax")
    case "divi" =>
      comp("popi","%rax")
      comp("popi","%rcx")
      asm("\tcltd")
      asm("\tidivl %ecx")
      comp("pushi","%rax")
    case "printi" =>
      comp("popi","%rdi")
      asm("\tcall\t_printi")
    case "casti2d" =>
      comp("popi","%rax")
      asm("\tcvtsi2sd %eax, %xmm0")
      comp("pushd","%xmm0")
    case "castd2i" =>
      comp("popd","%xmm0")
      asm("\tcvttsd2si\t%xmm0, %eax")
      comp("pushi","%rax")
    }
  }
}