summaryrefslogtreecommitdiffstats
path: root/java/com/android/voicemail/impl/transcribe/grpc/TranscriptionClientFactory.java
diff options
context:
space:
mode:
Diffstat (limited to 'java/com/android/voicemail/impl/transcribe/grpc/TranscriptionClientFactory.java')
-rw-r--r--java/com/android/voicemail/impl/transcribe/grpc/TranscriptionClientFactory.java194
1 files changed, 194 insertions, 0 deletions
diff --git a/java/com/android/voicemail/impl/transcribe/grpc/TranscriptionClientFactory.java b/java/com/android/voicemail/impl/transcribe/grpc/TranscriptionClientFactory.java
new file mode 100644
index 000000000..6101ed516
--- /dev/null
+++ b/java/com/android/voicemail/impl/transcribe/grpc/TranscriptionClientFactory.java
@@ -0,0 +1,194 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+package com.android.voicemail.impl.transcribe.grpc;
+
+import android.content.Context;
+import android.content.pm.PackageInfo;
+import android.content.pm.PackageManager;
+import android.text.TextUtils;
+import com.android.dialer.common.Assert;
+import com.android.dialer.common.LogUtil;
+import com.android.voicemail.impl.transcribe.TranscriptionConfigProvider;
+import com.google.internal.communications.voicemailtranscription.v1.VoicemailTranscriptionServiceGrpc;
+import io.grpc.CallOptions;
+import io.grpc.Channel;
+import io.grpc.ClientCall;
+import io.grpc.ClientInterceptor;
+import io.grpc.ClientInterceptors;
+import io.grpc.ForwardingClientCall;
+import io.grpc.ManagedChannel;
+import io.grpc.ManagedChannelBuilder;
+import io.grpc.Metadata;
+import io.grpc.MethodDescriptor;
+import io.grpc.okhttp.OkHttpChannelBuilder;
+import java.security.MessageDigest;
+
+/**
+ * Factory for creating grpc clients that talk to the transcription server. This allows all clients
+ * to share the same channel, which is relatively expensive to create.
+ */
+public class TranscriptionClientFactory {
+ private static final String DIGEST_ALGORITHM_SHA1 = "SHA1";
+ private static final char[] HEX_UPPERCASE = {
+ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'
+ };
+
+ private final TranscriptionConfigProvider configProvider;
+ private final ManagedChannel originalChannel;
+ private final String packageName;
+ private final String cert;
+
+ public TranscriptionClientFactory(Context context, TranscriptionConfigProvider configProvider) {
+ this(context, configProvider, getManagedChannel(configProvider));
+ }
+
+ public TranscriptionClientFactory(
+ Context context, TranscriptionConfigProvider configProvider, ManagedChannel managedChannel) {
+ this.configProvider = configProvider;
+ this.packageName = context.getPackageName();
+ this.cert = getCertificateFingerprint(context);
+ originalChannel = managedChannel;
+ }
+
+ public TranscriptionClient getClient() {
+ LogUtil.enterBlock("TranscriptionClientFactory.getClient");
+ Assert.checkState(!originalChannel.isShutdown());
+ Channel channel =
+ ClientInterceptors.intercept(
+ originalChannel,
+ new Interceptor(
+ packageName, cert, configProvider.getApiKey(), configProvider.getAuthToken()));
+ return new TranscriptionClient(VoicemailTranscriptionServiceGrpc.newBlockingStub(channel));
+ }
+
+ public void shutdown() {
+ LogUtil.enterBlock("TranscriptionClientFactory.shutdown");
+ originalChannel.shutdown();
+ }
+
+ private static ManagedChannel getManagedChannel(TranscriptionConfigProvider configProvider) {
+ ManagedChannelBuilder<OkHttpChannelBuilder> builder =
+ OkHttpChannelBuilder.forTarget(configProvider.getServerAddress());
+ // Only use plaintext for debugging
+ if (configProvider.shouldUsePlaintext()) {
+ // Just passing 'false' doesnt have the same effect as not setting this field
+ builder.usePlaintext(true);
+ }
+ return builder.build();
+ }
+
+ private static String getCertificateFingerprint(Context context) {
+ try {
+ PackageInfo packageInfo =
+ context
+ .getPackageManager()
+ .getPackageInfo(context.getPackageName(), PackageManager.GET_SIGNATURES);
+ if (packageInfo != null
+ && packageInfo.signatures != null
+ && packageInfo.signatures.length > 0) {
+ MessageDigest messageDigest = MessageDigest.getInstance(DIGEST_ALGORITHM_SHA1);
+ if (messageDigest == null) {
+ LogUtil.w(
+ "TranscriptionClientFactory.getCertificateFingerprint", "error getting digest.");
+ return null;
+ }
+ byte[] bytes = messageDigest.digest(packageInfo.signatures[0].toByteArray());
+ if (bytes == null) {
+ LogUtil.w(
+ "TranscriptionClientFactory.getCertificateFingerprint", "empty message digest.");
+ return null;
+ }
+
+ int length = bytes.length;
+ StringBuilder out = new StringBuilder(length * 2);
+ for (int i = 0; i < length; i++) {
+ out.append(HEX_UPPERCASE[(bytes[i] & 0xf0) >>> 4]);
+ out.append(HEX_UPPERCASE[bytes[i] & 0x0f]);
+ }
+ return out.toString();
+ } else {
+ LogUtil.w(
+ "TranscriptionClientFactory.getCertificateFingerprint",
+ "failed to get package signature.");
+ }
+ } catch (Exception e) {
+ LogUtil.e(
+ "TranscriptionClientFactory.getCertificateFingerprint",
+ "error getting certificate fingerprint.",
+ e);
+ }
+
+ return null;
+ }
+
+ private static final class Interceptor implements ClientInterceptor {
+ private final String packageName;
+ private final String cert;
+ private final String apiKey;
+ private final String authToken;
+
+ private static final Metadata.Key<String> API_KEY_HEADER =
+ Metadata.Key.of("X-Goog-Api-Key", Metadata.ASCII_STRING_MARSHALLER);
+ private static final Metadata.Key<String> ANDROID_PACKAGE_HEADER =
+ Metadata.Key.of("X-Android-Package", Metadata.ASCII_STRING_MARSHALLER);
+ private static final Metadata.Key<String> ANDROID_CERT_HEADER =
+ Metadata.Key.of("X-Android-Cert", Metadata.ASCII_STRING_MARSHALLER);
+ private static final Metadata.Key<String> AUTHORIZATION_HEADER =
+ Metadata.Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER);
+
+ public Interceptor(String packageName, String cert, String apiKey, String authToken) {
+ this.packageName = packageName;
+ this.cert = cert;
+ this.apiKey = apiKey;
+ this.authToken = authToken;
+ }
+
+ @Override
+ public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
+ MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
+ LogUtil.enterBlock(
+ "TranscriptionClientFactory.interceptCall, intercepted " + method.getFullMethodName());
+ ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
+
+ call =
+ new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(call) {
+ @Override
+ public void start(Listener<RespT> responseListener, Metadata headers) {
+ if (!TextUtils.isEmpty(packageName)) {
+ LogUtil.i(
+ "TranscriptionClientFactory.interceptCall",
+ "attaching package name: " + packageName);
+ headers.put(ANDROID_PACKAGE_HEADER, packageName);
+ }
+ if (!TextUtils.isEmpty(cert)) {
+ LogUtil.i("TranscriptionClientFactory.interceptCall", "attaching android cert");
+ headers.put(ANDROID_CERT_HEADER, cert);
+ }
+ if (!TextUtils.isEmpty(apiKey)) {
+ LogUtil.i("TranscriptionClientFactory.interceptCall", "attaching API Key");
+ headers.put(API_KEY_HEADER, apiKey);
+ }
+ if (!TextUtils.isEmpty(authToken)) {
+ LogUtil.i("TranscriptionClientFactory.interceptCall", "attaching auth token");
+ headers.put(AUTHORIZATION_HEADER, "Bearer " + authToken);
+ }
+ super.start(responseListener, headers);
+ }
+ };
+ return call;
+ }
+ }
+}