giftee Tech Blog

ギフティの開発を支えるメンバーの技術やデザイン、プロダクトマネジメントの情報を発信しています。

gRPC server with Go

こんにちは!ギフティでエンジニアをやっている中屋(@nakaryo79)です。
今まで長らく iphone7 を愛用してきたのですが、ちょっとした気の迷いと勢いで iphone13 をポチり、今は到着を心待ちにしています(2週間は長いって!!)。
そしてバージョンアップといえば、Go 1.17 リリースされましたね!今日は Go 1.17 の話ではないですが、 Go に関連したトピックで筆を取りました。

はじめに

弊社のプロダクトの多くは Ruby で書かれていますが、一部開発中のプロダクトに関しては backend の実装言語に Go を採用しているものもあります。
また、責務分離の観点から単一モノリスをマイクロサービスに分割する動きもあり、我々のチームで開発しているシステムは、サービス間通信に従来の REST の仕組みではなく gRPC を採用することを検討しています。
社内では通信 Framework に GraphQL を採用しているところも多いんですが、そこまでリッチなクエリを組み立てるユースケースがないかつ、速度面や Go との親和性及び周辺エコシステムの豊富さ、その他諸々の理由から gRPC を選択しました。
で、実装するにあたって、gRPC のサーバーってどう動いてるの? net http サーバーと一緒で内部的にはリクエストごとに goroutine 作ってるんだよな?worker の設定とかしなくていいんだっけ?となったので実際にコードを読みながら調べてみました。

対象読者

  • Go が読める人、興味がある人
    • gRPC と言っても、Go のサーバー側実装の話しかしません

TL;DR

え〜、worker の台数設定に関して先に結論を書いておきますと、2021年10月現在では、gRPC worker 台数の設定はしなくてもいいようです。
設定するためのオプションは存在しますが、そもそもライブラリ側の実装が未だ beta (αかな?)の状態らしいです。

本題

Go で gRPC のサーバーを立てようとすると大抵以下のようなコードを書くと思います(grpc-goサンプルをそのまま持ってきただけです)。

import (
    "google.golang.org/grpc"
    // 省略
)

func main() {
    lis, err := net.Listen("tcp", port)
    if err != nil {
        log.Fatalf("failed to listen: %v", err)
    }
    s := grpc.NewServer()
    pb.RegisterGreeterServer(s, &server{})
    log.Printf("server listening at %v", lis.Addr())
    if err := s.Serve(lis); err != nil {
        log.Fatalf("failed to serve: %v", err)
    }
}

なお、執筆時点(2021年10月)で latest stable の grpc-go v1.41 のコードを前提としています。 上記の s.Serve (= grpc.Server)がどのような仕組みでリクエストを処理するかというのを追っていきます。

Server.Serve の実装

https://github.com/grpc/grpc-go/blob/a671967dfbaab779d37fd7e597d9248f13806087/server.go#L733-L825

// Serve accepts incoming connections on the listener lis, creating a new
// ServerTransport and service goroutine for each. The service goroutines
// read gRPC requests and then call the registered handlers to reply to them.
// Serve returns when lis.Accept fails with fatal errors.  lis will be closed when
// this method returns.
// Serve will return a non-nil error unless Stop or GracefulStop is called.
func (s *Server) Serve(lis net.Listener) error {
    s.mu.Lock()
    s.printf("serving")
    s.serve = true
    if s.lis == nil {
        // Serve called after Stop or GracefulStop.
        s.mu.Unlock()
        lis.Close()
        return ErrServerStopped
    }

    s.serveWG.Add(1)
    defer func() {
        s.serveWG.Done()
        if s.quit.HasFired() {
            // Stop or GracefulStop called; block until done and return nil.
            <-s.done.Done()
        }
    }()

    ls := &listenSocket{Listener: lis}
    s.lis[ls] = true

    if channelz.IsOn() {
        ls.channelzID = channelz.RegisterListenSocket(ls, s.channelzID, lis.Addr().String())
    }
    s.mu.Unlock()

    defer func() {
        s.mu.Lock()
        if s.lis != nil && s.lis[ls] {
            ls.Close()
            delete(s.lis, ls)
        }
        s.mu.Unlock()
    }()

    var tempDelay time.Duration // how long to sleep on accept failure

    for {
        rawConn, err := lis.Accept()
        if err != nil {
            if ne, ok := err.(interface {
                Temporary() bool
            }); ok && ne.Temporary() {
                if tempDelay == 0 {
                    tempDelay = 5 * time.Millisecond
                } else {
                    tempDelay *= 2
                }
                if max := 1 * time.Second; tempDelay > max {
                    tempDelay = max
                }
                s.mu.Lock()
                s.printf("Accept error: %v; retrying in %v", err, tempDelay)
                s.mu.Unlock()
                timer := time.NewTimer(tempDelay)
                select {
                case <-timer.C:
                case <-s.quit.Done():
                    timer.Stop()
                    return nil
                }
                continue
            }
            s.mu.Lock()
            s.printf("done serving; Accept = %v", err)
            s.mu.Unlock()

            if s.quit.HasFired() {
                return nil
            }
            return err
        }
        tempDelay = 0
        // Start a new goroutine to deal with rawConn so we don't stall this Accept
        // loop goroutine.
        //
        // Make sure we account for the goroutine so GracefulStop doesn't nil out
        // s.conns before this conn can be added.
        s.serveWG.Add(1)
        go func() {
            s.handleRawConn(lis.Addr().String(), rawConn)
            s.serveWG.Done()
        }()
    }
}

