diff options
Diffstat (limited to 'tests/ExecuteNetwork/ExecuteNetworkParams.cpp')
-rw-r--r-- | tests/ExecuteNetwork/ExecuteNetworkParams.cpp | 134 |
1 files changed, 131 insertions, 3 deletions
diff --git a/tests/ExecuteNetwork/ExecuteNetworkParams.cpp b/tests/ExecuteNetwork/ExecuteNetworkParams.cpp index f341c30738..cc75bb4323 100644 --- a/tests/ExecuteNetwork/ExecuteNetworkParams.cpp +++ b/tests/ExecuteNetwork/ExecuteNetworkParams.cpp @@ -1,15 +1,76 @@ // -// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "ExecuteNetworkParams.hpp" #include "NetworkExecutionUtils/NetworkExecutionUtils.hpp" +#include <InferenceModel.hpp> #include <armnn/Logging.hpp> #include <fmt/format.h> -#include <armnnUtils/Filesystem.hpp> + +bool IsModelBinary(const std::string& modelFormat) +{ + // Parse model binary flag from the model-format string we got from the command-line + if (modelFormat.find("binary") != std::string::npos) + { + return true; + } + else if (modelFormat.find("txt") != std::string::npos || modelFormat.find("text") != std::string::npos) + { + return false; + } + else + { + throw armnn::InvalidArgumentException(fmt::format("Unknown model format: '{}'. " + "Please include 'binary' or 'text'", + modelFormat)); + } +} + +void CheckModelFormat(const std::string& modelFormat) +{ + // Forward to implementation based on the parser type + if (modelFormat.find("armnn") != std::string::npos) + { +#if defined(ARMNN_SERIALIZER) +#else + throw armnn::InvalidArgumentException("Can't run model in armnn format without a " + "built with serialization support."); +#endif + } + else if (modelFormat.find("onnx") != std::string::npos) + { +#if defined(ARMNN_ONNX_PARSER) +#else + throw armnn::InvalidArgumentException("Can't run model in onnx format without a " + "built with Onnx parser support."); +#endif + } + else if (modelFormat.find("tflite") != std::string::npos) + { +#if defined(ARMNN_TF_LITE_PARSER) + if (!IsModelBinary(modelFormat)) + { + throw armnn::InvalidArgumentException(fmt::format("Unknown model format: '{}'. Only 'binary' " + "format supported for tflite files", + modelFormat)); + } +#elif defined(ARMNN_TFLITE_DELEGATE) +#else + throw armnn::InvalidArgumentException("Can't run model in tflite format without a " + "built with Tensorflow Lite parser support."); +#endif + } + else + { + throw armnn::InvalidArgumentException(fmt::format("Unknown model format: '{}'. " + "Please include 'tflite' or 'onnx'", + modelFormat)); + } +} void CheckClTuningParameter(const int& tuningLevel, const std::string& tuningPath, @@ -44,6 +105,7 @@ void CheckClTuningParameter(const int& tuningLevel, ARMNN_LOG(warning) << "To use Cl Tuning the compute device GpuAcc needs to be active."; } } + } void ExecuteNetworkParams::ValidateParams() @@ -58,6 +120,7 @@ void ExecuteNetworkParams::ValidateParams() << invalidBackends; } } + CheckClTuningParameter(m_TuningLevel, m_TuningPath, m_ComputeDevices); if (m_EnableBf16TurboMode && m_EnableFp16TurboMode) @@ -66,6 +129,10 @@ void ExecuteNetworkParams::ValidateParams() "enabled at the same time."); } + m_IsModelBinary = IsModelBinary(m_ModelFormat); + + CheckModelFormat(m_ModelFormat); + // Check input tensor shapes if ((m_InputTensorShapes.size() != 0) && (m_InputTensorShapes.size() != m_InputNames.size())) @@ -90,6 +157,68 @@ void ExecuteNetworkParams::ValidateParams() m_InputNames.size(), m_InputTensorDataFilePaths.size())); } + else if (m_InputTensorDataFilePaths.size() % m_InputNames.size() != 0) + { + throw armnn::InvalidArgumentException( + fmt::format("According to the number of input names the user provided the network has {} " + "inputs. The user specified {} input-tensor-data file paths which is not " + "divisible by the number of inputs.", + m_InputNames.size(), + m_InputTensorDataFilePaths.size())); + } + } + + if (m_InputTypes.size() == 0) + { + //Defaults the value of all inputs to "float" + m_InputTypes.assign(m_InputNames.size(), "float"); + } + else if ((m_InputTypes.size() != 0) && + (m_InputTypes.size() != m_InputNames.size())) + { + throw armnn::InvalidArgumentException("input-name and input-type must have the same amount of elements."); + } + + // Make sure that the number of input files given is divisible by the number of inputs of the model + if (!(m_InputTensorDataFilePaths.size() % m_InputNames.size() == 0)) + { + throw armnn::InvalidArgumentException( + fmt::format("The number of input-tensor-data files ({0}) is not divisible by the " + "number of inputs ({1} according to the number of input names).", + m_InputTensorDataFilePaths.size(), + m_InputNames.size())); + } + + if (m_OutputTypes.size() == 0) + { + //Defaults the value of all outputs to "float" + m_OutputTypes.assign(m_OutputNames.size(), "float"); + } + else if ((m_OutputTypes.size() != 0) && + (m_OutputTypes.size() != m_OutputNames.size())) + { + throw armnn::InvalidArgumentException("output-name and output-type must have the same amount of elements."); + } + + // Make sure that the number of output files given is equal to the number of outputs of the model + // or equal to the number of outputs of the model multiplied with the number of iterations + if (!m_OutputTensorFiles.empty()) + { + if ((m_OutputTensorFiles.size() != m_OutputNames.size()) && + (m_OutputTensorFiles.size() != m_OutputNames.size() * m_Iterations)) + { + std::stringstream errmsg; + auto numOutputs = m_OutputNames.size(); + throw armnn::InvalidArgumentException( + fmt::format("The user provided {0} output-tensor files. The only allowed number of output-tensor " + "files is the number of outputs of the network ({1} according to the number of " + "output names) or the number of outputs multiplied with the number of times the " + "network should be executed (NumOutputs * NumIterations = {1} * {2} = {3}).", + m_OutputTensorFiles.size(), + numOutputs, + m_Iterations, + numOutputs*m_Iterations)); + } } // Check that threshold time is not less than zero @@ -181,5 +310,4 @@ armnnDelegate::DelegateOptions ExecuteNetworkParams::ToDelegateOptions() const return delegateOptions; } - #endif |