diff options
Diffstat (limited to 'tests/InferenceModel.hpp')
-rw-r--r-- | tests/InferenceModel.hpp | 61 |
1 files changed, 38 insertions, 23 deletions
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp index 7e338669c7..eb5f708c81 100644 --- a/tests/InferenceModel.hpp +++ b/tests/InferenceModel.hpp @@ -24,6 +24,7 @@ #include <boost/program_options.hpp> #include <boost/filesystem.hpp> #include <boost/lexical_cast.hpp> +#include <boost/variant.hpp> #include <algorithm> #include <iterator> @@ -266,13 +267,17 @@ inline armnn::InputTensors MakeInputTensors( const InferenceModelInternal::BindingPointInfo& inputBinding = inputBindings[i]; const TContainer& inputData = inputDataContainers[i]; - if (inputData.size() != inputBinding.second.GetNumElements()) - { - throw armnn::Exception("Input tensor has incorrect size"); - } - - armnn::ConstTensor inputTensor(inputBinding.second, inputData.data()); - inputTensors.push_back(std::make_pair(inputBinding.first, inputTensor)); + boost::apply_visitor([&](auto&& value) + { + if (value.size() != inputBinding.second.GetNumElements()) + { + throw armnn::Exception("Input tensor has incorrect size"); + } + + armnn::ConstTensor inputTensor(inputBinding.second, value.data()); + inputTensors.push_back(std::make_pair(inputBinding.first, inputTensor)); + }, + inputData); } return inputTensors; @@ -297,13 +302,17 @@ inline armnn::OutputTensors MakeOutputTensors( const InferenceModelInternal::BindingPointInfo& outputBinding = outputBindings[i]; TContainer& outputData = outputDataContainers[i]; - if (outputData.size() != outputBinding.second.GetNumElements()) - { - throw armnn::Exception("Output tensor has incorrect size"); - } - - armnn::Tensor outputTensor(outputBinding.second, outputData.data()); - outputTensors.push_back(std::make_pair(outputBinding.first, outputTensor)); + boost::apply_visitor([&](auto&& value) + { + if (value.size() != outputBinding.second.GetNumElements()) + { + throw armnn::Exception("Output tensor has incorrect size"); + } + + armnn::Tensor outputTensor(outputBinding.second, value.data()); + outputTensors.push_back(std::make_pair(outputBinding.first, outputTensor)); + }, + outputData); } return outputTensors; @@ -317,7 +326,7 @@ public: using Params = InferenceModelInternal::Params; using BindingPointInfo = InferenceModelInternal::BindingPointInfo; using QuantizationParams = InferenceModelInternal::QuantizationParams; - using TContainer = std::vector<TDataType>; + using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>; struct CommandLineOptions { @@ -439,16 +448,22 @@ public: void Run(const std::vector<TContainer>& inputContainers, std::vector<TContainer>& outputContainers) { - for (unsigned int i = 0; i < outputContainers.size(); i++) + for (unsigned int i = 0; i < outputContainers.size(); ++i) { const unsigned int expectedOutputDataSize = GetOutputSize(i); - const unsigned int actualOutputDataSize = boost::numeric_cast<unsigned int>(outputContainers[i].size()); - if (actualOutputDataSize < expectedOutputDataSize) + + boost::apply_visitor([expectedOutputDataSize, i](auto&& value) { - unsigned int outputIndex = boost::numeric_cast<unsigned int>(i); - throw armnn::Exception(boost::str(boost::format("Not enough data for output #%1%: expected " - "%2% elements, got %3%") % outputIndex % expectedOutputDataSize % actualOutputDataSize)); - } + const unsigned int actualOutputDataSize = boost::numeric_cast<unsigned int>(value.size()); + if (actualOutputDataSize < expectedOutputDataSize) + { + unsigned int outputIndex = boost::numeric_cast<unsigned int>(i); + throw armnn::Exception( + boost::str(boost::format("Not enough data for output #%1%: expected " + "%2% elements, got %3%") % outputIndex % expectedOutputDataSize % actualOutputDataSize)); + } + }, + outputContainers[i]); } std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkIdentifier); @@ -531,4 +546,4 @@ private: { return ::MakeOutputTensors(m_OutputBindings, outputDataContainers); } -};
\ No newline at end of file +}; |