コードコメントが丁寧に書かれていますね。 色々やってるように見えますが、受け取った net listener からサーバーを立てるのに必要な情報を取得したりなんなりして、無限ループに入ります。 ループの中では listner が受け付けた tcp コネクションを検知して、このコネクションの確立を試みます。この際、ループがブロッキングされないようかつ Graceful stop できるよう別 goroutine でその接続を処理しようとします。
その goroutine は handleRawConn メソッドを実行するよう指示され、処理が終わったら、waitGroup を done にして goroutine を破棄します。ここでの WaitGroup を Wait するのは Server.StopServer.GracefulStop であり、コネクションを確立している最中にぶった切られないように監視している、というわけです。各コネクションの処理が完了したか、はまた別のところで監視が行われます。

上記を延々ループで待ち受ける、というのが Serve メソッドの役割です。
これ自体はリスナーを監視して、コネクションを確立する準備をしているだけなので、実態は handleRawConn 以降の呼び出しを見ていくことになります。

Server.handleRawConn の実装

https://github.com/grpc/grpc-go/blob/a671967dfbaab779d37fd7e597d9248f13806087/server.go#L827-L850

handleRawConn のコメントには、受け入れたばかりのコネクションをハンドルする goroutine を fork する、と書いてあります。

// handleRawConn forks a goroutine to handle a just-accepted connection that
// has not had any I/O performed on it yet.
func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) {
    if s.quit.HasFired() {
        rawConn.Close()
        return
    }
    rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout))

    // Finish handshaking (HTTP2)
    st := s.newHTTP2Transport(rawConn)
    rawConn.SetDeadline(time.Time{})
    if st == nil {
        return
    }

    if !s.addConn(lisAddr, st) {
        return
    }
    go func() {
        s.serveStreams(st)
        s.removeConn(lisAddr, st)
    }()
}

gRPC は通信プロトコルとして HTTP2 を採用しています。 newHTTP2Transport では、TLS ハンドシェイクなどの処理をして、 ServerTransport インターフェースの実装として HTTP2 の実装を返します。 最後にさらに別 goroutine を作って Server.serveStreams というメソッドに、今作った ServerTransport を渡します。ここまでくるとコネクションが確立しているので、リターンして先程の WaitGroup を DONE にできるということです。

Server.serveStreams の実装

https://github.com/grpc/grpc-go/blob/a671967dfbaab779d37fd7e597d9248f13806087/server.go#L902-L934

ここでの Stream というのは grpc の Unary/Stream のことではなく、逐次流れてくる io steram のことを指しています。
引数の ServerTransportHandleStreams にファンクションを渡しています。
ここでやっと、これを調べるきっかけになった server option の numServerWorkers の設定が登場してきます。

func (s *Server) serveStreams(st transport.ServerTransport) {
    defer st.Close()
    var wg sync.WaitGroup

    var roundRobinCounter uint32
    st.HandleStreams(func(stream *transport.Stream) {
        wg.Add(1)
        if s.opts.numServerWorkers > 0 {
            data := &serverWorkerData{st: st, wg: &wg, stream: stream}
            select {
            case s.serverWorkerChannels[atomic.AddUint32(&roundRobinCounter, 1)%s.opts.numServerWorkers] <- data:
            default:
                // If all stream workers are busy, fallback to the default code path.
                go func() {
                    s.handleStream(st, stream, s.traceInfo(st, stream))
                    wg.Done()
                }()
            }
        } else {
            go func() {
                defer wg.Done()
                s.handleStream(st, stream, s.traceInfo(st, stream))
            }()
        }
    }, func(ctx context.Context, method string) context.Context {
        if !EnableTracing {
            return ctx
        }
        tr := trace.New("grpc.Recv."+methodFamily(method), method)
        return trace.NewContext(ctx, tr)
    })
    wg.Wait()
}

