aboutsummaryrefslogtreecommitdiff
path: root/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp')
-rw-r--r--tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp17
1 files changed, 12 insertions, 5 deletions
diff --git a/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp b/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp
index 31f37916b8..69941d5678 100644
--- a/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp
+++ b/tests/NetworkExecutionUtils/NetworkExecutionUtils.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include <armnn/ArmNN.hpp>
@@ -375,6 +375,7 @@ struct ExecuteNetworkParams
bool m_EnableLayerDetails = false;
bool m_GenerateTensorData;
bool m_ParseUnsupported = false;
+ bool m_InferOutputShape = false;
};
template<typename TParser, typename TDataType>
@@ -397,6 +398,7 @@ int MainImpl(const ExecuteNetworkParams& params,
inferenceModelParams.m_PrintIntermediateLayers = params.m_PrintIntermediate;
inferenceModelParams.m_VisualizePostOptimizationModel = params.m_EnableLayerDetails;
inferenceModelParams.m_ParseUnsupported = params.m_ParseUnsupported;
+ inferenceModelParams.m_InferOutputShape = params.m_InferOutputShape;
for(const std::string& inputName: params.m_InputNames)
{
@@ -550,6 +552,7 @@ int RunTest(const std::string& format,
const size_t subgraphId,
bool enableLayerDetails = false,
bool parseUnsupported = false,
+ bool inferOutputShape = false,
const size_t iterations = 1,
const std::shared_ptr<armnn::IRuntime>& runtime = nullptr)
{
@@ -678,6 +681,7 @@ int RunTest(const std::string& format,
params.m_EnableLayerDetails = enableLayerDetails;
params.m_GenerateTensorData = inputTensorDataFilePathsVector.empty();
params.m_ParseUnsupported = parseUnsupported;
+ params.m_InferOutputShape = inferOutputShape;
// Warn if ExecuteNetwork will generate dummy input data
if (params.m_GenerateTensorData)
@@ -749,7 +753,7 @@ int RunTest(const std::string& format,
int RunCsvTest(const armnnUtils::CsvRow &csvRow, const std::shared_ptr<armnn::IRuntime>& runtime,
const bool enableProfiling, const bool enableFp16TurboMode, const bool enableBf16TurboMode,
const double& thresholdTime, const bool printIntermediate, bool enableLayerDetails = false,
- bool parseUnuspported = false)
+ bool parseUnuspported = false, bool inferOutputShape = false)
{
IgnoreUnused(runtime);
std::string modelFormat;
@@ -869,7 +873,8 @@ int RunCsvTest(const armnnUtils::CsvRow &csvRow, const std::shared_ptr<armnn::IR
return RunTest(modelFormat, inputTensorShapes, computeDevices, dynamicBackendsPath, modelPath, inputNames,
inputTensorDataFilePaths, inputTypes, quantizeInput, outputTypes, outputNames, outputTensorFiles,
dequantizeOutput, enableProfiling, enableFp16TurboMode, enableBf16TurboMode,
- thresholdTime, printIntermediate, subgraphId, enableLayerDetails, parseUnuspported);
+ thresholdTime, printIntermediate, subgraphId, enableLayerDetails, parseUnuspported,
+ inferOutputShape);
}
#if defined(ARMCOMPUTECL_ENABLED)
@@ -895,7 +900,8 @@ int RunCLTuning(const std::string& tuningPath,
bool printIntermediate,
const size_t subgraphId,
bool enableLayerDetails = false,
- bool parseUnsupported = false)
+ bool parseUnsupported = false,
+ bool inferOutputShape = false)
{
armnn::IRuntime::CreationOptions options;
options.m_BackendOptions.emplace_back(
@@ -917,7 +923,8 @@ int RunCLTuning(const std::string& tuningPath,
int state = RunTest(modelFormat, inputTensorShapes, computeDevices, dynamicBackendsPath, modelPath, inputNames,
inputTensorDataFilePaths, inputTypes, quantizeInput, outputTypes, outputNames,
outputTensorFiles, dequantizeOutput, enableProfiling, enableFp16TurboMode, enableBf16TurboMode,
- thresholdTime, printIntermediate, subgraphId, enableLayerDetails, parseUnsupported, 1, runtime);
+ thresholdTime, printIntermediate, subgraphId, enableLayerDetails, parseUnsupported,
+ inferOutputShape, 1, runtime);
ARMNN_LOG(info) << "Tuning time: " << std::setprecision(2)
<< std::fixed << armnn::GetTimeDuration(start_time).count() << " ms\n";