型推論

以下のサイトを参考に

http://dysphoria.net/2009/06/28/hindley-milner-type-inference-in-scala/

hindley-milnerの型推論Python版をPHPで実装してみました。
Modula2->Perl->Scala->Python->PHP
という流れで移植に移植を重ねられています。
洗練されたり、悪くなったりしてると思います。
今週末はいろいろと、型推論のソースを見てました。
まだ良くわかってませんが、単相だけではなくて多相の型推論も眼中に入って来た感じで
嬉しいです。

<?php

abstract class Term {

    function __toString() {
        switch (true) {
            case $this instanceOf Ident: return $this->name;
            case $this instanceOf Lambda: return "(fn {$this->v} => {$this->body})";
            case $this instanceOf Apply: return "({$this->fn} {$this->arg})";
            case $this instanceOf Let: return "(let {$this->v} = {$this->defn} in {$this->body})";
            case $this instanceOf Letrec: return "(letrec {$this->v} = {$this->defn} in {$this->body})";
        }
    }

}

class Lambda extends Term {

    function __construct($v, $body) {
        $this->v = $v;
        $this->body = $body;
    }

}

function Lambda($v, $body) {
    return new Lambda($v, $body);
}

class Ident extends Term {

    function __construct($name) {
        $this->name = $name;
    }

}

function Ident($name) {
    return new Ident($name);
}

class Apply extends Term {

    function __construct($fn, $arg) {
        $this->fn = $fn;
        $this->arg = $arg;
    }

}

function Apply($fn, $arg) {
    return new Apply($fn, $arg);
}

class Let extends Term {

    function __construct($v, $defn, $body) {
        $this->v = $v;
        $this->defn = $defn;
        $this->body = $body;
    }

}

function Let($v, $defn, $body) {
    return new Let($v, $defn, $body);
}

class Letrec extends Term {

    function __construct($v, $defn, $body) {
        $this->v = $v;
        $this->defn = $defn;
        $this->body = $body;
    }

}

function Letrec($v, $defn, $body) {
    return new Letrec($v, $defn, $body);
}

class TypeError extends Exception {

}

class ParseError extends Exception {

}

class TypeVariable {

    static $next_variable_id = 0;

    function __construct() {
        $this->id = TypeVariable::$next_variable_id;
        TypeVariable::$next_variable_id += 1;
        $this->instance = null;
    }

    static $next_variable_name = 'a';

    public function __get($name) {
        if ($name == "name") {
            $this->name = TypeVariable::$next_variable_name;
            TypeVariable::$next_variable_name = chr(ord(TypeVariable::$next_variable_name) + 1);
            return $this->name;
        }
    }

    public function __toString() {
        if ($this->instance !== null) {
            return "" . $this->instance;
        } else {
            return $this->name;
        }
    }

}

function TypeVariable() {
    return new TypeVariable();
}

class TypeOperator {

    function __construct($name, $types) {
        $this->name = $name;
        $this->types = $types;
    }

    public function __toString() {
        switch(count($this->types)) {
        case 0: return $this->name;
        case 2: return "({$this->types[0]} {$this->name} {$this->types[1]})";
        default: return "{$this->name} " . implode(" ", $this->types);
        }
    }

}

function TypeOperator($name, $types) {
    return new TypeOperator($name, $types);
}

function Fun($from_type, $to_type) {
    return TypeOperator("->", array($from_type, $to_type));
}

function Integer() {
    static $instance = null;
    if ($instance === null) {
        $instance = new TypeOperator("int", array());  # Basic integer
    }
    return $instance;
}

function Bool() {
    static $instance = null;
    if ($instance === null) {
        $instance = new TypeOperator("bool", array()); # Basic bool
    }
    return $instance;
}

class Inferer {

    static function analyse($node, $env, $non_generic = null) {
        if ($non_generic === null) {
            $non_generic = array();
        }
        switch (true) {
            case $node instanceof Ident:
                return Inferer::getType($node->name, $env, $non_generic);
            case $node instanceof Apply:
                $fun_type = Inferer::analyse($node->fn, $env, $non_generic);
                $arg_type = Inferer::analyse($node->arg, $env, $non_generic);
                $result_type = new TypeVariable();
                Inferer::unify(Fun($arg_type, $result_type), $fun_type);
                return $result_type;
            case $node instanceof Lambda:
                $arg_type = TypeVariable();
                $env[$node->v] = $arg_type;
                $non_generic[] = $arg_type;
                $result_type = Inferer::analyse($node->body, $env, $non_generic);
                return Fun($arg_type, $result_type);
            case $node instanceof Let:
                $env[$node->v] = Inferer::analyse($node->defn, $env, $non_generic);
                return Inferer::analyse($node->body, $env, $non_generic);
            case $node instanceOf Letrec:
                $non_generic[] = $env[$node->v] = $new_type = TypeVariable();
                $defn_type = Inferer::analyse($node->defn, $env, $non_generic);
                Inferer::unify($new_type, $defn_type);
                return Inferer::analyse($node->body, $env, $non_generic);
        }
        throw new Exception("Unhandled syntax node " . $t);
    }

    static function getType($name, $env, $non_generic) {
        switch (true) {
            case isset($env[$name]): return Inferer::fresh($env[$name], $non_generic);
            case Inferer::isIntegerLiteral($name): return Integer();
            default: throw new ParseError("Undefined symbol " . $name);
        }
    }

