构建 TensorFlow Serving Java 客户端

TensorFlow Serving Java Client 所需的代码存放在 https://github.com/tensorflow/tensorflowhttps://github.com/tensorflow/serving 项目中. 代码是用 proto 方式编写的, 因此我们需要先安装 protoc 工具, 从 proto 文件生成 Java 代码.

Mac 安装 protoc 3

protoc 3 已经有编译好的版本, 直接从官方网站 https://github.com/google/protobuf/releases, 下载编译好的安装包 protoc-3.5.1-osx-x86_64.zip, 然后将命令复制到 /usr/local/bin 即可.

cd /tmp
mv protoc-3.5.1-osx-x86_64.zip .
unzip protoc-3.5.1-osx-x86_64.zip
cd bin
cp protoc /usr/local/bin/

检验是否安装成功

$ protoc --version
libprotoc 3.5.1

Mac 安装 protoc 2

protoc 2 一般用不到了, 考虑到有可能 proto 文件是用 proto 2 方式编写的, 这里介绍一下 protoc 2 的安装. 注意, protoc 2 官方没有编译好的版本, 只能手动编译安装. 从官方 https://github.com/google/protobuf/releases, 下载 protobuf-2.6.1

编译安装

cd /opt
mv protobuf-2.6.1.tar.gz .
tar -xkzvf protobuf-2.6.1.tar.gz
cd protobuf-2.6.1
./configure && make && make install

检验是否安装成功

$ protoc --version
libprotoc 2.6.1

通过 Maven 插件, 生成 Java 代码

新建一个 Maven 项目, 比如 tensorflow-serving-api, 加入以下 Maven 配置.

<build>
  <plugins>
    <plugin>
      <groupId>org.xolstice.maven.plugins</groupId>
      <artifactId>protobuf-maven-plugin</artifactId>
      <version>0.5.1</version>
      <configuration>
        <protocExecutable>/usr/local/bin/protoc</protocExecutable>
      </configuration>
      <executions>
        <execution>
          <goals>
            <goal>compile</goal>
            <goal>test-compile</goal>
          </goals>
        </execution>
      </executions>
    </plugin>
  </plugins>
</build>

<dependencies>
  <dependency>
    <groupId>com.google.protobuf</groupId>
    <artifactId>protobuf-java</artifactId>
    <version>3.5.1</version>
  </dependency>
  <dependency>
    <groupId>io.grpc</groupId>
    <artifactId>grpc-stub</artifactId>
    <version>1.12.0</version>
  </dependency>
  <dependency>
    <groupId>io.grpc</groupId>
    <artifactId>grpc-protobuf</artifactId>
    <version>1.12.0</version>
  </dependency>
  <dependency>
    <groupId>io.grpc</groupId>
    <artifactId>grpc-netty</artifactId>
    <version>1.12.0</version>
  </dependency>
</dependencies>

然后使用命令 mvn protobuf:compile 即可将项目中的 *.proto 文件生成 Java 代码. 第一次执行该命令, 会提示缺少 src/main/proto 这个目录, 创建好以后, 从前面提到的 2 个项目中筛选出如下文件.

src/main/proto
├── tensorflow
│   └── core
│       ├── example
│       │   ├── example.proto
│       │   └── feature.proto
│       ├── framework
│       │   ├── attr_value.proto
│       │   ├── function.proto
│       │   ├── graph.proto
│       │   ├── node_def.proto
│       │   ├── op_def.proto
│       │   ├── resource_handle.proto
│       │   ├── tensor.proto
│       │   ├── tensor_shape.proto
│       │   ├── types.proto
│       │   └── versions.proto
│       └── protobuf
│           ├── meta_graph.proto
│           └── saver.proto
└── tensorflow_serving
    └── apis
        ├── classification.proto
        ├── get_model_metadata.proto
        ├── inference.proto
        ├── input.proto
        ├── model.proto
        ├── predict.proto
        ├── prediction_service.proto
        └── regression.proto

