aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-10-13 11:44:50 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-10-18 08:32:02 +0000
commit1b11f32dbfea8383956c5d2c60b034469194f6d9 (patch)
tree3bd3f73e9af499778db894c3db18dc7b5f4ee668 /tests
parentea0712e72080b794fa864e67d073d3bfe2eda0f1 (diff)
downloadarmnn-1b11f32dbfea8383956c5d2c60b034469194f6d9.tar.gz
IVGCVSW-6450 Add Support of Models with Dynamic Batch Tensor to ONNX parser
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: Ia7dbf0735619d406d6b4e34a71f14f20d92586e6
Diffstat (limited to 'tests')
-rw-r--r--tests/InferenceModel.hpp26
1 files changed, 26 insertions, 0 deletions
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp
index 02511965d9..cf3aae137e 100644
--- a/tests/InferenceModel.hpp
+++ b/tests/InferenceModel.hpp
@@ -312,6 +312,32 @@ public:
armnn::INetworkPtr network{nullptr, [](armnn::INetwork *){}};
+ std::map<std::string, armnn::TensorShape> inputShapes;
+ if (!params.m_InputShapes.empty())
+ {
+ const size_t numInputShapes = params.m_InputShapes.size();
+ const size_t numInputBindings = params.m_InputBindings.size();
+ if (numInputShapes < numInputBindings)
+ {
+ throw armnn::Exception(fmt::format(
+ "Not every input has its tensor shape specified: expected={0}, got={1}",
+ numInputBindings, numInputShapes));
+ }
+
+ for (size_t i = 0; i < numInputShapes; i++)
+ {
+ inputShapes[params.m_InputBindings[i]] = params.m_InputShapes[i];
+ }
+
+ {
+ ARMNN_SCOPED_HEAP_PROFILING("Parsing");
+ network = (params.m_IsModelBinary ?
+ parser->CreateNetworkFromBinaryFile(modelPath.c_str(), inputShapes) :
+ parser->CreateNetworkFromTextFile(modelPath.c_str(), inputShapes));
+ }
+ }
+
+ else
{
ARMNN_SCOPED_HEAP_PROFILING("Parsing");
network = (params.m_IsModelBinary ?