Hatena::ブログ(Diary)

CLOVER

2013-09-21

Scalaで書く、高信頼型(Reliable)UDP Echoクライアント/サーバ

これまで、Clojureを中心にUDPを使ったプログラミングをしていましたが、今回はちょっとScalaのみにします。

お題は、高信頼型(Reliable)UDP

UDPは、データグラムが正しい順序で到着することを保証しませんし、パケットロスも検出できません。これを克服するために、

  • すべてのデータグラムにシーケンス番号を付ける
  • 送信側は、(受信側からの)確認にタイムアウトを設け、確認が時間以内に来なかったらリクエストを再送する

ということで信頼性を付与します。

タイムアウトと再送については、単に一定のタイムアウト時間を使うやり方ではなく、ネットワークの性質や状態、負荷などを考慮する必要があります。ここで、TCPに習い

  • パケットが送信者から受信者に到着して再び戻るまでの往復時間(Round-Trip Time:ラウンドトリップタイム)の現時点の推計値を統計的に求め、維持する。この推計値を元に、パケットをネットワークに送り出すものの、それが速すぎてネットワークを独占的に充満しないようにする
  • 次のリトライまでの待ち時間を計算する時、「指数的バックオフ」を使うので、エラーが繰り替えされる度に次の再送までの時間が指数的に伸びていく。すると、ネットワーク上のパケット数がだんだん減っていくので、次の再送がうまくいく可能性が高くなる

ということを行います。ここで、指数的バックオフというのは、例えば3度目の再送までの時間は、基準値nの3倍ではなく、2の3乗倍とするようなことです。TCPの輻輳回避のテクニックと。

とまあ、たいそうなことを書いていますがこちらの書籍の受け売りです。

Javaネットワークプログラミングの真髄

Javaネットワークプログラミングの真髄



こちらの書籍に、ReliableなUDP Echoサーバの例があったので、それをScalaで書き直しました。本来はClojureでやった方が…なのですが、今のClojure力だと実装とデバッグが困難になる気がするので、ここはいったんScalaで。

ちなみに、書籍のサンプルには明らかな誤りがあって、そちらはサポートページを見て修正しました。

Javaネットワークプログラミングの真髄--follow-up
http://homepage1.nifty.com/algafield/jnet/

そして、書籍には載っていなかったReliableなクライアントも追加しています。

今回のプログラムで最も重要なのは、ReliableDatagramSocketというクラスで、以下の機能を実装したものです。

  • 送信されるパケットにユニークなシーケンス番号を付け、受信したパケットからそのシーケンスを取り出す。その操作は、呼び出し元には見えないところで行われる
  • 入力からシーケンス番号を取り出すメソッドと、それを出力にセットするメソッドを、サーバがリプライを用意する時に使う
  • クライアントが使う、sendReceiveメソッドを作成する。これは、送信用のデータグラムと受信用のデータグラムを引数に取り、送信のデータグラムを送り出してサーバのリプライを待つ。その後、指定時間までにレスポンスがなければ、間隔を調整しながらタイムアウトと再送を最大リトライカウントまで繰り返す。最大カウントを超えると、SocketTimeoutExceptionを投げる

タイムアウト時の例外は、SocketTimeoutExceptionにしました。さすがに、JDK 1.4より前はもういいでしょう…。

では、書いていきます。ほぼ写経なので、Scalaっぽくないところはご愛嬌。まずはimport文。

import scala.annotation.tailrec
import scala.math.{abs, max, min}
import scala.util.{Failure, Success, Try}

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream, IOException}
import java.net.{DatagramPacket, DatagramSocket, InetAddress, InetSocketAddress, SocketAddress, SocketTimeoutException}
import java.nio.charset.StandardCharsets
import java.util.Date

import Reliabilities._

Reliabilitiesというのは、これから載せる定数とかを定義したオブジェクトです。

で、その定数とかを定義したReliabilitiesオブジェクト。

// このプログラムで、時間の単位はすべて「秒」

object Reliabilities {
  // タイムアウトの最小・最大値
  val MIN_RETRANSMIT_TIMEOUT: Int = 1
  val MAX_RETRANSMIT_TIMEOUT: Int = 64
  // ひとつのデータグラムの最大再送回数、3か4くらい
  val MAX_RETRANSMISSIONS: Int = 4

