sendailockサーバ

というわけで、この間の非同期型echoサーバをいじって、非同期型のロックサーバを作ってみました。

ノンブロッキングIOを使ってロック要求を受け付け、ロック用のマップにリクエストのリストに登録します。
リストが空ならロック取得でき、そうでなければ待ち行列に追加して待ちます。
ロックが解除されると、ロック待ち行列から1つ取り出してロックを渡します。

というようなことをすることで、スレッドは1つでも複数のロックを管理することが出来ます。
ソースのコメントなどは適当ですが、ま、とりあえずこんなところで。

以下ソースです。

package lock;

import scala.actors.Actor;
import scala.actors.Actor._;
import scala.collection.jcl.Conversions._;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.{LinkedHashMap,Map,LinkedHashSet};
import java.util.Map.Entry;
import java.nio.channels.{ SelectionKey, Selector, ServerSocketChannel, SocketChannel }
import java.nio.charset.Charset;
import java.security.MessageDigest;

/**
 * リクエストクラス
 */
case class Request(val remoteAddress:String, val channel:SocketChannel, var lock:String);

object LockActor extends Actor {

	/**
	 * メッセージクラス
	 */
	sealed abstract class Message;
	case class Init(port:Int)			extends Message;
	case class Select()					extends Message;
	case class Accept(key:SelectionKey)	extends Message;
	case class Read(key:SelectionKey)	extends Message;
	case class Write(key:SelectionKey)	extends Message;

	/**
	 * メイン関数
	 */
	def main(args:Array[String]) {
		// ロックアクター開始
		LockActor.start();
		// Initメッセージを送る
		LockActor ! Init(1006);
	}

	/**
	 * アクター実行
	 */
	def act() {
		loop {
			react {
			// Initメッセージ処理
			case Init(port)		=> init(port); LockActor ! Select;
			case Select			=> select();   LockActor ! Select;
			case Accept(key)	=> accept(key);
			case Read(key)		=> read(key);
			case Write(key)		=> write(key);
			}
		}
	}
	// ユーザー
	val user = "hoge";
	// パスワード
	val pass = "pass";

	// md5に変換されたパスワード
	val md5pass = md5(pass);

	// 先頭の文字列
	val head = "LOCK " + user + " " + md5pass + " ";
	
	// バッファのサイズ 1kあれば足りるはず。
	val BUF_SIZE = 1024;

	// セレクター
	val selector:Selector = Selector.open();

	// サーバーチャンネル
	val serverChannel:ServerSocketChannel = ServerSocketChannel.open();

	/**
	 * 初期化処理
	 * @param Int port サーバのポート番号
	 */
	def init(port:Int) {
		// サーバーチャンネルをブロックしない設定
		serverChannel.configureBlocking(false);
		// サーバーチャンネルのソケットを指定ポートでバインド
		serverChannel.socket().bind(new InetSocketAddress(port));
		// セレクターを登録
		serverChannel.register(selector, SelectionKey.OP_ACCEPT);
		// ログ
		println("lock server start port=" + port);
	}

	/**
	 * セレクト処理
	 * 
	 * セレクターに登録してあるキーを取り出してメッセージを送る
	 */
	def select() {
		selector.select();
		selector.selectedKeys().foreach { key =>
			if (key.isAcceptable()) {
				// アクセプト
				LockActor ! Accept(key);
			} else
			if (key.isWritable() && key.isValid()) {
				// 出力
				LockActor ! Write(key);
			} else
			if (key.isReadable()) {
				// 読み込み
				LockActor ! Read(key);
			}
		}
	}
	/**
	 * アクセプト
	 * @param SelectionKey key セレクションキー
	 */
	def accept(key:SelectionKey) {

		// サーバーソケットチャンネルを取得
		val socket:ServerSocketChannel = key.channel().asInstanceOf[ServerSocketChannel];

		// アクセプト
		socket.accept() match {
		// nullなら何もしない
		case null =>
		// channelのとき
		case channel:SocketChannel =>
			// アドレスを取得
			val remoteAddress:String = channel.socket().getRemoteSocketAddress().toString();
			// リクエストオブジェクト作成
			val request = Request(remoteAddress, channel, null);
			println(request + ":connect");
			// ノンブロッキングモードにする
			channel.configureBlocking(false);
			// セレクターに読み込み登録する
			channel.register(selector, SelectionKey.OP_READ, request);
		}
	}

