aboutsummaryrefslogtreecommitdiff
path: root/tests/InferenceModel.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/InferenceModel.hpp')
-rw-r--r--tests/InferenceModel.hpp61
1 files changed, 38 insertions, 23 deletions
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp
index 7e338669c7..eb5f708c81 100644
--- a/tests/InferenceModel.hpp
+++ b/tests/InferenceModel.hpp
@@ -24,6 +24,7 @@
#include <boost/program_options.hpp>
#include <boost/filesystem.hpp>
#include <boost/lexical_cast.hpp>
+#include <boost/variant.hpp>
#include <algorithm>
#include <iterator>
@@ -266,13 +267,17 @@ inline armnn::InputTensors MakeInputTensors(
const InferenceModelInternal::BindingPointInfo& inputBinding = inputBindings[i];
const TContainer& inputData = inputDataContainers[i];
- if (inputData.size() != inputBinding.second.GetNumElements())
- {
- throw armnn::Exception("Input tensor has incorrect size");
- }
-
- armnn::ConstTensor inputTensor(inputBinding.second, inputData.data());
- inputTensors.push_back(std::make_pair(inputBinding.first, inputTensor));
+ boost::apply_visitor([&](auto&& value)
+ {
+ if (value.size() != inputBinding.second.GetNumElements())
+ {
+ throw armnn::Exception("Input tensor has incorrect size");
+ }
+
+ armnn::ConstTensor inputTensor(inputBinding.second, value.data());
+ inputTensors.push_back(std::make_pair(inputBinding.first, inputTensor));
+ },
+ inputData);
}
return inputTensors;
@@ -297,13 +302,17 @@ inline armnn::OutputTensors MakeOutputTensors(
const InferenceModelInternal::BindingPointInfo& outputBinding = outputBindings[i];
TContainer& outputData = outputDataContainers[i];
- if (outputData.size() != outputBinding.second.GetNumElements())
- {
- throw armnn::Exception("Output tensor has incorrect size");
- }
-
- armnn::Tensor outputTensor(outputBinding.second, outputData.data());
- outputTensors.push_back(std::make_pair(outputBinding.first, outputTensor));
+ boost::apply_visitor([&](auto&& value)
+ {
+ if (value.size() != outputBinding.second.GetNumElements())
+ {
+ throw armnn::Exception("Output tensor has incorrect size");
+ }
+
+ armnn::Tensor outputTensor(outputBinding.second, value.data());
+ outputTensors.push_back(std::make_pair(outputBinding.first, outputTensor));
+ },
+ outputData);
}
return outputTensors;
@@ -317,7 +326,7 @@ public:
using Params = InferenceModelInternal::Params;
using BindingPointInfo = InferenceModelInternal::BindingPointInfo;
using QuantizationParams = InferenceModelInternal::QuantizationParams;
- using TContainer = std::vector<TDataType>;
+ using TContainer = boost::variant<std::vector<float>, std::vector<int>, std::vector<unsigned char>>;
struct CommandLineOptions
{
@@ -439,16 +448,22 @@ public:
void Run(const std::vector<TContainer>& inputContainers, std::vector<TContainer>& outputContainers)
{
- for (unsigned int i = 0; i < outputContainers.size(); i++)
+ for (unsigned int i = 0; i < outputContainers.size(); ++i)
{
const unsigned int expectedOutputDataSize = GetOutputSize(i);
- const unsigned int actualOutputDataSize = boost::numeric_cast<unsigned int>(outputContainers[i].size());
- if (actualOutputDataSize < expectedOutputDataSize)
+
+ boost::apply_visitor([expectedOutputDataSize, i](auto&& value)
{
- unsigned int outputIndex = boost::numeric_cast<unsigned int>(i);
- throw armnn::Exception(boost::str(boost::format("Not enough data for output #%1%: expected "
- "%2% elements, got %3%") % outputIndex % expectedOutputDataSize % actualOutputDataSize));
- }
+ const unsigned int actualOutputDataSize = boost::numeric_cast<unsigned int>(value.size());
+ if (actualOutputDataSize < expectedOutputDataSize)
+ {
+ unsigned int outputIndex = boost::numeric_cast<unsigned int>(i);
+ throw armnn::Exception(
+ boost::str(boost::format("Not enough data for output #%1%: expected "
+ "%2% elements, got %3%") % outputIndex % expectedOutputDataSize % actualOutputDataSize));
+ }
+ },
+ outputContainers[i]);
}
std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkIdentifier);
@@ -531,4 +546,4 @@ private:
{
return ::MakeOutputTensors(m_OutputBindings, outputDataContainers);
}
-}; \ No newline at end of file
+};