  // 通信先ポート
  val PORT: Int = 50000

  def log(msg: Any, msgs: Any*): Unit =
    println(s"[${new Date}] ${(msg :: msgs.toList).mkString(" ")}")

  // AutoCloseableをfor式で使えるようにするための、Implicit Class
  implicit class AutoCloseableWrapper[A <: AutoCloseable](val underlying: A) extends AnyVal {
    def foreach(fun: A => Unit): Unit =
      try {
        fun(underlying)
      } finally {
        underlying.close()
      }
  }
}

コメントにもありますが、今回登場するプログラムに使われる時間の単位は、全て「秒」です。

続いて、ラウンドトリップタイムを計算するクラス。

class RoundTripTimer {
  // 最も最近のラウンドトリップタイム(RTT)
  private var roundTripTime: Float = 0.0F
  // 平滑化したRTT
  private var smoothedTripTime: Float = 0.0F
  // 平滑化した標準偏差
  private var deviation: Float = 0.75F
  // 再送カウント:0, 1, 2
  private var retransmissions: Short = 0
  // 現在の再送タイムアウト
  var currentTimeout = minmax(calculateRetransmitTimeout)

  // 再送タイムアウトを返す
  private def calculateRetransmitTimeout: Int =
    (smoothedTripTime + 4.0 * deviation).toInt

  // 上限のある再送タイムアウトを返す
  private def minmax(rto: Float): Float =
    min(max(rto, MIN_RETRANSMIT_TIMEOUT), MAX_RETRANSMIT_TIMEOUT)

  // 新たなパケットを送信する度に、現在の再送カウントを初期化する
  def newPacket(): Unit =
    retransmissions = 0

  /**
   * 成功した受信の直後に呼ばれ、ラウンドトリップタイムを計算し、
   * 次に平滑化したラウンドトリップとその分散(偏差)を計算する
   **/
  def stoppedAt(ms: Long): Unit = {
    // このパケットのラウンドトリップを計算する
    roundTripTime = ms/1000

    // ラウンドリップタイムの推計値とその平均偏差を更新する
    val delta = roundTripTime - smoothedTripTime
    smoothedTripTime += (delta / 8.0).toFloat
    deviation += ((abs(delta) - deviation) / 4.0).toFloat

    // 現在のタイムアウトを再計算する
    currentTimeout= minmax(calculateRetransmitTimeout)
  }

  /**
   * タイムアウトが生じた後に呼ばれる。ギブアップすべき時間なら true を返却し、
   * 再送できるなら false を返却する
   **/
  def isTimeout(): Boolean = {
    currentTimeout *= 2  // 次の再送タイムアウト
    retransmissions = (retransmissions + 1).toShort
    retransmissions > MAX_RETRANSMISSIONS
  }
}

この後登場するReliableDatagramPacket#sendReceiveメソッドを使ったパケット送信時に、このクラスのnewPacketメソッドを呼び出し、リトライ回数をリセットします。使用元でタイムアウトを検出した時は、isTimeoutメソッドを呼び出しリトライの上限を確認すると共に、次回の再送タイムアウトを伸ばしていきます。正常に通信ができた場合は、stoppedAtメソッドを呼び出しラウンドトリップタイムの推計値を求め、その平均偏差を更新します。

そして、ReliableDatagramSocketクラス。

class ReliableDatagramSocket(localAddr: SocketAddress) extends DatagramSocket(localAddr) {
  private var roundTripTimer: RoundTripTimer = new RoundTripTimer
  private var reinit: Boolean = false
  var sendSequenceNo: Long = 0L  // 送信の順序番号
  var recvSequenceNo: Long = 0L  // 受信の順序番号

  init()

  def this(port: Int) = this(new InetSocketAddress(port))
  def this(port: Int, localAddr: InetAddress) = this(new InetSocketAddress(localAddr, port))
  def this() = this(null)

  // 初期化
  private def init(): Unit =
    roundTripTimer = new RoundTripTimer

  // コネクトした後、接続用の統計を(再)初期化する
  override def connect(dest: InetAddress, port: Int): Unit = {
    super.connect(dest, port)
    init()
  }

