aboutsummaryrefslogtreecommitdiff
path: root/tests/ExecuteNetwork/ExecuteNetworkParams.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/ExecuteNetwork/ExecuteNetworkParams.cpp')
-rw-r--r--tests/ExecuteNetwork/ExecuteNetworkParams.cpp212
1 files changed, 212 insertions, 0 deletions
diff --git a/tests/ExecuteNetwork/ExecuteNetworkParams.cpp b/tests/ExecuteNetwork/ExecuteNetworkParams.cpp
new file mode 100644
index 0000000000..c298bd614a
--- /dev/null
+++ b/tests/ExecuteNetwork/ExecuteNetworkParams.cpp
@@ -0,0 +1,212 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ExecuteNetworkParams.hpp"
+
+#include "NetworkExecutionUtils/NetworkExecutionUtils.hpp"
+#include <InferenceModel.hpp>
+#include <armnn/Logging.hpp>
+
+#include <fmt/format.h>
+
+bool IsModelBinary(const std::string& modelFormat)
+{
+ // Parse model binary flag from the model-format string we got from the command-line
+ if (modelFormat.find("binary") != std::string::npos)
+ {
+ return true;
+ }
+ else if (modelFormat.find("txt") != std::string::npos || modelFormat.find("text") != std::string::npos)
+ {
+ return false;
+ }
+ else
+ {
+ throw armnn::InvalidArgumentException(fmt::format("Unknown model format: '{}'. "
+ "Please include 'binary' or 'text'",
+ modelFormat));
+ }
+}
+
+void CheckModelFormat(const std::string& modelFormat)
+{
+ // Forward to implementation based on the parser type
+ if (modelFormat.find("armnn") != std::string::npos)
+ {
+#if defined(ARMNN_SERIALIZER)
+#else
+ throw armnn::InvalidArgumentException("Can't run model in armnn format without a "
+ "built with serialization support.");
+#endif
+ }
+ else if (modelFormat.find("caffe") != std::string::npos)
+ {
+#if defined(ARMNN_CAFFE_PARSER)
+#else
+ throw armnn::InvalidArgumentException("Can't run model in caffe format without a "
+ "built with Caffe parser support.");
+#endif
+ }
+ else if (modelFormat.find("onnx") != std::string::npos)
+ {
+#if defined(ARMNN_ONNX_PARSER)
+#else
+ throw armnn::InvalidArgumentException("Can't run model in onnx format without a "
+ "built with Onnx parser support.");
+#endif
+ }
+ else if (modelFormat.find("tensorflow") != std::string::npos)
+ {
+#if defined(ARMNN_TF_PARSER)
+#else
+ throw armnn::InvalidArgumentException("Can't run model in onnx format without a "
+ "built with Tensorflow parser support.");
+#endif
+ }
+ else if(modelFormat.find("tflite") != std::string::npos)
+ {
+#if defined(ARMNN_TF_LITE_PARSER)
+ if (!IsModelBinary(modelFormat))
+ {
+ throw armnn::InvalidArgumentException(fmt::format("Unknown model format: '{}'. Only 'binary' format "
+ "supported for tflite files",
+ modelFormat));
+ }
+#else
+ throw armnn::InvalidArgumentException("Can't run model in tflite format without a "
+ "built with Tensorflow Lite parser support.");
+#endif
+ }
+ else
+ {
+ throw armnn::InvalidArgumentException(fmt::format("Unknown model format: '{}'. "
+ "Please include 'caffe', 'tensorflow', 'tflite' or 'onnx'",
+ modelFormat));
+ }
+}
+
+void CheckClTuningParameter(const int& tuningLevel,
+ const std::string& tuningPath,
+ const std::vector<armnn::BackendId> computeDevices)
+{
+ if (!tuningPath.empty())
+ {
+ if(tuningLevel == 0)
+ {
+ ARMNN_LOG(info) << "Using cl tuning file: " << tuningPath << "\n";
+ if(!ValidatePath(tuningPath, true))
+ {
+ throw armnn::InvalidArgumentException("The tuning path is not valid");
+ }
+ }
+ else if ((1 <= tuningLevel) && (tuningLevel <= 3))
+ {
+ ARMNN_LOG(info) << "Starting execution to generate a cl tuning file: " << tuningPath << "\n"
+ << "Tuning level in use: " << tuningLevel << "\n";
+ }
+ else if ((0 < tuningLevel) || (tuningLevel > 3))
+ {
+ throw armnn::InvalidArgumentException(fmt::format("The tuning level {} is not valid.", tuningLevel));
+ }
+
+ // Ensure that a GpuAcc is enabled. Otherwise no tuning data are used or genereted
+ // Only warn if it's not enabled
+ auto it = std::find(computeDevices.begin(), computeDevices.end(), "GpuAcc");
+ if (it == computeDevices.end())
+ {
+ ARMNN_LOG(warning) << "To use Cl Tuning the compute device GpuAcc needs to be active.";
+ }
+ }
+
+
+}
+
+void ExecuteNetworkParams::ValidateParams()
+{
+ // Check compute devices
+ std::string invalidBackends;
+ if (!CheckRequestedBackendsAreValid(m_ComputeDevices, armnn::Optional<std::string&>(invalidBackends)))
+ {
+ throw armnn::InvalidArgumentException(fmt::format("Some of the requested compute devices are invalid. "
+ "\nInvalid devices: {} \nAvailable devices are: {}",
+ invalidBackends,
+ armnn::BackendRegistryInstance().GetBackendIdsAsString()));
+ }
+
+ CheckClTuningParameter(m_TuningLevel, m_TuningPath, m_ComputeDevices);
+
+ // Check turbo modes
+ if (m_EnableBf16TurboMode && m_EnableFp16TurboMode)
+ {
+ throw armnn::InvalidArgumentException("BFloat16 and Float16 turbo mode cannot be enabled at the same time.");
+ }
+
+ m_IsModelBinary = IsModelBinary(m_ModelFormat);
+
+ CheckModelFormat(m_ModelFormat);
+
+ // Check input tensor shapes
+ if ((m_InputTensorShapes.size() != 0) &&
+ (m_InputTensorShapes.size() != m_InputNames.size()))
+ {
+ throw armnn::InvalidArgumentException("input-name and input-tensor-shape must "
+ "have the same amount of elements.");
+ }
+
+ if (m_InputTensorDataFilePaths.size() != 0)
+ {
+ if (!ValidatePaths(m_InputTensorDataFilePaths, true))
+ {
+ throw armnn::InvalidArgumentException("One or more input data file paths are not valid.");
+ }
+
+ if (m_InputTensorDataFilePaths.size() != m_InputNames.size())
+ {
+ throw armnn::InvalidArgumentException("input-name and input-tensor-data must have "
+ "the same amount of elements.");
+ }
+ }
+
+ if ((m_OutputTensorFiles.size() != 0) &&
+ (m_OutputTensorFiles.size() != m_OutputNames.size()))
+ {
+ throw armnn::InvalidArgumentException("output-name and write-outputs-to-file must have the "
+ "same amount of elements.");
+ }
+
+ if (m_InputTypes.size() == 0)
+ {
+ //Defaults the value of all inputs to "float"
+ m_InputTypes.assign(m_InputNames.size(), "float");
+ }
+ else if ((m_InputTypes.size() != 0) &&
+ (m_InputTypes.size() != m_InputNames.size()))
+ {
+ throw armnn::InvalidArgumentException("input-name and input-type must have the same amount of elements.");
+ }
+
+ if (m_OutputTypes.size() == 0)
+ {
+ //Defaults the value of all outputs to "float"
+ m_OutputTypes.assign(m_OutputNames.size(), "float");
+ }
+ else if ((m_OutputTypes.size() != 0) &&
+ (m_OutputTypes.size() != m_OutputNames.size()))
+ {
+ throw armnn::InvalidArgumentException("output-name and output-type must have the same amount of elements.");
+ }
+
+ // Check that threshold time is not less than zero
+ if (m_ThresholdTime < 0)
+ {
+ throw armnn::InvalidArgumentException("Threshold time supplied as a command line argument is less than zero.");
+ }
+
+ // Warn if ExecuteNetwork will generate dummy input data
+ if (m_GenerateTensorData)
+ {
+ ARMNN_LOG(warning) << "No input files provided, input tensors will be filled with 0s.";
+ }
+} \ No newline at end of file