aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/ParserPrototxtFixture.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnUtils/ParserPrototxtFixture.hpp')
-rw-r--r--src/armnnUtils/ParserPrototxtFixture.hpp27
1 files changed, 27 insertions, 0 deletions
diff --git a/src/armnnUtils/ParserPrototxtFixture.hpp b/src/armnnUtils/ParserPrototxtFixture.hpp
index 08ac3aeb9b..3c659d3fd6 100644
--- a/src/armnnUtils/ParserPrototxtFixture.hpp
+++ b/src/armnnUtils/ParserPrototxtFixture.hpp
@@ -42,6 +42,7 @@ struct ParserPrototxtFixture
const std::string& outputName);
void Setup(const std::map<std::string, armnn::TensorShape>& inputShapes,
const std::vector<std::string>& requestedOutputs);
+ void Setup(const std::map<std::string, armnn::TensorShape>& inputShapes);
void Setup();
armnn::IOptimizedNetworkPtr SetupOptimizedNetwork(
const std::map<std::string,armnn::TensorShape>& inputShapes,
@@ -136,6 +137,23 @@ void ParserPrototxtFixture<TParser>::Setup(const std::map<std::string, armnn::Te
}
template<typename TParser>
+void ParserPrototxtFixture<TParser>::Setup(const std::map<std::string, armnn::TensorShape>& inputShapes)
+{
+ std::string errorMessage;
+
+ armnn::INetworkPtr network =
+ m_Parser->CreateNetworkFromString(m_Prototext.c_str(), inputShapes);
+ auto optimized = Optimize(*network, { armnn::Compute::CpuRef }, m_Runtime->GetDeviceSpec());
+ armnn::Status ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, move(optimized), errorMessage);
+ if (ret != armnn::Status::Success)
+ {
+ throw armnn::Exception(fmt::format("LoadNetwork failed with error: '{0}' {1}",
+ errorMessage,
+ CHECK_LOCATION().AsString()));
+ }
+}
+
+template<typename TParser>
void ParserPrototxtFixture<TParser>::Setup()
{
std::string errorMessage;
@@ -191,6 +209,15 @@ void ParserPrototxtFixture<TParser>::RunTest(const std::map<std::string, std::ve
{
armnn::BindingPointInfo bindingInfo = m_Parser->GetNetworkInputBindingInfo(it.first);
inputTensors.push_back({ bindingInfo.first, armnn::ConstTensor(bindingInfo.second, it.second.data()) });
+ if (bindingInfo.second.GetNumElements() != it.second.size())
+ {
+ throw armnn::Exception(fmt::format("Input tensor {0} is expected to have {1} elements. "
+ "{2} elements supplied. {3}",
+ it.first,
+ bindingInfo.second.GetNumElements(),
+ it.second.size(),
+ CHECK_LOCATION().AsString()));
+ }
}
// Allocates storage for the output tensors to be written to and sets up the armnn output tensors.