aboutsummaryrefslogtreecommitdiff
path: root/src/armnnOnnxParser/OnnxParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnOnnxParser/OnnxParser.cpp')
-rw-r--r--src/armnnOnnxParser/OnnxParser.cpp110
1 files changed, 101 insertions, 9 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 3588975897..eb24bb5425 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -60,6 +60,25 @@ armnn::INetworkPtr IOnnxParser::CreateNetworkFromString(const std::string& proto
return pOnnxParserImpl->CreateNetworkFromString(protoText);
}
+armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinaryFile(
+ const char* graphFile,
+ const std::map<std::string, armnn::TensorShape>& inputShapes)
+{
+ return pOnnxParserImpl->CreateNetworkFromBinaryFile(graphFile, inputShapes);
+}
+
+armnn::INetworkPtr IOnnxParser::CreateNetworkFromTextFile(const char* graphFile,
+ const std::map<std::string, armnn::TensorShape>& inputShapes)
+{
+ return pOnnxParserImpl->CreateNetworkFromTextFile(graphFile, inputShapes);
+}
+
+armnn::INetworkPtr IOnnxParser::CreateNetworkFromString(const std::string& protoText,
+ const std::map<std::string, armnn::TensorShape>& inputShapes)
+{
+ return pOnnxParserImpl->CreateNetworkFromString(protoText, inputShapes);
+}
+
BindingPointInfo IOnnxParser::GetNetworkInputBindingInfo(const std::string& name) const
{
return pOnnxParserImpl->GetNetworkInputBindingInfo(name);
@@ -287,12 +306,18 @@ armnn::TensorInfo ToTensorInfo(const std::string& name, std::vector<unsigned int
}
}
- // To avoid crashes by trivial tensors
+ // Scalar Tensor
if (shape.empty())
{
return TensorInfo(TensorShape(Dimensionality::Scalar), type);
}
+ // Dynamic Tensor
+ if(std::find(shape.begin(), shape.end(), 0) != shape.end())
+ {
+ return TensorInfo(TensorShape(Dimensionality::NotSpecified), type);
+ }
+
return TensorInfo(TensorShape(static_cast<unsigned int>(shape.size()), shape.data()), type);
}
@@ -469,7 +494,9 @@ std::vector<TensorInfo> OnnxParserImpl::ComputeOutputInfo(std::vector<std::strin
outNames.end(),
[this](std::string name)
{
- return (m_TensorsInfo.count(name) == 0 || m_TensorsInfo[name].m_info == nullptr);
+ return (m_TensorsInfo.count(name) == 0 || m_TensorsInfo[name].m_info == nullptr
+ || m_TensorsInfo[name].m_info->GetShape().GetDimensionality() ==
+ Dimensionality::NotSpecified);
});
std::vector<TensorInfo> outInfo;
//if the output info(s) are not here, we need to compute them
@@ -521,6 +548,8 @@ void OnnxParserImpl::ResetParser()
{
m_Network = armnn::INetworkPtr(nullptr, nullptr);
m_Graph = nullptr;
+ m_InputInfos.clear();
+ m_OutputInfos.clear();
}
void OnnxParserImpl::Cleanup()
@@ -529,6 +558,7 @@ void OnnxParserImpl::Cleanup()
m_TensorsInfo.clear();
m_OutputsMap.clear();
m_OutputsFusedAndUsed.clear();
+ m_InputShapes.clear();
}
template<typename T>
@@ -692,6 +722,14 @@ INetworkPtr OnnxParserImpl::CreateNetworkFromTextFile(const char* graphFile)
return CreateNetworkFromModel(*modelProto);
}
+INetworkPtr OnnxParserImpl::CreateNetworkFromTextFile(const char* graphFile,
+ const std::map<std::string, armnn::TensorShape>& inputShapes)
+{
+ ResetParser();
+ m_InputShapes = inputShapes;
+ ModelPtr modelProto = LoadModelFromTextFile(graphFile);
+ return CreateNetworkFromModel(*modelProto);
+}
ModelPtr OnnxParserImpl::LoadModelFromBinaryFile(const char* graphFile)
{
@@ -728,6 +766,15 @@ INetworkPtr OnnxParserImpl::CreateNetworkFromBinaryFile(const char* graphFile)
return CreateNetworkFromModel(*modelProto);
}
+INetworkPtr OnnxParserImpl::CreateNetworkFromBinaryFile(const char* graphFile,
+ const std::map<std::string, armnn::TensorShape>& inputShapes)
+{
+ ResetParser();
+ m_InputShapes = inputShapes;
+ ModelPtr modelProto = LoadModelFromBinaryFile(graphFile);
+ return CreateNetworkFromModel(*modelProto);
+}
+
ModelPtr OnnxParserImpl::LoadModelFromString(const std::string& protoText)
{
if (protoText == "")
@@ -754,6 +801,15 @@ INetworkPtr OnnxParserImpl::CreateNetworkFromString(const std::string& protoText
return CreateNetworkFromModel(*modelProto);
}
+INetworkPtr OnnxParserImpl::CreateNetworkFromString(const std::string& protoText,
+ const std::map<std::string, armnn::TensorShape>& inputShapes)
+{
+ ResetParser();
+ m_InputShapes = inputShapes;
+ ModelPtr modelProto = LoadModelFromString(protoText);
+ return CreateNetworkFromModel(*modelProto);
+}
+
INetworkPtr OnnxParserImpl::CreateNetworkFromModel(onnx::ModelProto& model)
{
m_Network = INetwork::Create();
@@ -843,6 +899,13 @@ void OnnxParserImpl::LoadGraph()
}
}
}
+
+ // Get output info.
+ for(int outputIndex = 0; outputIndex < m_Graph->output_size(); ++outputIndex)
+ {
+ auto output = m_Graph->output(outputIndex);
+ m_OutputInfos[output.name()] = *m_TensorsInfo[output.name()].m_info;
+ }
}
void OnnxParserImpl::SetupInfo(const google::protobuf::RepeatedPtrField<onnx::ValueInfoProto >* list)
@@ -2172,13 +2235,31 @@ void OnnxParserImpl::SetupInputLayers()
for(int inputIndex = 0; inputIndex < m_Graph->input_size(); ++inputIndex)
{
auto input = m_Graph->input(inputIndex);
- if (! m_TensorsInfo[input.name()].isConstant())
+ if (!m_TensorsInfo[input.name()].isConstant())
{
IConnectableLayer* layer =
- m_Network->AddInputLayer(static_cast<armnn::LayerBindingId>(inputIndex), input.name().c_str());
- auto tensorInfo = ToTensorInfo(input);
+ m_Network->AddInputLayer(static_cast<armnn::LayerBindingId>(inputIndex), input.name().c_str());
+ TensorInfo tensorInfo = *m_TensorsInfo[input.name()].m_info;
+ if (tensorInfo.GetShape().GetDimensionality() == Dimensionality::NotSpecified)
+ {
+ if (m_InputShapes.find(input.name()) == m_InputShapes.end())
+ {
+ throw ParseException(fmt::format("The parser does not support dynamic tensor, "
+ "please specify input shape for {}. {}",
+ input.name(),
+ CHECK_LOCATION().AsString()));
+ }
+ else
+ {
+ tensorInfo.SetShape(m_InputShapes[input.name()]);
+ m_TensorsInfo[input.name()].m_info = std::make_unique<TensorInfo>(tensorInfo);
+ }
+
+ }
layer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
+ m_InputInfos[input.name()] = tensorInfo;
+
RegisterOutputSlots(layer,{ input.name() });
}
}
@@ -2211,7 +2292,7 @@ void OnnxParserImpl::RegisterInputSlot(IConnectableLayer* layer,
if (it == m_TensorConnections.end())
{
- //First time seing this tensor, we need to map it
+ //First time seeing this tensor, we need to map it
m_TensorConnections[tensorId] = TensorSlots();
}
m_TensorConnections[tensorId].inputSlots.push_back(slot);
@@ -2238,7 +2319,7 @@ void OnnxParserImpl::RegisterInputSlots(IConnectableLayer* layer, const std::vec
if (it == m_TensorConnections.end())
{
- //First time seing this tensor, we need to map it
+ // First time seing this tensor, we need to map it
m_TensorConnections[tensorId] = TensorSlots();
}
m_TensorConnections[tensorId].inputSlots.push_back(slot);
@@ -2282,6 +2363,7 @@ void OnnxParserImpl::RegisterOutputSlots(IConnectableLayer* layer, const std::ve
}
tensorSlots.outputSlot = slot;
}
+
}
BindingPointInfo OnnxParserImpl::GetNetworkInputBindingInfo(const std::string& name) const
@@ -2291,7 +2373,12 @@ BindingPointInfo OnnxParserImpl::GetNetworkInputBindingInfo(const std::string& n
auto input = m_Graph->input(i);
if(input.name() == name)
{
- return std::make_pair(static_cast<armnn::LayerBindingId>(i), ToTensorInfo(input));
+ auto it = m_InputInfos.find(name);
+
+ if (it != m_InputInfos.end())
+ {
+ return std::make_pair(static_cast<armnn::LayerBindingId>(i), it->second);
+ }
}
}
throw InvalidArgumentException(fmt::format("The input layer '{}' does not exist {}",
@@ -2305,7 +2392,12 @@ BindingPointInfo OnnxParserImpl::GetNetworkOutputBindingInfo(const std::string&
auto output = m_Graph->output(i);
if(output.name() == name)
{
- return std::make_pair(static_cast<armnn::LayerBindingId>(i), ToTensorInfo(output));
+ auto it = m_OutputInfos.find(name);
+
+ if (it != m_OutputInfos.end())
+ {
+ return std::make_pair(static_cast<armnn::LayerBindingId>(i), it->second);
+ }
}
}
throw InvalidArgumentException(fmt::format("The output layer '{}' does not exist {}",