diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-10-13 11:44:50 +0100 |
---|---|---|
committer | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-10-18 08:32:02 +0000 |
commit | 1b11f32dbfea8383956c5d2c60b034469194f6d9 (patch) | |
tree | 3bd3f73e9af499778db894c3db18dc7b5f4ee668 /src/armnnUtils/ParserPrototxtFixture.hpp | |
parent | ea0712e72080b794fa864e67d073d3bfe2eda0f1 (diff) | |
download | armnn-1b11f32dbfea8383956c5d2c60b034469194f6d9.tar.gz |
IVGCVSW-6450 Add Support of Models with Dynamic Batch Tensor to ONNX parser
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: Ia7dbf0735619d406d6b4e34a71f14f20d92586e6
Diffstat (limited to 'src/armnnUtils/ParserPrototxtFixture.hpp')
-rw-r--r-- | src/armnnUtils/ParserPrototxtFixture.hpp | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/src/armnnUtils/ParserPrototxtFixture.hpp b/src/armnnUtils/ParserPrototxtFixture.hpp index 08ac3aeb9b..3c659d3fd6 100644 --- a/src/armnnUtils/ParserPrototxtFixture.hpp +++ b/src/armnnUtils/ParserPrototxtFixture.hpp @@ -42,6 +42,7 @@ struct ParserPrototxtFixture const std::string& outputName); void Setup(const std::map<std::string, armnn::TensorShape>& inputShapes, const std::vector<std::string>& requestedOutputs); + void Setup(const std::map<std::string, armnn::TensorShape>& inputShapes); void Setup(); armnn::IOptimizedNetworkPtr SetupOptimizedNetwork( const std::map<std::string,armnn::TensorShape>& inputShapes, @@ -136,6 +137,23 @@ void ParserPrototxtFixture<TParser>::Setup(const std::map<std::string, armnn::Te } template<typename TParser> +void ParserPrototxtFixture<TParser>::Setup(const std::map<std::string, armnn::TensorShape>& inputShapes) +{ + std::string errorMessage; + + armnn::INetworkPtr network = + m_Parser->CreateNetworkFromString(m_Prototext.c_str(), inputShapes); + auto optimized = Optimize(*network, { armnn::Compute::CpuRef }, m_Runtime->GetDeviceSpec()); + armnn::Status ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, move(optimized), errorMessage); + if (ret != armnn::Status::Success) + { + throw armnn::Exception(fmt::format("LoadNetwork failed with error: '{0}' {1}", + errorMessage, + CHECK_LOCATION().AsString())); + } +} + +template<typename TParser> void ParserPrototxtFixture<TParser>::Setup() { std::string errorMessage; @@ -191,6 +209,15 @@ void ParserPrototxtFixture<TParser>::RunTest(const std::map<std::string, std::ve { armnn::BindingPointInfo bindingInfo = m_Parser->GetNetworkInputBindingInfo(it.first); inputTensors.push_back({ bindingInfo.first, armnn::ConstTensor(bindingInfo.second, it.second.data()) }); + if (bindingInfo.second.GetNumElements() != it.second.size()) + { + throw armnn::Exception(fmt::format("Input tensor {0} is expected to have {1} elements. " + "{2} elements supplied. {3}", + it.first, + bindingInfo.second.GetNumElements(), + it.second.size(), + CHECK_LOCATION().AsString())); + } } // Allocates storage for the output tensors to be written to and sets up the armnn output tensors. |