diff options
Diffstat (limited to 'tests/InferenceModel.hpp')
-rw-r--r-- | tests/InferenceModel.hpp | 50 |
1 files changed, 47 insertions, 3 deletions
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp index 8645c9041a..8ef17d4df5 100644 --- a/tests/InferenceModel.hpp +++ b/tests/InferenceModel.hpp @@ -14,6 +14,8 @@ #include <armnnOnnxParser/IOnnxParser.hpp> #endif +#include <BackendRegistry.hpp> + #include <boost/exception/exception.hpp> #include <boost/exception/diagnostic_information.hpp> #include <boost/log/trivial.hpp> @@ -22,11 +24,45 @@ #include <boost/filesystem.hpp> #include <boost/lexical_cast.hpp> +#include <fstream> #include <map> #include <string> -#include <fstream> #include <type_traits> +namespace +{ + +inline bool CheckRequestedBackendsAreValid(const std::vector<armnn::BackendId>& backendIds, + armnn::Optional<std::string&> invalidBackendIds = armnn::EmptyOptional()) +{ + if (backendIds.empty()) + { + return false; + } + + armnn::BackendIdSet validBackendIds = armnn::BackendRegistryInstance().GetBackendIds(); + + bool allValid = true; + for (const auto& backendId : backendIds) + { + if (std::find(validBackendIds.begin(), validBackendIds.end(), backendId) == validBackendIds.end()) + { + allValid = false; + if (invalidBackendIds) + { + if (!invalidBackendIds.value().empty()) + { + invalidBackendIds.value() += ", "; + } + invalidBackendIds.value() += backendId; + } + } + } + return allValid; +} + +} // anonymous namespace + namespace InferenceModelInternal { // This needs to go when the armnnCaffeParser, armnnTfParser and armnnTfLiteParser @@ -217,12 +253,14 @@ public: std::vector<armnn::BackendId> defaultBackends = {armnn::Compute::CpuAcc, armnn::Compute::CpuRef}; + const std::string backendsMessage = "Which device to run layers on by default. Possible choices: " + + armnn::BackendRegistryInstance().GetBackendIdsAsString(); + desc.add_options() ("model-dir,m", po::value<std::string>(&options.m_ModelDir)->required(), "Path to directory containing model files (.caffemodel/.prototxt/.tflite)") ("compute,c", po::value<std::vector<armnn::BackendId>>(&options.m_ComputeDevice)->default_value - (defaultBackends), - "Which device to run layers on by default. Possible choices: CpuAcc, CpuRef, GpuAcc") + (defaultBackends), backendsMessage.c_str()) ("visualize-optimized-model,v", po::value<bool>(&options.m_VisualizePostOptimizationModel)->default_value(false), "Produce a dot file useful for visualizing the graph post optimization." @@ -246,6 +284,12 @@ public: m_Runtime = std::move(armnn::IRuntime::Create(options)); } + std::string invalidBackends; + if (!CheckRequestedBackendsAreValid(params.m_ComputeDevice, armnn::Optional<std::string&>(invalidBackends))) + { + throw armnn::Exception("Some backend IDs are invalid: " + invalidBackends); + } + armnn::INetworkPtr network = CreateNetworkImpl<IParser>::Create(params, m_InputBindingInfo, m_OutputBindingInfo); |