  // コネクトした後、接続用の統計を(再)初期化する
  override def connect(dest: SocketAddress): Unit = {
    super.connect(dest)
    init()
  }

  @throws(classOf[IOException])
  def sendReceive(sendPacket: DatagramPacket, recvPacket: DatagramPacket): Unit = synchronized {
    // タイムアウト後に再初期化する
    if (reinit) {
      init()
      reinit = false
    }

    roundTripTimer.newPacket()

    val start = System.currentTimeMillis
    val sequenceNumber = sendSequenceNo

    // 最後のタイムアウト、または予期しない例外が起きるまで繰り返し
    // リトライ中は、同じsequenceNumberを使用し続ける
    Iterator.continually {
      Try {
        sendSequenceNo = sequenceNumber
        send(sendPacket)  // 例外を投げても良い

        val timeout = (roundTripTimer.currentTimeout * 1000.0 + 0.5).toInt
        val soTimeoutStart = System.currentTimeMillis

        @tailrec
        def receiveRetries(): Long = {
          // ソケットのタイムアウト値を、すでに経過した時間で調整する
          val soTimeout = (timeout - (System.currentTimeMillis.toInt - soTimeoutStart)).toInt
          setSoTimeout(soTimeout)
          receive(recvPacket)
          recvSequenceNo match {
            case `sequenceNumber` => recvSequenceNo // シーケンスが一致していれば、ループをストップ
            case _ => receiveRetries()
          }
        }

        receiveRetries()
      }
    }.takeWhile {
      case Success(_) => false  // シーケンスが一致していれば、ループをストップ
      case Failure(e: SocketTimeoutException) =>
        // タイムアウトで、リトライするかどうか
        if (roundTripTimer.isTimeout()) {
          reinit = true
          throw e
        } else {
          // リトライする
          true
        }
      case Failure(e) => throw e
    }.foreach { retry => } // シーケンスの不一致、またはタイムアウトのためリトライ

    // 正しいリプライを得た
    // タイマーを停止し、新たなRTTの値を計算する
    val ms = System.currentTimeMillis - start
    roundTripTimer.stoppedAt(ms)
  }

  // 順序番号を処理する
  @throws(classOf[IOException])
  override def receive(packet: DatagramPacket): Unit = {
    super.receive(packet)

    // 順序番号を読み、それをパケットから削除する
    val bais = new ByteArrayInputStream(packet.getData,
                                        packet.getOffset,
                                        packet.getLength)

    val dis = new DataInputStream(bais)
    recvSequenceNo = dis.readLong()
    val buffer = Array.ofDim[Byte](dis.available)
    dis.read(buffer)
    packet.setData(buffer, 0, buffer.size)
  }

  // 順序番号を処理する
  @throws(classOf[IOException])
  override def send(packet: DatagramPacket): Unit = {
    val baos = new ByteArrayOutputStream
    val dos = new DataOutputStream(baos)

    // 順序番号を書き出し、次にユーザデータを書き出す
    dos.writeLong(sendSequenceNo)
    sendSequenceNo += 1
    dos.write(packet.getData, packet.getOffset, packet.getLength)
    dos.flush()

    // この新しいデータで新たなパケットをコンストラクトし、送信する
    val data = baos.toByteArray
    val newPacket = new DatagramPacket(data, data.size, packet.getSocketAddress)
    super.send(newPacket)
  }
}

このクラスは、先ほど作成したRoundTripTimerクラスのインスタンスを保持します。継承元のDatagramSocketクラスのいくつかのメソッドをオーバーライドしていて、receiveメソッドでは受け取ったレスポンスからシーケンス番号を先に読み出し、呼び出し元が使うDatagramPacketのインスタンスからはシーケンス番号を取り除いています。sendメソッドでは、送信データの前にシーケンス番号を挿入し、DatagramPacketを新しく構築しています。

独自のメソッドがsendReceiveで、こちらはクライアントが使用することを意図しています。最初にRoundTripTimerクラスのnewPacketメソッドを呼び出し、リトライ回数をリセットします。

続いて、DatagramPacketを送信。

