diff options
| -rw-r--r-- | .gitignore | 4 | ||||
| -rw-r--r-- | dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java | 7 | ||||
| -rw-r--r-- | res/values/strings.xml | 2 | ||||
| -rw-r--r-- | src/com/android/nn/benchmark/app/BenchmarkTestBase.java | 93 | ||||
| -rw-r--r-- | src/com/android/nn/benchmark/app/NNBenchmark.java | 41 | ||||
| -rw-r--r-- | src/com/android/nn/benchmark/app/NNScoringTest.java | 30 | ||||
| -rw-r--r-- | src/com/android/nn/benchmark/core/NNTestBase.java | 32 | ||||
| -rw-r--r-- | src/com/android/nn/benchmark/core/Processor.java | 174 | ||||
| -rw-r--r-- | src/com/android/nn/benchmark/core/TestModels.java | 23 |
9 files changed, 234 insertions, 172 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1519f47 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.iml +**/gen/* +**/.idea/* + diff --git a/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java b/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java index 4334d84..21ce236 100644 --- a/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java +++ b/dogfood/src/com/android/nn/dogfood/BenchmarkJobService.java @@ -28,12 +28,13 @@ import androidx.core.app.NotificationCompat; import androidx.core.app.NotificationManagerCompat; import com.android.nn.benchmark.core.BenchmarkResult; -import com.android.nn.benchmark.core.NNTestBase; import com.android.nn.benchmark.core.Processor; import com.android.nn.benchmark.core.TestModels; import java.util.List; import java.util.Random; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; /** Regularly runs a random selection of the NN API benchmark models */ public class BenchmarkJobService extends JobService implements Processor.Callback { @@ -52,6 +53,7 @@ public class BenchmarkJobService extends JobService implements Processor.Callbac private static int DOGFOOD_MODELS_PER_RUN = 20; private BenchmarkResult mTestResults[]; + private final ExecutorService processorRunner = Executors.newSingleThreadExecutor(); @Override @@ -74,11 +76,10 @@ public class BenchmarkJobService extends JobService implements Processor.Callbac } public void doBenchmark() { - mProcessor = new Processor(this, this, randomModelList()); mProcessor.setUseNNApi(true); mProcessor.setToggleLong(true); - mProcessor.start(); + processorRunner.submit(mProcessor); } public void onBenchmarkFinish(boolean ok) { diff --git a/res/values/strings.xml b/res/values/strings.xml index 180e0cb..aab5c97 100644 --- a/res/values/strings.xml +++ b/res/values/strings.xml @@ -33,7 +33,7 @@ <string name="ok">Ok</string> <string name="cancel">Cancel</string> <string name="settings">settings</string> - <string-array + <string-array name="settings_array"> <item>Run each test longer, 10 seconds</item> <item>Pause 10 seconds between tests</item> diff --git a/src/com/android/nn/benchmark/app/BenchmarkTestBase.java b/src/com/android/nn/benchmark/app/BenchmarkTestBase.java index 5aaa056..a89626f 100644 --- a/src/com/android/nn/benchmark/app/BenchmarkTestBase.java +++ b/src/com/android/nn/benchmark/app/BenchmarkTestBase.java @@ -23,7 +23,6 @@ import android.content.Context; import android.content.Intent; import android.content.IntentFilter; import android.os.BatteryManager; -import android.os.Bundle; import android.os.Trace; import android.test.ActivityInstrumentationTestCase2; import android.util.Log; @@ -35,6 +34,8 @@ import com.android.nn.benchmark.core.BenchmarkResult; import com.android.nn.benchmark.core.TestModels; import com.android.nn.benchmark.core.TestModels.TestModelEntry; +import java.util.concurrent.CountDownLatch; + import org.junit.After; import org.junit.Before; import org.junit.runner.RunWith; @@ -51,6 +52,7 @@ import java.util.List; */ @RunWith(Parameterized.class) public class BenchmarkTestBase extends ActivityInstrumentationTestCase2<NNBenchmark> { + // Only run 1 iteration now to fit the MediumTest time requirement. // One iteration means running the tests continuous for 1s. private NNBenchmark mActivity; @@ -95,33 +97,32 @@ public class BenchmarkTestBase extends ActivityInstrumentationTestCase2<NNBenchm protected void waitUntilCharged() { Log.v(NNBenchmark.TAG, "Waiting for the device to charge"); - Object lock = new Object(); + final CountDownLatch chargedLatch = new CountDownLatch(1); BroadcastReceiver receiver = new BroadcastReceiver() { @Override public void onReceive(Context context, Intent intent) { int level = intent.getIntExtra(BatteryManager.EXTRA_LEVEL, -1); int scale = intent.getIntExtra(BatteryManager.EXTRA_SCALE, -1); - int percentage = level * 100 / scale; + int percentage = level * 100 / scale; Log.v(NNBenchmark.TAG, "Battery level: " + percentage + "%"); int status = intent.getIntExtra(BatteryManager.EXTRA_STATUS, -1); if (status == BatteryManager.BATTERY_STATUS_FULL) { - synchronized (lock) { - lock.notify(); - } + chargedLatch.countDown(); } else if (status != BatteryManager.BATTERY_STATUS_CHARGING) { - Log.e(NNBenchmark.TAG, "Device is not charging"); + Log.e(NNBenchmark.TAG, + String.format("Device is not charging, status is %d", status)); } } }; mActivity.registerReceiver(receiver, new IntentFilter(Intent.ACTION_BATTERY_CHANGED)); - synchronized (lock) { - try { - lock.wait(); - } catch (InterruptedException e) { - } + try { + chargedLatch.await(); + } catch (InterruptedException ignored) { + Thread.currentThread().interrupt(); } + mActivity.unregisterReceiver(receiver); } @@ -139,34 +140,49 @@ public class BenchmarkTestBase extends ActivityInstrumentationTestCase2<NNBenchm super.tearDown(); } - class TestAction implements Runnable { - TestModelEntry mTestModel; + interface Joinable extends Runnable { + // Syncrhonises the caller with the completion of the current action + void join(); + } + + class TestAction implements Joinable { + + private final TestModelEntry mTestModel; + private final float mWarmupTimeSeconds; + private final float mRunTimeSeconds; + private final CountDownLatch actionComplete; + BenchmarkResult mResult; - float mWarmupTimeSeconds; - float mRunTimeSeconds; Throwable mException; - public TestAction(TestModelEntry testName) { - mTestModel = testName; - } public TestAction(TestModelEntry testName, float warmupTimeSeconds, float runTimeSeconds) { mTestModel = testName; mWarmupTimeSeconds = warmupTimeSeconds; mRunTimeSeconds = runTimeSeconds; + actionComplete = new CountDownLatch(1); } public void run() { + Log.v(NNBenchmark.TAG, String.format( + "Starting benchmark for test '%s' running for at least %f seconds", + mTestModel.mTestName, + mRunTimeSeconds)); try { - mResult = mActivity.mProcessor.getInstrumentationResult( - mTestModel, mWarmupTimeSeconds, mRunTimeSeconds); - } catch (IOException e) { + mResult = mActivity.runSynchronously( + mTestModel, mWarmupTimeSeconds, mRunTimeSeconds); + Log.v(NNBenchmark.TAG, + String.format("Benchmark for test '%s' is: %s", mTestModel, mResult)); + } catch (BenchmarkException | IOException e) { mException = e; - e.printStackTrace(); - } - Log.v(NNBenchmark.TAG, - "Benchmark for test \"" + mTestModel.toString() + "\" is: " + mResult); - synchronized (this) { - this.notify(); + Log.e(NNBenchmark.TAG, + String.format("Error running Benchmark for test '%s'", mTestModel), e); + } catch (Throwable e) { + mException = e; + Log.e(NNBenchmark.TAG, + String.format("Failure running Benchmark for test '%s'!!", mTestModel), e); + throw e; + } finally { + actionComplete.countDown(); } } @@ -176,22 +192,25 @@ public class BenchmarkTestBase extends ActivityInstrumentationTestCase2<NNBenchm } return mResult; } - } - // Set the benchmark thread to run on ui thread - // Synchronized the thread such that the test will wait for the benchmark thread to finish - public void runOnUiThread(Runnable action) { - synchronized (action) { - mActivity.runOnUiThread(action); + @Override + public void join() { try { - action.wait(); + actionComplete.await(); } catch (InterruptedException e) { - Log.v(NNBenchmark.TAG, "waiting for action running on UI thread is interrupted: " + - e.toString()); + Thread.currentThread().interrupt(); + Log.v(NNBenchmark.TAG, "Interrupted while waiting for action running", e); } } } + // Set the benchmark thread to run on ui thread + // Synchronized the thread such that the test will wait for the benchmark thread to finish + public void runOnUiThread(Joinable action) { + mActivity.runOnUiThread(action); + action.join(); + } + public void runTest(TestAction ta, String testName) { float sum = 0; // For NNAPI systrace usage documentation, see diff --git a/src/com/android/nn/benchmark/app/NNBenchmark.java b/src/com/android/nn/benchmark/app/NNBenchmark.java index a45e043..2ea4478 100644 --- a/src/com/android/nn/benchmark/app/NNBenchmark.java +++ b/src/com/android/nn/benchmark/app/NNBenchmark.java @@ -19,33 +19,38 @@ package com.android.nn.benchmark.app; import android.app.Activity; import android.content.Intent; import android.os.Bundle; +import android.util.Log; import android.view.WindowManager; import android.widget.TextView; - +import com.android.nn.benchmark.core.BenchmarkException; import com.android.nn.benchmark.core.BenchmarkResult; import com.android.nn.benchmark.core.Processor; +import com.android.nn.benchmark.core.TestModels.TestModelEntry; +import java.io.IOException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; public class NNBenchmark extends Activity implements Processor.Callback { - protected static final String TAG = "NN_BENCHMARK"; + public static final String TAG = "NN_BENCHMARK"; public static final String EXTRA_ENABLE_LONG = "enable long"; public static final String EXTRA_ENABLE_PAUSE = "enable pause"; public static final String EXTRA_DISABLE_NNAPI = "disable NNAPI"; - public static final String EXTRA_DEMO = "demo"; public static final String EXTRA_TESTS = "tests"; public static final String EXTRA_RESULTS_TESTS = "tests"; public static final String EXTRA_RESULTS_RESULTS = "results"; private int mTestList[]; - private BenchmarkResult mTestResults[]; + + private Processor mProcessor; + private final ExecutorService executorService = Executors.newSingleThreadExecutor(); private TextView mTextView; // Initialize the parameters for Instrumentation tests. protected void prepareInstrumentationTest() { mTestList = new int[1]; - mTestResults = new BenchmarkResult[1]; mProcessor = new Processor(this, this, mTestList); } @@ -57,9 +62,6 @@ public class NNBenchmark extends Activity implements Processor.Callback { mProcessor.setCompleteInputSet(completeInputSet); } - private boolean mDoingBenchmark; - public Processor mProcessor; - @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); @@ -74,7 +76,8 @@ public class NNBenchmark extends Activity implements Processor.Callback { protected void onPause() { super.onPause(); if (mProcessor != null) { - mProcessor.exit(); + mProcessor.exitWithTimeout(30000l); + mProcessor = null; } } @@ -104,12 +107,15 @@ public class NNBenchmark extends Activity implements Processor.Callback { super.onResume(); Intent i = getIntent(); mTestList = i.getIntArrayExtra(EXTRA_TESTS); - mProcessor = new Processor(this, this, mTestList); - mProcessor.setToggleLong(i.getBooleanExtra(EXTRA_ENABLE_LONG, false)); - mProcessor.setTogglePause(i.getBooleanExtra(EXTRA_ENABLE_PAUSE, false)); - mProcessor.setUseNNApi(!i.getBooleanExtra(EXTRA_DISABLE_NNAPI, false)); - if (mTestList != null) { - mProcessor.start(); + if (mTestList != null && mTestList.length > 0) { + Log.v(TAG, String.format("Starting benchmark with %d test", mTestList.length)); + mProcessor = new Processor(this, this, mTestList); + mProcessor.setToggleLong(i.getBooleanExtra(EXTRA_ENABLE_LONG, false)); + mProcessor.setTogglePause(i.getBooleanExtra(EXTRA_ENABLE_PAUSE, false)); + mProcessor.setUseNNApi(!i.getBooleanExtra(EXTRA_DISABLE_NNAPI, false)); + executorService.submit(mProcessor); + } else { + Log.v(TAG, "No test to run, doing nothing"); } } @@ -117,4 +123,9 @@ public class NNBenchmark extends Activity implements Processor.Callback { protected void onDestroy() { super.onDestroy(); } + + public BenchmarkResult runSynchronously(TestModelEntry testModel, + float warmupTimeSeconds, float runTimeSeconds) throws IOException, BenchmarkException { + return mProcessor.getInstrumentationResult(testModel, warmupTimeSeconds, runTimeSeconds); + } } diff --git a/src/com/android/nn/benchmark/app/NNScoringTest.java b/src/com/android/nn/benchmark/app/NNScoringTest.java index 1c339cf..4932eb1 100644 --- a/src/com/android/nn/benchmark/app/NNScoringTest.java +++ b/src/com/android/nn/benchmark/app/NNScoringTest.java @@ -36,6 +36,7 @@ import java.nio.file.Files; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import android.util.Log; /** * Tests that run all models/datasets/backend that are required for scoring the device. @@ -61,17 +62,15 @@ public class NNScoringTest extends BenchmarkTestBase { super.prepareTest(); } - @Test - @LargeTest - public void testTFLite() throws IOException { + private void test(boolean useNnapi) throws IOException { if (!TestExternalStorageActivity.testWriteExternalStorage(getActivity(), false)) { throw new IOException("No permission to store results in external storage"); } - setUseNNApi(false); + setUseNNApi(useNnapi); setCompleteInputSet(true); TestAction ta = new TestAction(mModel, WARMUP_REPEATABLE_SECONDS, - COMPLETE_SET_TIMEOUT_SECOND); + COMPLETE_SET_TIMEOUT_SECOND); runTest(ta, mModel.getTestName()); try (CSVWriter writer = new CSVWriter(getLocalCSVFile())) { @@ -81,21 +80,14 @@ public class NNScoringTest extends BenchmarkTestBase { @Test @LargeTest - public void testNNAPI() throws IOException { - if (!TestExternalStorageActivity.testWriteExternalStorage(getActivity(), false)) { - throw new IOException("No permission to store results in external storage"); - } - - setUseNNApi(true); - setCompleteInputSet(true); - TestAction ta = new TestAction(mModel, WARMUP_REPEATABLE_SECONDS, - COMPLETE_SET_TIMEOUT_SECOND); - runTest(ta, mModel.getTestName()); - + public void testTFLite() throws IOException { + test(false); + } - try (CSVWriter writer = new CSVWriter(getLocalCSVFile())) { - writer.write(ta.getBenchmark()); - } + @Test + @LargeTest + public void testNNAPI() throws IOException { + test(true); } public static File getLocalCSVFile() { diff --git a/src/com/android/nn/benchmark/core/NNTestBase.java b/src/com/android/nn/benchmark/core/NNTestBase.java index 579b4a2..21f8711 100644 --- a/src/com/android/nn/benchmark/core/NNTestBase.java +++ b/src/com/android/nn/benchmark/core/NNTestBase.java @@ -262,12 +262,13 @@ public class NNTestBase { Pair<List<InferenceInOutSequence>, List<InferenceResult>> result = runBenchmark(ios, totalSequenceInferencesCount, timeoutSec, flags); - if (result.second.size() != extpectedResults ) { + if (result.second.size() != extpectedResults) { // We reached a timeout or failed to evaluate whole set for other reason, abort. - throw new IllegalStateException( - "Failed to evaluate complete input set, expected: " - + extpectedResults + - ", received: " + result.second.size()); + final String errorMsg = "Failed to evaluate complete input set, expected: " + + extpectedResults + + ", received: " + result.second.size(); + Log.w(TAG, errorMsg); + throw new IllegalStateException(errorMsg); } return result; } @@ -283,7 +284,7 @@ public class NNTestBase { } List<InferenceResult> resultList = new ArrayList<>(); if (!runBenchmark(mModelHandle, inOutList, resultList, inferencesSeqMaxCount, - timeoutSec, flags)) { + timeoutSec, flags)) { throw new BenchmarkException("Failed to run benchmark"); } return new Pair<List<InferenceInOutSequence>, List<InferenceResult>>( @@ -303,21 +304,18 @@ public class NNTestBase { String modelAssetName = mModelFile + ".tflite"; AssetManager assetManager = mContext.getAssets(); try { - InputStream in = assetManager.open(modelAssetName); - outFileName = mContext.getCacheDir().getAbsolutePath() + "/" + modelAssetName; File outFile = new File(outFileName); - OutputStream out = new FileOutputStream(outFile); - byte[] buffer = new byte[1024]; - int read; - while ((read = in.read(buffer)) != -1) { - out.write(buffer, 0, read); - } - out.flush(); + try (InputStream in = assetManager.open(modelAssetName); + FileOutputStream out = new FileOutputStream(outFile)) { - in.close(); - out.close(); + byte[] byteBuffer = new byte[1024]; + int readBytes = -1; + while ((readBytes = in.read(byteBuffer)) != -1) { + out.write(byteBuffer, 0, readBytes); + } + } } catch (IOException e) { Log.e(TAG, "Failed to copy asset file: " + modelAssetName, e); return null; diff --git a/src/com/android/nn/benchmark/core/Processor.java b/src/com/android/nn/benchmark/core/Processor.java index 1aa6008..0017082 100644 --- a/src/com/android/nn/benchmark/core/Processor.java +++ b/src/com/android/nn/benchmark/core/Processor.java @@ -16,6 +16,8 @@ package com.android.nn.benchmark.core; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + import android.content.Context; import android.os.Trace; import android.util.Log; @@ -23,21 +25,26 @@ import android.util.Pair; import java.io.IOException; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; /** Processor is a helper thread for running the work without blocking the UI thread. */ -public class Processor extends Thread { +public class Processor implements Runnable { public interface Callback { - public void onBenchmarkFinish(boolean ok); + void onBenchmarkFinish(boolean ok); - public void onStatusUpdate(int testNumber, int numTests, String modelName); + void onStatusUpdate(int testNumber, int numTests, String modelName); } protected static final String TAG = "NN_BENCHMARK"; private Context mContext; - private float mLastResult; - private boolean mRun = true; + private final AtomicBoolean mRun = new AtomicBoolean(true); + + volatile boolean mHasBeenStarted = false; + // You cannot restart a thread, so the completion flag is final + private final CountDownLatch mCompleted = new CountDownLatch(1); private boolean mDoingBenchmark; private NNTestBase mTest; private int mTestList[]; @@ -78,18 +85,24 @@ public class Processor extends Thread { // Method to retrieve benchmark results for instrumentation tests. public BenchmarkResult getInstrumentationResult( TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds) - throws IOException { + throws IOException, BenchmarkException { mTest = changeTest(mTest, t); - return getBenchmark(warmupTimeSeconds, runTimeSeconds); + BenchmarkResult result = getBenchmark(warmupTimeSeconds, runTimeSeconds); + mTest.destroy(); + mTest = null; + return result; } - private NNTestBase changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t) { + private NNTestBase changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t) + throws BenchmarkException { if (oldTestBase != null) { // Make sure we don't leak memory. oldTestBase.destroy(); } NNTestBase tb = t.createNNTestBase(mUseNNApi, false /* enableIntermediateTensorsDump */); - tb.setupModel(mContext); + if (!tb.setupModel(mContext)) { + throw new BenchmarkException("Cannot initialise model"); + } return tb; } @@ -133,14 +146,10 @@ public class Processor extends Thread { mTest.checkSdkVersion(); } catch (UnsupportedSdkException e) { BenchmarkResult r = new BenchmarkResult(e.getMessage()); - Log.v(TAG, "Test: " + r.toString()); + Log.v(TAG, "Unsupported SDK for test: " + r.toString()); return r; } - mDoingBenchmark = true; - - long result = 0; - // We run a short bit of work before starting the actual test // this is to let any power management do its job and respond. // For NNAPI systrace usage documentation, see @@ -163,79 +172,114 @@ public class Processor extends Thread { Trace.endSection(); } - Log.v(TAG, "Test: " + r.toString()); + Log.v(TAG, "Completed benchmark loop"); - mDoingBenchmark = false; return r; } @Override public void run() { - while (mRun) { - // Our loop for launching tests or benchmarks - synchronized (this) { - // We may have been asked to exit while waiting - if (!mRun) return; + mHasBeenStarted = true; + Log.d(TAG, "Processor starting"); + try { + while (mRun.get()) { + try { + benchmarkAllModels(); + } catch (IOException e) { + Log.e(TAG, "IOException during benchmark run", e); + break; + } catch (Throwable e) { + Log.e(TAG, "Error during execution", e); + throw e; + } + + mCallback.onBenchmarkFinish(mRun.get()); } + } finally { + mCompleted.countDown(); + } + } + private void benchmarkAllModels() throws IOException { + Log.i(TAG, String.format("Iterating through %d models", mTestList.length)); + // Loop over the tests we want to benchmark + for (int ct = 0; ct < mTestList.length; ct++) { + if (!mRun.get()) { + Log.v(TAG, String.format("Asked to stop execution at model #%d", ct)); + break; + } + // For reproducibility we wait a short time for any sporadic work + // created by the user touching the screen to launch the test to pass. + // Also allows for things to settle after the test changes. try { - // Loop over the tests we want to benchmark - for (int ct = 0; (ct < mTestList.length) && mRun; ct++) { + Thread.sleep(250); + } catch (InterruptedException ignored) { + Thread.currentThread().interrupt(); + break; + } - // For reproducibility we wait a short time for any sporadic work - // created by the user touching the screen to launch the test to pass. - // Also allows for things to settle after the test changes. - try { - sleep(250); - } catch (InterruptedException e) { - } + TestModels.TestModelEntry testModel = + TestModels.modelsList().get(mTestList[ct]); - TestModels.TestModelEntry testModel = - TestModels.modelsList().get(mTestList[ct]); - int testNumber = ct + 1; - mCallback.onStatusUpdate(testNumber, mTestList.length, testModel.toString()); - - // Select the next test - mTest = changeTest(mTest, testModel); - - // If the user selected the "long pause" option, wait - if (mTogglePause) { - for (int i = 0; (i < 100) && mRun; i++) { - try { - sleep(100); - } catch (InterruptedException e) { - } - } - } + Log.i(TAG, String.format("%d/%d: '%s'", ct, mTestList.length, + testModel.mTestName)); + int testNumber = ct + 1; + mCallback.onStatusUpdate(testNumber, mTestList.length, + testModel.toString()); + + // Select the next test + try { + mTest = changeTest(mTest, testModel); + } catch (BenchmarkException e) { + Log.w(TAG, String.format("Cannot initialise test %d: '%s', skipping", ct, + testModel.mTestName), e); + } - // Run the test - float warmupTime = 0.3f; - float runTime = 1.f; - if (mToggleLong) { - warmupTime = 2.f; - runTime = 10.f; + // If the user selected the "long pause" option, wait + if (mTogglePause) { + for (int i = 0; (i < 100) && mRun.get(); i++) { + try { + Thread.sleep(100); + } catch (InterruptedException ignored) { + Thread.currentThread().interrupt(); + break; } - mTestResults[ct] = getBenchmark(warmupTime, runTime); } - mCallback.onBenchmarkFinish(mRun); - } catch (IOException e) { - Log.e(TAG, "Exception during benchmark run", e); - break; } + + // Run the test + float warmupTime = 0.3f; + float runTime = 1.f; + if (mToggleLong) { + warmupTime = 2.f; + runTime = 10.f; + } + Log.i(TAG, "Running test for model " + testModel.mModelName + " file " + + testModel.mModelFile); + mTestResults[ct] = getBenchmark(warmupTime, runTime); } } public void exit() { - mRun = false; + exitWithTimeout(-1l); + } - synchronized (this) { - notifyAll(); - } - // exit() is called on same thread when run via dogfood BenchmarkJobService - if (this != Thread.currentThread()) { + public void exitWithTimeout(long timeoutMs) { + mRun.set(false); + + if (mHasBeenStarted) { try { - this.join(); + if (timeoutMs > 0) { + boolean hasCompleted = mCompleted.await(timeoutMs, MILLISECONDS); + if (!hasCompleted) { + Log.w(TAG, "Exiting before execution actually completed"); + } + } else { + mCompleted.await(); + } } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + Log.w(TAG, "Interrupted while waiting for Processor to complete", e); } } diff --git a/src/com/android/nn/benchmark/core/TestModels.java b/src/com/android/nn/benchmark/core/TestModels.java index 95dfff5..7dad749 100644 --- a/src/com/android/nn/benchmark/core/TestModels.java +++ b/src/com/android/nn/benchmark/core/TestModels.java @@ -16,9 +16,9 @@ package com.android.nn.benchmark.core; -import java.util.List; import java.util.ArrayList; -import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; /** Information about available benchmarking models */ public class TestModels { @@ -88,12 +88,13 @@ public class TestModels { } } - static private List<TestModelEntry> sTestModelEntryList = new ArrayList<>(); - static private volatile boolean sTestModelEntryListFrozen = false; + static private final List<TestModelEntry> sTestModelEntryList = new ArrayList<>(); + static private final AtomicReference<List<TestModelEntry>> frozenEntries = new AtomicReference<>(null); + /** Add new benchmark model. */ static public void registerModel(TestModelEntry model) { - if (sTestModelEntryListFrozen) { + if (frozenEntries.get() != null) { throw new IllegalStateException("Can't register new models after its list is frozen"); } sTestModelEntryList.add(model); @@ -104,16 +105,8 @@ public class TestModels { * If this method was called at least once, then it's impossible to register new models. */ static public List<TestModelEntry> modelsList() { - if (!sTestModelEntryListFrozen) { - // If this method was called once, make models list unmodifiable - synchronized (TestModels.class) { - if (!sTestModelEntryListFrozen) { - sTestModelEntryList = Collections.unmodifiableList(sTestModelEntryList); - sTestModelEntryListFrozen = true; - } - } - } - return sTestModelEntryList; + frozenEntries.compareAndSet(null, sTestModelEntryList); + return frozenEntries.get(); } /** Fetch model by its name. */ |
