TensorFlow Serving Java Client 所需的代码存放在 https://github.com/tensorflow/tensorflow 和 https://github.com/tensorflow/serving 项目中. 代码是用 proto 方式编写的, 因此我们需要先安装 protoc 工具, 从 proto 文件生成 Java 代码.
Mac 安装 protoc 3
protoc 3 已经有编译好的版本, 直接从官方网站 https://github.com/google/protobuf/releases, 下载编译好的安装包 protoc-3.5.1-osx-x86_64.zip
, 然后将命令复制到 /usr/local/bin
即可.
cd /tmp
mv protoc-3.5.1-osx-x86_64.zip .
unzip protoc-3.5.1-osx-x86_64.zip
cd bin
cp protoc /usr/local/bin/
检验是否安装成功
$ protoc --version
libprotoc 3.5.1
Mac 安装 protoc 2
protoc 2 一般用不到了, 考虑到有可能 proto 文件是用 proto 2 方式编写的, 这里介绍一下 protoc 2 的安装. 注意, protoc 2 官方没有编译好的版本, 只能手动编译安装. 从官方 https://github.com/google/protobuf/releases, 下载 protobuf-2.6.1
编译安装
cd /opt
mv protobuf-2.6.1.tar.gz .
tar -xkzvf protobuf-2.6.1.tar.gz
cd protobuf-2.6.1
./configure && make && make install
检验是否安装成功
$ protoc --version
libprotoc 2.6.1
通过 Maven 插件, 生成 Java 代码
新建一个 Maven 项目, 比如 tensorflow-serving-api
, 加入以下 Maven 配置.
<build>
<plugins>
<plugin>
<groupId>org.xolstice.maven.plugins</groupId>
<artifactId>protobuf-maven-plugin</artifactId>
<version>0.5.1</version>
<configuration>
<protocExecutable>/usr/local/bin/protoc</protocExecutable>
</configuration>
<executions>
<execution>
<goals>
<goal>compile</goal>
<goal>test-compile</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
<dependencies>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
<version>3.5.1</version>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-stub</artifactId>
<version>1.12.0</version>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-protobuf</artifactId>
<version>1.12.0</version>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-netty</artifactId>
<version>1.12.0</version>
</dependency>
</dependencies>
然后使用命令 mvn protobuf:compile
即可将项目中的 *.proto
文件生成 Java 代码. 第一次执行该命令, 会提示缺少 src/main/proto
这个目录, 创建好以后, 从前面提到的 2 个项目中筛选出如下文件.
src/main/proto
├── tensorflow
│ └── core
│ ├── example
│ │ ├── example.proto
│ │ └── feature.proto
│ ├── framework
│ │ ├── attr_value.proto
│ │ ├── function.proto
│ │ ├── graph.proto
│ │ ├── node_def.proto
│ │ ├── op_def.proto
│ │ ├── resource_handle.proto
│ │ ├── tensor.proto
│ │ ├── tensor_shape.proto
│ │ ├── types.proto
│ │ └── versions.proto
│ └── protobuf
│ ├── meta_graph.proto
│ └── saver.proto
└── tensorflow_serving
└── apis
├── classification.proto
├── get_model_metadata.proto
├── inference.proto
├── input.proto
├── model.proto
├── predict.proto
├── prediction_service.proto
└── regression.proto
现在执行 mvn clean protobuf:compile
, 或者 mvn clean package
即可生成一个可用的 tensorflow-serving-api Jar 包.
TensorFlow Serving Java Client
以 Tensor 方式
Predict.PredictRequest.Builder request = Predict.PredictRequest.newBuilder();
Model.ModelSpec.Builder modelSpec = Model.ModelSpec.newBuilder();
modelSpec.setName("mnist");
modelSpec.setSignatureName("serving_default");
request.setModelSpec(modelSpec);
TensorShapeProto.Dim dim = TensorShapeProto.Dim.newBuilder().setSize(1).build();
TensorShapeProto shape = TensorShapeProto.newBuilder().addDim(dim).build();
TensorProto.Builder tensor = TensorProto.newBuilder();
tensor.setTensorShape(shape);
tensor.setDtype(DataType.DT_INT32);
tensor.addIntVal(10);
request.putInputs("input_label", tensor.build());
tensor.clear();
tensor.setTensorShape(shape);
tensor.setDtype(DataType.DT_STRING);
tensor.addStringVal(ByteString.copyFrom("新品,冰淇淋,不错,哈根达,斯,冰淇淋,不错,麻薯,口味,单一,口味,时代,吃,抹茶,口味,终于,选择,甜,吃,舒服,一点,吃,完,口,渴", StandardCharsets.UTF_8));
request.putInputs("input_feature", tensor.build());
tensor.clear();
tensor.setDtype(DataType.DT_FLOAT);
tensor.addFloatVal(1.0F);
request.putInputs("keep_prob", tensor.build());
ManagedChannel channel = ManagedChannelBuilder.forAddress("localhost", 9000).usePlaintext(true).build();
PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel);
Predict.PredictResponse response = stub.predict(request.build());
System.out.println(response);
以 TFRecord 方式
FloatList.Builder floatList = FloatList.newBuilder();
floatList.addValue(6.9F);
floatList.addValue(3.1F);
floatList.addValue(5.4F);
floatList.addValue(2.1F);
Feature feature = Feature.newBuilder().setFloatList(floatList).build();
Features.Builder features = Features.newBuilder();
features.putFeature("x", feature);
Example example = Example.newBuilder().setFeatures(features).build();
Predict.PredictRequest.Builder request = Predict.PredictRequest.newBuilder();
Model.ModelSpec.Builder modelSpec = Model.ModelSpec.newBuilder();
modelSpec.setName("mnist");
modelSpec.setSignatureName("serving_default");
request.setModelSpec(modelSpec);
TensorShapeProto.Dim dim = TensorShapeProto.Dim.newBuilder().setSize(1).build();
TensorShapeProto shape = TensorShapeProto.newBuilder().addDim(dim).build();
TensorProto.Builder tensor = TensorProto.newBuilder();
tensor.setTensorShape(shape);
tensor.setDtype(DataType.DT_STRING);
tensor.addStringVal(example.toByteString());
request.putInputs("inputs", tensor.build());
ManagedChannel channel = ManagedChannelBuilder.forAddress("localhost", 9000).usePlaintext(true).build();
PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel);
Predict.PredictResponse response = stub.predict(request.build());
System.out.println(response);
前面的例子属于单个调用, 如果需要批量调用, 在 tensor.addxxx
部分, add 多个值, 同时修改对应维度的 size, 比如
TensorShapeProto.Dim dim = TensorShapeProto.Dim.newBuilder().setSize(2).build();
TensorShapeProto shape = TensorShapeProto.newBuilder().addDim(dim).build();
TensorProto.Builder tensor = TensorProto.newBuilder();
tensor.setTensorShape(shape);
tensor.setDtype(DataType.DT_STRING);
tensor.addStringVal(example.toByteString());
tensor.addStringVal(example.toByteString());
request.putInputs("inputs", tensor.build());
整个项目代码, 参见 https://github.com/henryhyn/tensorflow-serving-api