あとは、タイムアウト値を調整しつつ、受信したデータのシーケンス番号がおかしかったら再度受信待ち、タイムアウトしたらタイムアウト値を調整してリトライ、もしくはリトライ回数をオーバーしていれば諦めるという感じになっています。

正常なシーケンス番号を受信できた場合は、ループをストップしてRoundTripTimerクラスのstoppedAtメソッドを呼び出しラウンドトリップタイムの推計値を求め、その平均偏差を更新します。

ちなみに、元の書籍が間違っていたのは、ここのsendReceiveメソッドが

      // ソケットのタイムアウトをすでに経過した時間で調整する 
 
        int soTimeout = timeout.(int)
  (System.currentTimeMillis().soTimeoutStart);

みたいな感じで書かれていたのですが、正しくは

        // ソケットのタイムアウトをすでに経過した時間で調整する 
 
        int soTimeout = timeout-(int)
  (System.currentTimeMillis()-soTimeoutStart);

です。

そして、このReliableDatagramSocketを使用するEchoサーバ。

object ReliableEchoServer {
  def main(args: Array[String]): Unit = {
    val buffer = Array.ofDim[Byte](1024)
    val recvPacket = new DatagramPacket(buffer, buffer.size)

    for (socket <- new ReliableDatagramSocket(PORT)) {
      log("Reliable Single Thread Scala UDP Server", socket.getLocalSocketAddress, "Startup.")

      Iterator.from(1).foreach { i =>
        // 受信パケットのバッファを、最大にリセットする
        recvPacket.setData(buffer, 0, buffer.size)
        socket.receive(recvPacket)

        // リプライは、リクエストと同じ順序番号とする
        val seqNo = socket.recvSequenceNo
        socket.sendSequenceNo = seqNo

        // レスポンスとして、リクエストをエコーバックする
        socket.send(recvPacket)

        log(s"Receive & Send SeqNo[${socket.recvSequenceNo}]")
      }
    }
  }
}

これまでのクラスに比べると、だいぶ小さいですね。なお、シングルスレッドで動くことを前提にしています。ここでのポイントは、ReliableDatagramSocket#receiveでデータを受信して書き戻す際、つまりsendメソッドを呼ぶ前に送信用のシーケンス番号を設定していることですね。

最後は、Echoクライアント側。

object ReliableEchoClient {
  def main(args: Array[String]): Unit = {
    // 送信用バッファ
    val sendBuffer = Array.ofDim[Byte](1024)
    // 受信用バッファ
    val recvBuffer = Array.ofDim[Byte](1024)

    val sendPacket = new DatagramPacket(sendBuffer, sendBuffer.size)
    val recvPacket = new DatagramPacket(recvBuffer, recvBuffer.size)
    val address = new InetSocketAddress(PORT)
    sendPacket.setSocketAddress(address)
    recvPacket.setSocketAddress(address)

    for (socket <- new ReliableDatagramSocket) {
      log(s"Reliable Scala UDP Client", address, "Startup.")

      Iterator
        .continually(readLine())
        .takeWhile(word => word != null && word != "exit")
        .foreach { word =>
          // 送信用、受信用パケットをそれぞれ初期化
          sendPacket.setData(sendBuffer, 0, sendBuffer.size)
          recvPacket.setData(recvBuffer, 0, recvBuffer.size)

          val wordBinary = word.getBytes(StandardCharsets.UTF_8)
          System.arraycopy(wordBinary,
                           0,
                           sendPacket.getData,
                           0,
                           wordBinary.size)

          sendPacket.setData(sendPacket.getData, 0, wordBinary.size)

          // 送信 & 受信
          socket.sendReceive(sendPacket, recvPacket)

          val seqNo = socket.recvSequenceNo

          log(s"Receive Sequence No[$seqNo]")
          log(s"Received =>", new String(recvPacket.getData,
                                         recvPacket.getOffset,
                                         recvPacket.getLength,
                                         StandardCharsets.UTF_8))
        }
    }
  }
}

今までと同じように、コンソールから入力されたサーバへ送信するプログラムですが、データの送受信はReliableDatagramSocket#sendReceiveで一括して行っています。

動かしてみると、こんな感じですね。まずはサーバ起動。

