aboutsummaryrefslogtreecommitdiff
path: root/tests/ExecuteNetwork/ExecuteNetworkParams.cpp
diff options
context:
space:
mode:
authorNikhil Raj Arm <nikhil.raj@arm.com>2022-07-05 09:29:18 +0000
committerNikhil Raj <nikhil.raj@arm.com>2022-07-08 15:21:03 +0100
commitf4ccb1f6339a1e9ed573f188e7f14353167b5749 (patch)
treebb53a449cd42ed919022bd52b9e369a28d5a14d4 /tests/ExecuteNetwork/ExecuteNetworkParams.cpp
parentfd33a698ee3c588aa4064b70b7781ab25ff76f66 (diff)
downloadarmnn-f4ccb1f6339a1e9ed573f188e7f14353167b5749.tar.gz
Revert "IVGCVSW-6650 Refactor ExecuteNetwork"
This reverts commit 615e06f54a4c4139e81e289991ba4084aa2f69d3. Reason for revert: <Breaking nightlies and tests> Change-Id: I06a4a0119463188a653bb749033f78514645bd0c
Diffstat (limited to 'tests/ExecuteNetwork/ExecuteNetworkParams.cpp')
-rw-r--r--tests/ExecuteNetwork/ExecuteNetworkParams.cpp134
1 files changed, 131 insertions, 3 deletions
diff --git a/tests/ExecuteNetwork/ExecuteNetworkParams.cpp b/tests/ExecuteNetwork/ExecuteNetworkParams.cpp
index f341c30738..cc75bb4323 100644
--- a/tests/ExecuteNetwork/ExecuteNetworkParams.cpp
+++ b/tests/ExecuteNetwork/ExecuteNetworkParams.cpp
@@ -1,15 +1,76 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "ExecuteNetworkParams.hpp"
#include "NetworkExecutionUtils/NetworkExecutionUtils.hpp"
+#include <InferenceModel.hpp>
#include <armnn/Logging.hpp>
#include <fmt/format.h>
-#include <armnnUtils/Filesystem.hpp>
+
+bool IsModelBinary(const std::string& modelFormat)
+{
+ // Parse model binary flag from the model-format string we got from the command-line
+ if (modelFormat.find("binary") != std::string::npos)
+ {
+ return true;
+ }
+ else if (modelFormat.find("txt") != std::string::npos || modelFormat.find("text") != std::string::npos)
+ {
+ return false;
+ }
+ else
+ {
+ throw armnn::InvalidArgumentException(fmt::format("Unknown model format: '{}'. "
+ "Please include 'binary' or 'text'",
+ modelFormat));
+ }
+}
+
+void CheckModelFormat(const std::string& modelFormat)
+{
+ // Forward to implementation based on the parser type
+ if (modelFormat.find("armnn") != std::string::npos)
+ {
+#if defined(ARMNN_SERIALIZER)
+#else
+ throw armnn::InvalidArgumentException("Can't run model in armnn format without a "
+ "built with serialization support.");
+#endif
+ }
+ else if (modelFormat.find("onnx") != std::string::npos)
+ {
+#if defined(ARMNN_ONNX_PARSER)
+#else
+ throw armnn::InvalidArgumentException("Can't run model in onnx format without a "
+ "built with Onnx parser support.");
+#endif
+ }
+ else if (modelFormat.find("tflite") != std::string::npos)
+ {
+#if defined(ARMNN_TF_LITE_PARSER)
+ if (!IsModelBinary(modelFormat))
+ {
+ throw armnn::InvalidArgumentException(fmt::format("Unknown model format: '{}'. Only 'binary' "
+ "format supported for tflite files",
+ modelFormat));
+ }
+#elif defined(ARMNN_TFLITE_DELEGATE)
+#else
+ throw armnn::InvalidArgumentException("Can't run model in tflite format without a "
+ "built with Tensorflow Lite parser support.");
+#endif
+ }
+ else
+ {
+ throw armnn::InvalidArgumentException(fmt::format("Unknown model format: '{}'. "
+ "Please include 'tflite' or 'onnx'",
+ modelFormat));
+ }
+}
void CheckClTuningParameter(const int& tuningLevel,
const std::string& tuningPath,
@@ -44,6 +105,7 @@ void CheckClTuningParameter(const int& tuningLevel,
ARMNN_LOG(warning) << "To use Cl Tuning the compute device GpuAcc needs to be active.";
}
}
+
}
void ExecuteNetworkParams::ValidateParams()
@@ -58,6 +120,7 @@ void ExecuteNetworkParams::ValidateParams()
<< invalidBackends;
}
}
+
CheckClTuningParameter(m_TuningLevel, m_TuningPath, m_ComputeDevices);
if (m_EnableBf16TurboMode && m_EnableFp16TurboMode)
@@ -66,6 +129,10 @@ void ExecuteNetworkParams::ValidateParams()
"enabled at the same time.");
}
+ m_IsModelBinary = IsModelBinary(m_ModelFormat);
+
+ CheckModelFormat(m_ModelFormat);
+
// Check input tensor shapes
if ((m_InputTensorShapes.size() != 0) &&
(m_InputTensorShapes.size() != m_InputNames.size()))
@@ -90,6 +157,68 @@ void ExecuteNetworkParams::ValidateParams()
m_InputNames.size(),
m_InputTensorDataFilePaths.size()));
}
+ else if (m_InputTensorDataFilePaths.size() % m_InputNames.size() != 0)
+ {
+ throw armnn::InvalidArgumentException(
+ fmt::format("According to the number of input names the user provided the network has {} "
+ "inputs. The user specified {} input-tensor-data file paths which is not "
+ "divisible by the number of inputs.",
+ m_InputNames.size(),
+ m_InputTensorDataFilePaths.size()));
+ }
+ }
+
+ if (m_InputTypes.size() == 0)
+ {
+ //Defaults the value of all inputs to "float"
+ m_InputTypes.assign(m_InputNames.size(), "float");
+ }
+ else if ((m_InputTypes.size() != 0) &&
+ (m_InputTypes.size() != m_InputNames.size()))
+ {
+ throw armnn::InvalidArgumentException("input-name and input-type must have the same amount of elements.");
+ }
+
+ // Make sure that the number of input files given is divisible by the number of inputs of the model
+ if (!(m_InputTensorDataFilePaths.size() % m_InputNames.size() == 0))
+ {
+ throw armnn::InvalidArgumentException(
+ fmt::format("The number of input-tensor-data files ({0}) is not divisible by the "
+ "number of inputs ({1} according to the number of input names).",
+ m_InputTensorDataFilePaths.size(),
+ m_InputNames.size()));
+ }
+
+ if (m_OutputTypes.size() == 0)
+ {
+ //Defaults the value of all outputs to "float"
+ m_OutputTypes.assign(m_OutputNames.size(), "float");
+ }
+ else if ((m_OutputTypes.size() != 0) &&
+ (m_OutputTypes.size() != m_OutputNames.size()))
+ {
+ throw armnn::InvalidArgumentException("output-name and output-type must have the same amount of elements.");
+ }
+
+ // Make sure that the number of output files given is equal to the number of outputs of the model
+ // or equal to the number of outputs of the model multiplied with the number of iterations
+ if (!m_OutputTensorFiles.empty())
+ {
+ if ((m_OutputTensorFiles.size() != m_OutputNames.size()) &&
+ (m_OutputTensorFiles.size() != m_OutputNames.size() * m_Iterations))
+ {
+ std::stringstream errmsg;
+ auto numOutputs = m_OutputNames.size();
+ throw armnn::InvalidArgumentException(
+ fmt::format("The user provided {0} output-tensor files. The only allowed number of output-tensor "
+ "files is the number of outputs of the network ({1} according to the number of "
+ "output names) or the number of outputs multiplied with the number of times the "
+ "network should be executed (NumOutputs * NumIterations = {1} * {2} = {3}).",
+ m_OutputTensorFiles.size(),
+ numOutputs,
+ m_Iterations,
+ numOutputs*m_Iterations));
+ }
}
// Check that threshold time is not less than zero
@@ -181,5 +310,4 @@ armnnDelegate::DelegateOptions ExecuteNetworkParams::ToDelegateOptions() const
return delegateOptions;
}
-
#endif