diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2020-03-24 13:54:05 +0000 |
---|---|---|
committer | Jim Flynn <jim.flynn@arm.com> | 2020-03-24 16:45:36 +0000 |
commit | d8cc8116f2deea11ad7aff9218a2e103062a7daf (patch) | |
tree | 63ca615ca9d5f8a1655f560518fc65b662d65e51 /tests/ExecuteNetwork | |
parent | b6a402f46231688f7684dcb8c8e4ef5f4579b011 (diff) | |
download | armnn-d8cc8116f2deea11ad7aff9218a2e103062a7daf.tar.gz |
IVGCVSW-4521 Add bf16-turbo-mode option to ExecuteNetwork
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I57ec47adf98680254fa481fb91d5a98dea8f032e
Diffstat (limited to 'tests/ExecuteNetwork')
-rw-r--r-- | tests/ExecuteNetwork/ExecuteNetwork.cpp | 19 |
1 files changed, 14 insertions, 5 deletions
diff --git a/tests/ExecuteNetwork/ExecuteNetwork.cpp b/tests/ExecuteNetwork/ExecuteNetwork.cpp index e9811d523a..a59f58074b 100644 --- a/tests/ExecuteNetwork/ExecuteNetwork.cpp +++ b/tests/ExecuteNetwork/ExecuteNetwork.cpp @@ -97,6 +97,8 @@ int main(int argc, const char* argv[]) "Enables built in profiler. If unset, defaults to off.") ("visualize-optimized-model,v", po::bool_switch()->default_value(false), "Enables built optimized model visualizer. If unset, defaults to off.") + ("bf16-turbo-mode", po::bool_switch()->default_value(false), "If this option is enabled, FP32 layers, " + "weights and biases will be converted to BFloat16 where the backend supports it") ("fp16-turbo-mode,h", po::bool_switch()->default_value(false), "If this option is enabled, FP32 layers, " "weights and biases will be converted to FP16 where the backend supports it") ("threshold-time,r", po::value<double>(&thresholdTime)->default_value(0.0), @@ -158,6 +160,7 @@ int main(int argc, const char* argv[]) bool concurrent = vm["concurrent"].as<bool>(); bool enableProfiling = vm["event-based-profiling"].as<bool>(); bool enableLayerDetails = vm["visualize-optimized-model"].as<bool>(); + bool enableBf16TurboMode = vm["bf16-turbo-mode"].as<bool>(); bool enableFp16TurboMode = vm["fp16-turbo-mode"].as<bool>(); bool quantizeInput = vm["quantize-input"].as<bool>(); bool dequantizeOutput = vm["dequantize-output"].as<bool>(); @@ -166,6 +169,12 @@ int main(int argc, const char* argv[]) bool fileOnlyExternalProfiling = vm["file-only-external-profiling"].as<bool>(); bool parseUnsupported = vm["parse-unsupported"].as<bool>(); + if (enableBf16TurboMode && enableFp16TurboMode) + { + ARMNN_LOG(fatal) << "BFloat16 and Float16 turbo mode cannot be enabled at the same time."; + return EXIT_FAILURE; + } + // Check whether we have to load test cases from a file. if (CheckOption(vm, "test-cases")) @@ -213,8 +222,8 @@ int main(int argc, const char* argv[]) { testCase.values.insert(testCase.values.begin(), executableName); results.push_back(std::async(std::launch::async, RunCsvTest, std::cref(testCase), std::cref(runtime), - enableProfiling, enableFp16TurboMode, thresholdTime, printIntermediate, - enableLayerDetails, parseUnsupported)); + enableProfiling, enableFp16TurboMode, enableBf16TurboMode, thresholdTime, + printIntermediate, enableLayerDetails, parseUnsupported)); } // Check results @@ -233,7 +242,7 @@ int main(int argc, const char* argv[]) { testCase.values.insert(testCase.values.begin(), executableName); if (RunCsvTest(testCase, runtime, enableProfiling, - enableFp16TurboMode, thresholdTime, printIntermediate, + enableFp16TurboMode, enableBf16TurboMode, thresholdTime, printIntermediate, enableLayerDetails, parseUnsupported) != EXIT_SUCCESS) { return EXIT_FAILURE; @@ -280,7 +289,7 @@ int main(int argc, const char* argv[]) return RunTest(modelFormat, inputTensorShapes, computeDevices, dynamicBackendsPath, modelPath, inputNames, inputTensorDataFilePaths, inputTypes, quantizeInput, outputTypes, outputNames, - outputTensorFiles, dequantizeOutput, enableProfiling, enableFp16TurboMode, thresholdTime, - printIntermediate, subgraphId, enableLayerDetails, parseUnsupported, runtime); + outputTensorFiles, dequantizeOutput, enableProfiling, enableFp16TurboMode, enableBf16TurboMode, + thresholdTime, printIntermediate, subgraphId, enableLayerDetails, parseUnsupported, runtime); } } |