$ scala ReliableEchoServer
[Sat Sep 21 23:29:24 JST 2013] Reliable Single Thread Scala UDP Server 0.0.0.0/0.0.0.0:50000 Startup.

クライアント起動。

$ scala ReliableEchoClient
[Sat Sep 21 23:29:47 JST 2013] Reliable Scala UDP Client 0.0.0.0/0.0.0.0:50000 Startup.

あとは、適当に文字列を入力していればサーバが応答してくれます。

Hello World
[Sat Sep 21 23:30:22 JST 2013] Receive Sequence No[0]
[Sat Sep 21 23:30:22 JST 2013] Received => Hello World
こんにちは、世界
[Sat Sep 21 23:30:26 JST 2013] Receive Sequence No[1]
[Sat Sep 21 23:30:26 JST 2013] Received => こんにちは、世界
Reliable UDP Echo.
[Sat Sep 21 23:30:35 JST 2013] Receive Sequence No[2]
[Sat Sep 21 23:30:35 JST 2013] Received => Reliable UDP Echo.
exit

サーバ側には、こんなログが。

[Sat Sep 21 23:30:22 JST 2013] Receive & Send SeqNo[0]
[Sat Sep 21 23:30:26 JST 2013] Receive & Send SeqNo[1]
[Sat Sep 21 23:30:35 JST 2013] Receive & Send SeqNo[2]

で、このプログラムですが、ReliableDatagramSocketがシーケンス番号を持ってしまっているので、複数の接続先には対応できません。あくまで、通信してくるクライアントはひとつの想定です。本来は、接続先ごとにシーケンス番号を管理すべきだと書籍にも書いてありましたしね。サーバプログラムがシングルスレッドになっているのは、これが理由です。

とはいえ、TCPソケットの様に、接続がSocketクラスのインスタンスみたいに表せるわけでもないので、そこは工夫が必要なところですよね。あと、ホントはパケット分割も考慮しないとですよね。

そのうち、JGroupsのソースも追ってみようかな…UDPとかは少しみましたけど、このあたりの信頼性担保のところは、まだ見れていません。

次のテーマは、UDPでNIOです。

最後は、今回書いたソースですよ。ちょっと長めのプログラムでしたね。
ReliableUdpClientServer.scala

import scala.annotation.tailrec
import scala.math.{abs, max, min}
import scala.util.{Failure, Success, Try}

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream, IOException}
import java.net.{DatagramPacket, DatagramSocket, InetAddress, InetSocketAddress, SocketAddress, SocketTimeoutException}
import java.nio.charset.StandardCharsets
import java.util.Date

import Reliabilities._

// このプログラムで、時間の単位はすべて「秒」

object Reliabilities {
  // タイムアウトの最小・最大値
  val MIN_RETRANSMIT_TIMEOUT: Int = 1
  val MAX_RETRANSMIT_TIMEOUT: Int = 64
  // ひとつのデータグラムの最大再送回数、3か4くらい
  val MAX_RETRANSMISSIONS: Int = 4

  // 通信先ポート
  val PORT: Int = 50000

  def log(msg: Any, msgs: Any*): Unit =
    println(s"[${new Date}] ${(msg :: msgs.toList).mkString(" ")}")

  // AutoCloseableをfor式で使えるようにするための、Implicit Class
  implicit class AutoCloseableWrapper[A <: AutoCloseable](val underlying: A) extends AnyVal {
    def foreach(fun: A => Unit): Unit =
      try {
        fun(underlying)
      } finally {
        underlying.close()
      }
  }
}

class RoundTripTimer {
  // 最も最近のラウンドトリップタイム(RTT)
  private var roundTripTime: Float = 0.0F
  // 平滑化したRTT
  private var smoothedTripTime: Float = 0.0F
  // 平滑化した標準偏差
  private var deviation: Float = 0.75F
  // 再送カウント:0, 1, 2
  private var retransmissions: Short = 0
  // 現在の再送タイムアウト
  var currentTimeout = minmax(calculateRetransmitTimeout)

  // 再送タイムアウトを返す
  private def calculateRetransmitTimeout: Int =
    (smoothedTripTime + 4.0 * deviation).toInt

