aboutsummaryrefslogtreecommitdiff
path: root/tests/InferenceModel.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/InferenceModel.hpp')
-rw-r--r--tests/InferenceModel.hpp26
1 files changed, 26 insertions, 0 deletions
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp
index 02511965d9..cf3aae137e 100644
--- a/tests/InferenceModel.hpp
+++ b/tests/InferenceModel.hpp
@@ -312,6 +312,32 @@ public:
armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}};
+ std::map<std::string, armnn::TensorShape> inputShapes;
+ if (!params.m_InputShapes.empty())
+ {
+ const size_t numInputShapes = params.m_InputShapes.size();
+ const size_t numInputBindings = params.m_InputBindings.size();
+ if (numInputShapes < numInputBindings)
+ {
+ throw armnn::Exception(fmt::format(
+ "Not every input has its tensor shape specified: expected={0}, got={1}",
+ numInputBindings, numInputShapes));
+ }
+
+ for (size_t i = 0; i < numInputShapes; i++)
+ {
+ inputShapes[params.m_InputBindings[i]] = params.m_InputShapes[i];
+ }
+
+ {
+ ARMNN_SCOPED_HEAP_PROFILING("Parsing");
+ network = (params.m_IsModelBinary ?
+ parser->CreateNetworkFromBinaryFile(modelPath.c_str(), inputShapes) :
+ parser->CreateNetworkFromTextFile(modelPath.c_str(), inputShapes));
+ }
+ }
+
+ else
{
ARMNN_SCOPED_HEAP_PROFILING("Parsing");
network = (params.m_IsModelBinary ?