ワーカー数が設定がない場合は愚直に goroutine を作って Server.handleStream を呼びます。
ワーカー数が設定されている場合は、serverWorkerData という構造体インスタンスを作成し、select でゴニョゴニョやっています。 select case で待ち構えている s.serverWorkerChannels は channel の slice になっています。 こいつは NewSever のタイミングで初期化され、ワーカー数の設定があれば、その長さ分の channel の slice がセットされます(以下コード参照)。

https://github.com/grpc/grpc-go/blob/a671967dfbaab779d37fd7e597d9248f13806087/server.go#L543-L551

// initServerWorkers creates worker goroutines and channels to process incoming
// connections to reduce the time spent overall on runtime.morestack.
func (s *Server) initServerWorkers() {
    s.serverWorkerChannels = make([]chan *serverWorkerData, s.opts.numServerWorkers)
    for i := uint32(0); i < s.opts.numServerWorkers; i++ {
        s.serverWorkerChannels[i] = make(chan *serverWorkerData)
        go s.serverWorker(s.serverWorkerChannels[i])
    }
}

そのそれぞれの channel では worker goroutine が実行されていて、受信したデータを処理するようになっています。つまりこのコネクションで入ってきたデータを channel に流し込むとすでにプールされているワーカーがいい感じに処理してくれるわけです。どの channel に流すかは roundRobinCounter という変数から、どうやらラウンドロビン方式になっていそうです。

詳細まで追ってないですが、以下を読む限りだと、
https://github.com/grpc/grpc-go/blob/a671967dfbaab779d37fd7e597d9248f13806087/server.go#L522-L541

// serverWorkers blocks on a *transport.Stream channel forever and waits for
// data to be fed by serveStreams. This allows different requests to be
// processed by the same goroutine, removing the need for expensive stack
// re-allocations (see the runtime.morestack problem [1]).

ワーカー設定をするとワーカープールが有効になって毎回 goroutine を新規に作らず、既にあるものを使い回すので、 goroutine 作って破棄するだけのオーバーヘッドがなくなり効率がいいね、という話らしいです。どうやらワーカー数分だけチャネル作っといて、上手い感じにやるっぽい。
goroutine 作るのにそんなにコストかかるんだっけというのもあるので、そんなに変わるのか?というお気持ちがあったりなかったりしますが、この辺はロードテストをやってみないとわからないですね。やる気が出たら調べてみます。

ただし、全てのワーカーがビージーな場合は、ワーカー設定がない時と同じ挙動を取るっぽいので、どうやら生成される goroutine の上限をコントロールできるわけではないようです。
また、ワーカー数を設定するオプションは Experimental ステータスとなっていて、まだ stable ではないらしいので、利用する場合は破壊的変更が入る可能性があることに注意が必要です。
https://github.com/grpc/grpc-go/blob/a671967dfbaab779d37fd7e597d9248f13806087/server.go#L496-L513

ServerTransport.HandleStreams の実装

ServerTransport の実装は HTTP2 でしたね。ですので、実装の詳細は http2Server の実装を見る必要があります。

https://github.com/grpc/grpc-go/blob/a671967dfbaab779d37fd7e597d9248f13806087/internal/transport/http2_server.go#L551-L614

// HandleStreams receives incoming streams using the given handler. This is
// typically run in a separate goroutine.
// traceCtx attaches trace to ctx and returns the new context.
func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) {
    defer close(t.readerDone)
    for {
        t.controlBuf.throttle()
        frame, err := t.framer.fr.ReadFrame()
        atomic.StoreInt64(&t.lastRead, time.Now().UnixNano())
        if err != nil {
            if se, ok := err.(http2.StreamError); ok {
                if logger.V(logLevel) {
                    logger.Warningf("transport: http2Server.HandleStreams encountered http2.StreamError: %v", se)
                }
                t.mu.Lock()
                s := t.activeStreams[se.StreamID]
                t.mu.Unlock()
                if s != nil {
                    t.closeStream(s, true, se.Code, false)
                } else {
                    t.controlBuf.put(&cleanupStream{
                        streamID: se.StreamID,
                        rst:      true,
                        rstCode:  se.Code,
                        onWrite:  func() {},
                    })
                }
                continue
            }
            if err == io.EOF || err == io.ErrUnexpectedEOF {
                t.Close()
                return
            }
            if logger.V(logLevel) {
                logger.Warningf("transport: http2Server.HandleStreams failed to read frame: %v", err)
            }
            t.Close()
            return
        }
        switch frame := frame.(type) {
        case *http2.MetaHeadersFrame:
            if t.operateHeaders(frame, handle, traceCtx) {
                t.Close()
                break
            }
        case *http2.DataFrame:
            t.handleData(frame)
        case *http2.RSTStreamFrame:
            t.handleRSTStream(frame)
        case *http2.SettingsFrame:
            t.handleSettings(frame)
        case *http2.PingFrame:
            t.handlePing(frame)
        case *http2.WindowUpdateFrame:
            t.handleWindowUpdate(frame)
        case *http2.GoAwayFrame:
            // TODO: Handle GoAway from the client appropriately.
        default:
            if logger.V(logLevel) {
                logger.Errorf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame)
            }
        }
    }
}