现在执行 mvn clean protobuf:compile, 或者 mvn clean package 即可生成一个可用的 tensorflow-serving-api Jar 包.

TensorFlow Serving Java Client

以 Tensor 方式

Predict.PredictRequest.Builder request = Predict.PredictRequest.newBuilder();
Model.ModelSpec.Builder modelSpec = Model.ModelSpec.newBuilder();
modelSpec.setName("mnist");
modelSpec.setSignatureName("serving_default");
request.setModelSpec(modelSpec);

TensorShapeProto.Dim dim = TensorShapeProto.Dim.newBuilder().setSize(1).build();
TensorShapeProto shape = TensorShapeProto.newBuilder().addDim(dim).build();

TensorProto.Builder tensor = TensorProto.newBuilder();
tensor.setTensorShape(shape);
tensor.setDtype(DataType.DT_INT32);
tensor.addIntVal(10);
request.putInputs("input_label", tensor.build());

tensor.clear();
tensor.setTensorShape(shape);
tensor.setDtype(DataType.DT_STRING);
tensor.addStringVal(ByteString.copyFrom("新品,冰淇淋,不错,哈根达,斯,冰淇淋,不错,麻薯,口味,单一,口味,时代,吃,抹茶,口味,终于,选择,甜,吃,舒服,一点,吃,完,口,渴", StandardCharsets.UTF_8));
request.putInputs("input_feature", tensor.build());

tensor.clear();
tensor.setDtype(DataType.DT_FLOAT);
tensor.addFloatVal(1.0F);
request.putInputs("keep_prob", tensor.build());

ManagedChannel channel = ManagedChannelBuilder.forAddress("localhost", 9000).usePlaintext(true).build();
PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel);
Predict.PredictResponse response = stub.predict(request.build());
System.out.println(response);

以 TFRecord 方式

FloatList.Builder floatList = FloatList.newBuilder();
floatList.addValue(6.9F);
floatList.addValue(3.1F);
floatList.addValue(5.4F);
floatList.addValue(2.1F);
Feature feature = Feature.newBuilder().setFloatList(floatList).build();

Features.Builder features = Features.newBuilder();
features.putFeature("x", feature);

Example example = Example.newBuilder().setFeatures(features).build();

Predict.PredictRequest.Builder request = Predict.PredictRequest.newBuilder();
Model.ModelSpec.Builder modelSpec = Model.ModelSpec.newBuilder();
modelSpec.setName("mnist");
modelSpec.setSignatureName("serving_default");
request.setModelSpec(modelSpec);

TensorShapeProto.Dim dim = TensorShapeProto.Dim.newBuilder().setSize(1).build();
TensorShapeProto shape = TensorShapeProto.newBuilder().addDim(dim).build();

TensorProto.Builder tensor = TensorProto.newBuilder();
tensor.setTensorShape(shape);
tensor.setDtype(DataType.DT_STRING);
tensor.addStringVal(example.toByteString());
request.putInputs("inputs", tensor.build());

ManagedChannel channel = ManagedChannelBuilder.forAddress("localhost", 9000).usePlaintext(true).build();
PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel);
Predict.PredictResponse response = stub.predict(request.build());
System.out.println(response);

前面的例子属于单个调用, 如果需要批量调用, 在 tensor.addxxx 部分, add 多个值, 同时修改对应维度的 size, 比如

TensorShapeProto.Dim dim = TensorShapeProto.Dim.newBuilder().setSize(2).build();
TensorShapeProto shape = TensorShapeProto.newBuilder().addDim(dim).build();

TensorProto.Builder tensor = TensorProto.newBuilder();
tensor.setTensorShape(shape);
tensor.setDtype(DataType.DT_STRING);
tensor.addStringVal(example.toByteString());
tensor.addStringVal(example.toByteString());
request.putInputs("inputs", tensor.build());

整个项目代码, 参见 https://github.com/henryhyn/tensorflow-serving-api

参考文献