aboutsummaryrefslogtreecommitdiff
path: root/tests/InferenceModel.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/InferenceModel.hpp')
-rw-r--r--tests/InferenceModel.hpp50
1 files changed, 47 insertions, 3 deletions
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp
index 8645c9041a..8ef17d4df5 100644
--- a/tests/InferenceModel.hpp
+++ b/tests/InferenceModel.hpp
@@ -14,6 +14,8 @@
#include <armnnOnnxParser/IOnnxParser.hpp>
#endif
+#include <BackendRegistry.hpp>
+
#include <boost/exception/exception.hpp>
#include <boost/exception/diagnostic_information.hpp>
#include <boost/log/trivial.hpp>
@@ -22,11 +24,45 @@
#include <boost/filesystem.hpp>
#include <boost/lexical_cast.hpp>
+#include <fstream>
#include <map>
#include <string>
-#include <fstream>
#include <type_traits>
+namespace
+{
+
+inline bool CheckRequestedBackendsAreValid(const std::vector<armnn::BackendId>& backendIds,
+ armnn::Optional<std::string&> invalidBackendIds = armnn::EmptyOptional())
+{
+ if (backendIds.empty())
+ {
+ return false;
+ }
+
+ armnn::BackendIdSet validBackendIds = armnn::BackendRegistryInstance().GetBackendIds();
+
+ bool allValid = true;
+ for (const auto& backendId : backendIds)
+ {
+ if (std::find(validBackendIds.begin(), validBackendIds.end(), backendId) == validBackendIds.end())
+ {
+ allValid = false;
+ if (invalidBackendIds)
+ {
+ if (!invalidBackendIds.value().empty())
+ {
+ invalidBackendIds.value() += ", ";
+ }
+ invalidBackendIds.value() += backendId;
+ }
+ }
+ }
+ return allValid;
+}
+
+} // anonymous namespace
+
namespace InferenceModelInternal
{
// This needs to go when the armnnCaffeParser, armnnTfParser and armnnTfLiteParser
@@ -217,12 +253,14 @@ public:
std::vector<armnn::BackendId> defaultBackends = {armnn::Compute::CpuAcc, armnn::Compute::CpuRef};
+ const std::string backendsMessage = "Which device to run layers on by default. Possible choices: "
+ + armnn::BackendRegistryInstance().GetBackendIdsAsString();
+
desc.add_options()
("model-dir,m", po::value<std::string>(&options.m_ModelDir)->required(),
"Path to directory containing model files (.caffemodel/.prototxt/.tflite)")
("compute,c", po::value<std::vector<armnn::BackendId>>(&options.m_ComputeDevice)->default_value
- (defaultBackends),
- "Which device to run layers on by default. Possible choices: CpuAcc, CpuRef, GpuAcc")
+ (defaultBackends), backendsMessage.c_str())
("visualize-optimized-model,v",
po::value<bool>(&options.m_VisualizePostOptimizationModel)->default_value(false),
"Produce a dot file useful for visualizing the graph post optimization."
@@ -246,6 +284,12 @@ public:
m_Runtime = std::move(armnn::IRuntime::Create(options));
}
+ std::string invalidBackends;
+ if (!CheckRequestedBackendsAreValid(params.m_ComputeDevice, armnn::Optional<std::string&>(invalidBackends)))
+ {
+ throw armnn::Exception("Some backend IDs are invalid: " + invalidBackends);
+ }
+
armnn::INetworkPtr network = CreateNetworkImpl<IParser>::Create(params, m_InputBindingInfo,
m_OutputBindingInfo);