  // 上限のある再送タイムアウトを返す
  private def minmax(rto: Float): Float =
    min(max(rto, MIN_RETRANSMIT_TIMEOUT), MAX_RETRANSMIT_TIMEOUT)

  // 新たなパケットを送信する度に、現在の再送カウントを初期化する
  def newPacket(): Unit =
    retransmissions = 0

  /**
   * 成功した受信の直後に呼ばれ、ラウンドトリップタイムを計算し、
   * 次に平滑化したラウンドトリップとその分散(偏差)を計算する
   **/
  def stoppedAt(ms: Long): Unit = {
    // このパケットのラウンドトリップを計算する
    roundTripTime = ms/1000

    // ラウンドリップタイムの推計値とその平均偏差を更新する
    val delta = roundTripTime - smoothedTripTime
    smoothedTripTime += (delta / 8.0).toFloat
    deviation += ((abs(delta) - deviation) / 4.0).toFloat

    // 現在のタイムアウトを再計算する
    currentTimeout= minmax(calculateRetransmitTimeout)
  }

  /**
   * タイムアウトが生じた後に呼ばれる。ギブアップすべき時間なら true を返却し、
   * 再送できるなら false を返却する
   **/
  def isTimeout(): Boolean = {
    currentTimeout *= 2  // 次の再送タイムアウト
    retransmissions = (retransmissions + 1).toShort
    retransmissions > MAX_RETRANSMISSIONS
  }
}

class ReliableDatagramSocket(localAddr: SocketAddress) extends DatagramSocket(localAddr) {
  private var roundTripTimer: RoundTripTimer = new RoundTripTimer
  private var reinit: Boolean = false
  var sendSequenceNo: Long = 0L  // 送信の順序番号
  var recvSequenceNo: Long = 0L  // 受信の順序番号

  init()

  def this(port: Int) = this(new InetSocketAddress(port))
  def this(port: Int, localAddr: InetAddress) = this(new InetSocketAddress(localAddr, port))
  def this() = this(null)

  // 初期化
  private def init(): Unit =
    roundTripTimer = new RoundTripTimer

  // コネクトした後、接続用の統計を(再)初期化する
  override def connect(dest: InetAddress, port: Int): Unit = {
    super.connect(dest, port)
    init()
  }

  // コネクトした後、接続用の統計を(再)初期化する
  override def connect(dest: SocketAddress): Unit = {
    super.connect(dest)
    init()
  }

  @throws(classOf[IOException])
  def sendReceive(sendPacket: DatagramPacket, recvPacket: DatagramPacket): Unit = synchronized {
    // タイムアウト後に再初期化する
    if (reinit) {
      init()
      reinit = false
    }

    roundTripTimer.newPacket()

    val start = System.currentTimeMillis
    val sequenceNumber = sendSequenceNo

    // 最後のタイムアウト、または予期しない例外が起きるまで繰り返し
    // リトライ中は、同じsequenceNumberを使用し続ける
    Iterator.continually {
      Try {
        sendSequenceNo = sequenceNumber
        send(sendPacket)  // 例外を投げても良い

        val timeout = (roundTripTimer.currentTimeout * 1000.0 + 0.5).toInt
        val soTimeoutStart = System.currentTimeMillis

        @tailrec
        def receiveRetries(): Long = {
          // ソケットのタイムアウト値を、すでに経過した時間で調整する
          val soTimeout = (timeout - (System.currentTimeMillis.toInt - soTimeoutStart)).toInt
          setSoTimeout(soTimeout)
          receive(recvPacket)
          recvSequenceNo match {
            case `sequenceNumber` => recvSequenceNo // シーケンスが一致していれば、ループをストップ
            case _ => receiveRetries()
          }
        }

        receiveRetries()
      }
    }.takeWhile {
      case Success(_) => false  // シーケンスが一致していれば、ループをストップ
      case Failure(e: SocketTimeoutException) =>
        // タイムアウトで、リトライするかどうか
        if (roundTripTimer.isTimeout()) {
          reinit = true
          throw e
        } else {
          // リトライする
          true
        }
      case Failure(e) => throw e
    }.foreach { retry => } // シーケンスの不一致、またはタイムアウトのためリトライ

    // 正しいリプライを得た
    // タイマーを停止し、新たなRTTの値を計算する
    val ms = System.currentTimeMillis - start
    roundTripTimer.stoppedAt(ms)
  }