	/**
	 * 読み込み
	 * @param SelectionKey key セレクションキー
	 */
	def read(key:SelectionKey) {

		// リクエストオブジェクト取得
		val request:Request = key.attachment().asInstanceOf[Request];

		// ソケットチャンネル取得
		val channel = key.channel().asInstanceOf[SocketChannel];
		println(request + ":read");
		try {
			// バッファ作成
			val buf:ByteBuffer = ByteBuffer.allocate(BUF_SIZE);

			// 読み込み
			request.channel.read(buf) match {
			// 閉じる
			case -1 => close(request);
			// なにもしない
			case 0 =>
			// 読み込みあり
			case x =>
				// バッファ読み込み
				buf.flip();
				var str = Charset.forName("UTF-8").decode(buf).toString();

				// ヘッダ部取得
				var strHead = str.substring(0, head.length);
				println(request + ":" + str);
				buf.flip();

				// ヘッダチェック
				if (head != strHead) {
					println("ng");
					// 駄目なのでngを返す
					buf.put("ng\r\n".getBytes());
					buf.flip();
					request.channel.write(buf);
					// 切断
					close(request);
				} else {
					println("ok");
					// okなのでロック名保存
					request.lock = str.substring(head.length, str.length - 1);
					
					// 書き込みオッケー状態にするのは、ロック取得できたらにしよう。
					locks.synchronized {
						var map = locks.get(request.lock);
						if (map == null) {
							val newmap = new LinkedHashSet[Request]();
							newmap.add(request);
							locks.put(request.lock, newmap);
							request.channel.register(selector, SelectionKey.OP_WRITE, request);
						} else {
							map.add(request);
						}
					}
				}
			}
		} catch {
		// 接続断などなので接続を閉じる
		case e:java.io.IOException => close(request);
		}
	}

	/**
	 * 出力
	 * @param SelectionKey key キー情報
	 */
	def write(key:SelectionKey) {
		val request:Request = key.attachment().asInstanceOf[Request];
		println(request + ":write");
		try {
			val buf:ByteBuffer = ByteBuffer.allocate(BUF_SIZE);
			buf.put("ok\r\n".getBytes());
			buf.flip();
			request.channel.write(buf);
			request.channel.register(selector, SelectionKey.OP_READ, request);
		} catch {
		case e:java.io.IOException => close(request);
		}
	}

	/**
	 * 接続断
	 * @param Request request リクエストオブジェクト
	 */
	def close(request:Request) {
		request.channel.close();
		println(request + ":close");
		locks.synchronized {
			val map = locks.get(request.lock);
			if (map != null) {
				map.remove(request);
				val itr = map.iterator();
				if (itr.hasNext()) {
					val nextRequest = itr.next();
					nextRequest.channel.register(selector, SelectionKey.OP_WRITE, nextRequest);
				} else {
					locks.remove(request.lock);
				}
			}
		}
	}
	val locks:LRUMap = new LRUMap(1000);
	/**
	 * LRUマップ
	 * 古い順から消えていくマップ
	 * @param Int maxSize 最大サイズ
	 */
	class LRUMap(maxSize:Int) extends LinkedHashMap[String, LinkedHashSet[Request]](16, 0.75f, true) {
		/**
		 * 削除
		 * LRUマップから削除するかどうかの判定をする
		 */
		protected override def removeEldestEntry(eldest:Map.Entry[String, LinkedHashSet[Request]]):Boolean = {
			// maxSizeが0より大きくてsize()がmaxSizeより大きいときに削除
			return maxSize > 0 && size() > maxSize;
		}
	}
	/**
	 * md5文字列計算
	 * @param String str 文字列
	 * @return md5文字列
	 */
	private def md5(str:String):String = {
		try {
			MessageDigest.getInstance("md5").digest(str.getBytes())
			.foldLeft("") {(s:String, i:Byte) =>
				val b = Integer.toHexString(i & 0xff);
				if (b.length() == 1) s + "0" + b;
				else				 s + b;
			}
		} catch {
		case e:Exception => ""
		}
	}
}