    static function fresh($t, $non_generic) {

        $mappings = array(); # A mapping of TypeVariables to TypeVariables
        $freshrec = null;
        $freshrec = function ($tp) use (&$mappings, $non_generic, &$freshrec) {
            $p = Inferer::prune($tp);
            switch (true) {
                case $p instanceof TypeVariable && Inferer::isGeneric($p, $non_generic):
                    if (!isset($mappings[$p->id])) {
                        $mappings[$p->id] = TypeVariable();
                    }
                    return $mappings[$p->id];
                case $p instanceof TypeVariable:
                    return $p;
                case $p instanceof TypeOperator:
                    return TypeOperator($p->name, array_map($freshrec, $p->types));
            }
        };
        return $freshrec($t);
    }

    static function unify($t1, $t2) {
        $a = Inferer::prune($t1);
        $b = Inferer::prune($t2);
        switch (true) {
            case $a instanceOf TypeVariable && $a === $b:
                break;
            case $a instanceOf TypeVariable:
                if (Inferer::occursInType($a, $b)) {
                    throw new TypeError("recursive unification");
                }
                $a->instance = $b;
                break;
            case $a instanceof TypeOperator && $b instanceOf TypeVariable:
                Inferer::unify($b, $a);
                break;
            case $a instanceof TypeOperator && $b instanceof TypeOperator:
                if ($a->name != $b->name || count($a->types) != count($b->types)) {
                    throw new TypeError("Type mismatch: {$a} != {$b}");
                }
                foreach ($a->types as $i => $p) {
                    Inferer::unify($p, $b->types[$i]);
                }
                break;
            default:
                throw new TypeError("Not unified");
        }
    }

    static function prune($t) {

        if ($t instanceof TypeVariable) {
            if ($t->instance !== null) {
                $t->instance = Inferer::prune($t->instance);
                return $t->instance;
            }
        }
        return $t;
    }

    static function isGeneric($v, $non_generic) {
        return !Inferer::occursIn($v, $non_generic);
    }

    static function occursInType($v, $type2) {
        $pruned_type2 = Inferer::prune($type2);
        if ($pruned_type2 == $v) {
            return true;
        } else if ($pruned_type2 instanceOf TypeOperator) {
            return Inferer::occursIn($v, $pruned_type2->types);
        }
        return false;
    }

    static function occursIn($t, $types) {
        foreach ($types as $t2) {
            if (Inferer::occursInType($t, $t2)) {
                return true;
            }
        }
        return false;
    }

    static function isIntegerLiteral($name) {
        return preg_match('/^[0-9]+$/', $name) > 0;
    }

    static function println($str) {
        echo $str . "\n";
    }

    static function tryExp($env, $node) {
        print($node . " : ");
        try {
            $t = Inferer::analyse($node, $env);
            Inferer::println($t);
        } catch (ParseError $e) {
            Inferer::println($e->getMessage());
        } catch (TypeError $e) {
            Inferer::println($e->getMessage());
        }
    }

    static function main() {

        $var1 = TypeVariable();
        $var2 = TypeVariable();
        $pair_type = TypeOperator("*", array($var1, $var2));

        $var3 = TypeVariable();

        $env = array("pair" => Fun($var1, Fun($var2, $pair_type)),
            "true" => Bool(),
            "cond" => Fun(Bool(), Fun($var3, Fun($var3, $var3))),
            "zero" => Fun(Integer(), Bool()),
            "pred" => Fun(Integer(), Integer()),
            "times" => Fun(Integer(), Fun(Integer(), Integer())));

        $pair = Apply(Apply(Ident("pair"), Apply(Ident("f"), Ident("4"))), Apply(Ident("f"), Ident("true")));

        # factorial
        Inferer::tryExp($env, Letrec("factorial", # letrec factorial =
                        Lambda("n", # fn n =>
                            Apply(
                                Apply(# cond (zero n) 1
                                    Apply(Ident("cond"), # cond (zero n)
                                        Apply(Ident("zero"), Ident("n"))
                                    ), Ident("1")),
                                Apply(# times n
                                    Apply(Ident("times"), Ident("n")), Apply(Ident("factorial"), Apply(Ident("pred"), Ident("n")))
                                )
                            )
                        ), # in
                        Apply(Ident("factorial"), Ident("5"))
                ));

        # Should fail:
        # fn x => (pair(x(3) (x(true)))
        Inferer::tryExp($env, Lambda("x", Apply(
                                Apply(Ident("pair"), Apply(Ident("x"), Ident("3"))), Apply(Ident("x"), Ident("true"))))
        );

        # pair(f(3), f(true))
        Inferer::tryExp($env, Apply(
                        Apply(Ident("pair"), Apply(Ident("f"), Ident("4"))), Apply(Ident("f"), Ident("true")))
        );
        # let f = (fn x => x) in ((pair (f 4)) (f true))
        Inferer::tryExp($env, Let("f", Lambda("x", Ident("x")), $pair)
        );
        # fn f => f f (fail)
        Inferer::tryExp($env, Lambda("f", Apply(Ident("f"), Ident("f")))
        );

        # let g = fn f => 5 in g g
        Inferer::tryExp($env, Let("g", Lambda("f", Ident("5")), Apply(Ident("g"), Ident("g")))
        );

        # example that demonstrates generic and non-generic variables:
        # fn g => let f = fn x => g in pair (f 3, f true)
        Inferer::tryExp($env, Lambda("g", Let("f", Lambda("x", Ident("g")), Apply(
                                        Apply(Ident("pair"), Apply(Ident("f"), Ident("3"))
                                        ), Apply(Ident("f"), Ident("true")))))
        );

        # Function composition
        # fn f (fn g (fn arg (f g arg)))
        Inferer::tryExp($env, Lambda("f", Lambda("g", Lambda("arg", Apply(Ident("g"), Apply(Ident("f"), Ident("arg"))))))
        );
    }

}

Inferer::main();