  // 順序番号を処理する
  @throws(classOf[IOException])
  override def receive(packet: DatagramPacket): Unit = {
    super.receive(packet)

    // 順序番号を読み、それをパケットから削除する
    val bais = new ByteArrayInputStream(packet.getData,
                                        packet.getOffset,
                                        packet.getLength)

    val dis = new DataInputStream(bais)
    recvSequenceNo = dis.readLong()
    val buffer = Array.ofDim[Byte](dis.available)
    dis.read(buffer)
    packet.setData(buffer, 0, buffer.size)
  }

  // 順序番号を処理する
  @throws(classOf[IOException])
  override def send(packet: DatagramPacket): Unit = {
    val baos = new ByteArrayOutputStream
    val dos = new DataOutputStream(baos)

    // 順序番号を書き出し、次にユーザデータを書き出す
    dos.writeLong(sendSequenceNo)
    sendSequenceNo += 1
    dos.write(packet.getData, packet.getOffset, packet.getLength)
    dos.flush()

    // この新しいデータで新たなパケットをコンストラクトし、送信する
    val data = baos.toByteArray
    val newPacket = new DatagramPacket(data, data.size, packet.getSocketAddress)
    super.send(newPacket)
  }
}

object ReliableEchoServer {
  def main(args: Array[String]): Unit = {
    val buffer = Array.ofDim[Byte](1024)
    val recvPacket = new DatagramPacket(buffer, buffer.size)

    for (socket <- new ReliableDatagramSocket(PORT)) {
      log("Reliable Single Thread Scala UDP Server", socket.getLocalSocketAddress, "Startup.")

      Iterator.from(1).foreach { i =>
        // 受信パケットのバッファを、最大にリセットする
        recvPacket.setData(buffer, 0, buffer.size)
        socket.receive(recvPacket)

        // リプライは、リクエストと同じ順序番号とする
        val seqNo = socket.recvSequenceNo
        socket.sendSequenceNo = seqNo

        // レスポンスとして、リクエストをエコーバックする
        socket.send(recvPacket)

        log(s"Receive & Send SeqNo[${socket.recvSequenceNo}]")
      }
    }
  }
}

object ReliableEchoClient {
  def main(args: Array[String]): Unit = {
    // 送信用バッファ
    val sendBuffer = Array.ofDim[Byte](1024)
    // 受信用バッファ
    val recvBuffer = Array.ofDim[Byte](1024)

    val sendPacket = new DatagramPacket(sendBuffer, sendBuffer.size)
    val recvPacket = new DatagramPacket(recvBuffer, recvBuffer.size)
    val address = new InetSocketAddress(PORT)
    sendPacket.setSocketAddress(address)
    recvPacket.setSocketAddress(address)

    for (socket <- new ReliableDatagramSocket) {
      log(s"Reliable Scala UDP Client", address, "Startup.")

      Iterator
        .continually(readLine())
        .takeWhile(word => word != null && word != "exit")
        .foreach { word =>
          // 送信用、受信用パケットをそれぞれ初期化
          sendPacket.setData(sendBuffer, 0, sendBuffer.size)
          recvPacket.setData(recvBuffer, 0, recvBuffer.size)

          val wordBinary = word.getBytes(StandardCharsets.UTF_8)
          System.arraycopy(wordBinary,
                           0,
                           sendPacket.getData,
                           0,
                           wordBinary.size)

          sendPacket.setData(sendPacket.getData, 0, wordBinary.size)

          // 送信 & 受信
          socket.sendReceive(sendPacket, recvPacket)

          val seqNo = socket.recvSequenceNo

          log(s"Receive Sequence No[$seqNo]")
          log(s"Received =>", new String(recvPacket.getData,
                                         recvPacket.getOffset,
                                         recvPacket.getLength,
                                         StandardCharsets.UTF_8))
        }
    }
  }
}

スパム対策のためのダミーです。もし見えても何も入力しないでください
ゲスト


画像認証

idトラックバック - http://d.hatena.ne.jp/Kazuhira/20130921