無限ループしながら、HTTP2 のフレームを読み出して、そのタイプによって処理を分岐させています。
HTTP2 には HTTP1.X と違い、ストリームやバイナリフレーミングという概念があり、一つのコネクションで同時に並行して複数のリクエスト/レスポンスを処理することができるようになっています。ここは tcp コネクションに対する処理ですが、goroutine を駆使して、ストリームに流れてくるリクエストデータを非同期でうまいこと捌けるような実装になっています。
HTTP2 の話をし出すとそれだけで記事が一本書けてしまうので、ここでは割愛します(というか自分もそんなに詳しくない...)。

さて、この HandleStreams の第一引数に Server.handleStream を実行するファンクションを渡していたことを思い出してください。具体的なリクエストヘッダの解析はそこで行われます。

Server.handleStream の実装

https://github.com/grpc/grpc-go/blob/a671967dfbaab779d37fd7e597d9248f13806087/server.go#L1578-L1641

リクエストの詳細(メソッドやらリクエストの種類(Unary or Stream)やら)を解析します。
これより先(Server.processUnaryRPC or Server.processStreamingRPC)は具体的なペイロードの解析をして、サーバーに実装されている rpc メソッドを呼ぶだけです。

func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) {
    sm := stream.Method()
    if sm != "" && sm[0] == '/' {
        sm = sm[1:]
    }
    pos := strings.LastIndex(sm, "/")
    if pos == -1 {
        if trInfo != nil {
            trInfo.tr.LazyLog(&fmtStringer{"Malformed method name %q", []interface{}{sm}}, true)
            trInfo.tr.SetError()
        }
        errDesc := fmt.Sprintf("malformed method name: %q", stream.Method())
        if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil {
            if trInfo != nil {
                trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
                trInfo.tr.SetError()
            }
            channelz.Warningf(logger, s.channelzID, "grpc: Server.handleStream failed to write status: %v", err)
        }
        if trInfo != nil {
            trInfo.tr.Finish()
        }
        return
    }
    service := sm[:pos]
    method := sm[pos+1:]

    srv, knownService := s.services[service]
    if knownService {
        if md, ok := srv.methods[method]; ok {
            s.processUnaryRPC(t, stream, srv, md, trInfo)
            return
        }
        if sd, ok := srv.streams[method]; ok {
            s.processStreamingRPC(t, stream, srv, sd, trInfo)
            return
        }
    }
    // Unknown service, or known server unknown method.
    if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil {
        s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo)
        return
    }
    var errDesc string
    if !knownService {
        errDesc = fmt.Sprintf("unknown service %v", service)
    } else {
        errDesc = fmt.Sprintf("unknown method %v for service %v", method, service)
    }
    if trInfo != nil {
        trInfo.tr.LazyPrintf("%s", errDesc)
        trInfo.tr.SetError()
    }
    if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil {
        if trInfo != nil {
            trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
            trInfo.tr.SetError()
        }
        channelz.Warningf(logger, s.channelzID, "grpc: Server.handleStream failed to write status: %v", err)
    }
    if trInfo != nil {
        trInfo.tr.Finish()
    }
}

コードを抜粋するとこの辺です。s.services[service] はリクエストされてきたメソッドがこの gRPC サーバーに登録されているかを確認しています。メソッドがあればそのメソッドを呼び出します。

srv, knownService := s.services[service]
if knownService {
    if md, ok := srv.methods[method]; ok {
        s.processUnaryRPC(t, stream, srv, md, trInfo)
        return
    }
    if sd, ok := srv.streams[method]; ok {
        s.processStreamingRPC(t, stream, srv, sd, trInfo)
        return
    }
}

終わりに

サラッとですが、gRPC サーバーの仕組みを追ってみました。
(Go はコードが追いやすいかつメタ要素も薄いので oss のコードでも読みやすくていいですね...)
ライブラリやフレームワークは中身の詳細を知らなくても、それ通り書けば動作させることができるのがメリットではありますが、時には仕組みを知ることも大切ですね!(今回は興味9割でした笑)