我希望暴露一个接口,供其它程序调用神经网络。本打算 flask 弄个 restful 接口,不过想了下不如尝试一下 protobuf.

Protobuf 的基本原理

Protocol Buffers (简称 Protobuf,PB)是一种序列化规范。和常见的 JSON 一样,都能编码对象。但也有其优势,所以广泛应用于微服务等技术中。

JSON:可读性强,兼容性强,序列化、反序列化性能差。传输占用大(但是可以 GZIP 压缩)

Thrift:可读性差,解析性能强。Facebook 提供,和框架耦合性强。

Protobuf:可读性差,解析性能强(据说其实一言难尽),各种框架可用,传输层无关。

Protobuf 采用了 Varint 算法进行压缩。若数据类型为 sint32 或 sint64 ,用 ZigZag 算法进行数据压缩。

Protobuf 采用 (field_number << 3) | wire_type 的方式记录数据类型和顺序。field_number 相当于顺序,wire_type 相当于数据类型。

深入微服务:3. Protobuf 为啥比 JSON、XML 牛? - 潇洒哥和黑大帅 (printlove.cn)

例子:使用 gRPC 前后端分离的手写识别程序

完整代码位于 pluveto/digit-recog (github.com)

我们的神经网络能够将向量形式的图片识别出为一个整数。

  • 输入:一个 784 维的向量。

  • 输出:一个整数。

消息定义如下:

syntax = "proto3";

package recog;

service DigitRecog {
    rpc Recog (RecogRequest) returns (RecogResponse) {}
}

message RecogRequest {
    repeated int32 px = 1 [packed=true];
}

message RecogResponse {
    int32 num = 1;
}

message Error {
  string detail = 1;
};

Python 搭建 gRPC 服务

安装依赖

pip install grpcio
pip install grpcio-tools

代码生成:

python -m grpc_tools.protoc -I./ --python_out=. --grpc_python_out=. ./recog.proto 

app.py 代码如下:

from concurrent import futures
import time

import grpc

import recog_pb2
import recog_pb2_grpc

from tensorflow import keras
import numpy as np

model = None

def load_model():
    global model
    model = keras.models.load_model('demo_model.h5')

def gen_submission():
    data = np.genfromtxt('test.csv',delimiter=',',skip_header=1)
    with open('submission.csv', mode="w", encoding="utf-8") as fout:
        fout.write("%s,%s\n" % ("ImageId", "Label"))
        for i in range(len(data)):
            fout.write("%s,%s\n" % (i + 1, recog(data[i])))


def recog(img):
    """
    input: np.array
    """
    img = img.reshape((1, 28, 28, 1)) / 255.0
    predict = model.predict(img)[0]
    maxPredict = -1
    maxPredictIndex = -1
    for i in range(9):
        if predict[i] > maxPredict:
            maxPredict = predict[i]
            maxPredictIndex = i
    return maxPredictIndex


print("Loading model...")
load_model()
print("Model loaded")


class DigitRecog(recog_pb2_grpc.DigitRecogServicer):
    def Recog(self, request, context):
        print("req: ")
        print(request.px)
        resp = recog(np.array(request.px))        
        print("resp: ")
        print(resp)
        return recog_pb2.RecogResponse(num=resp)


def serve():
    # gRPC server
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
    recog_pb2_grpc.add_DigitRecogServicer_to_server(DigitRecog(), server)
    server.add_insecure_port('localhost:7042')
    server.start()
    print("Listening...")
    server.wait_for_termination()

print("Loading server...")
serve()
# print("Generating submission")
# gen_submission()
# print("Done.")

C# 作为 gRPC 客户端

安装依赖

Install-Package Grpc.Net.Client
Install-Package Google.Protobuf
Install-Package Grpc.Tools

新建目录 Protos,复制 proto 文件进来。添加一行:

option csharp_namespace = "GrpcRecogClient";

编辑 LBNN.GUI.csproj 项目文件,添加:

  <ItemGroup>
    <Protobuf Include="Protos\recog.proto" GrpcServices="Client" />
  </ItemGroup>
PS C:\Repo\digit-recognizer\LBNN> .\packages\Grpc.Tools.2.42.0\tools\windows_x64\protoc.exe -I .\LBNN.GUI\Protos\ --csharp_out=.\LBNN.GUI\Protos\ .\LBNN.GUI\Protos\recog.proto
protoc.exe -I <INCLUDE> --csharp_out=<CS_CLASS_OUT_DIR> --grpc_out=<CS_CLASS_OUT_DIR> --plugin=protoc-gen-grpc=grpc_csharp_plugin.exe <PROTO_FILE> 
.\packages\Grpc.Tools.2.42.0\tools\windows_x64\protoc.exe -I .\LBNN.GUI\Protos\ --csharp_out=.\LBNN.GUI\Protos\ --grpc_out=.\LBNN.GUI\Protos\ --plugin=protoc-gen-grpc=.\packages\Grpc.Tools.2.42.0\tools\windows_x64\grpc_csharp_plugin.exe .\LBNN.GUI\Protos\recog.proto

也可以通过设置 Build Action 编译,但是我这么设置之后出现 编译不报错,但是编译失败 的诡异情况。

image_up_1639756047be2ba3e5.jpg

..\packages\Grpc.Tools.2.27.0\tools\windows_x64\protoc.exe -I.\Proto\ --csharp_out ./ --grpc_out ./ --plugin=protoc-gen-grpc=grpc_csharp_plugin.exe .\Proto\routine.proto

请求的关键代码为:

        private async void RecogButton_Click(object sender, RoutedEventArgs e)
        {
            var channel = GrpcChannel.ForAddress("http://localhost:7042");
            var client = new Proto.DigitRecog.DigitRecogClient(channel);
            AppContext.SetSwitch("System.Net.Http.SocketsHttpHandler.Http2UnencryptedSupport", true);
            var reply = await client.RecogAsync(
                new Proto.RecogRequest
                {
                    Px = { track }
                }
            );
            MessageBox.Show(reply.Num.ToString());
        }

例子:长连接的 Go gRPC 服务

安装依赖

# gRPC
go get -u google.golang.org/grpc
# ProtoBuf
go get -u github.com/golang/protobuf/protoc-gen-go

代码生成

protoc -I gmem/ gmem/UserService.proto --go_out=plugins=grpc:gmem

参考

绿色记忆:gRPC 学习笔记 (gmem.cc)