From ef33cb192eef332fb3a26be742b341288421e5bc Mon Sep 17 00:00:00 2001 From: Kevin May Date: Fri, 29 Jan 2021 14:24:57 +0000 Subject: IVGCVSW-5592 Implement Pimpl Idiom for Caffe and Onnx Parsers Signed-off-by: Kevin May Change-Id: I760dc4f33c0f87113cda2fa924da70f2e8c19025 --- src/armnnOnnxParser/OnnxParser.cpp | 208 +++++++++++++++++++++---------------- 1 file changed, 119 insertions(+), 89 deletions(-) (limited to 'src/armnnOnnxParser/OnnxParser.cpp') diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index f3d0a73342..9f5aa1975a 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -20,6 +20,51 @@ using namespace armnn; namespace armnnOnnxParser { + +IOnnxParser::IOnnxParser() : pOnnxParserImpl(new OnnxParserImpl()) {} + +IOnnxParser::~IOnnxParser() = default; + +IOnnxParser* IOnnxParser::CreateRaw() +{ + return new IOnnxParser(); +} + +IOnnxParserPtr IOnnxParser::Create() +{ + return IOnnxParserPtr(CreateRaw(), &IOnnxParser::Destroy); +} + +void IOnnxParser::Destroy(IOnnxParser* parser) +{ + delete parser; +} + +armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinaryFile(const char* graphFile) +{ + return pOnnxParserImpl->CreateNetworkFromBinaryFile(graphFile); +} + +armnn::INetworkPtr IOnnxParser::CreateNetworkFromTextFile(const char* graphFile) +{ + return pOnnxParserImpl->CreateNetworkFromTextFile(graphFile); +} + +armnn::INetworkPtr IOnnxParser::CreateNetworkFromString(const std::string& protoText) +{ + return pOnnxParserImpl->CreateNetworkFromString(protoText); +} + +BindingPointInfo IOnnxParser::GetNetworkInputBindingInfo(const std::string& name) const +{ + return pOnnxParserImpl->GetNetworkInputBindingInfo(name); +} + +BindingPointInfo IOnnxParser::GetNetworkOutputBindingInfo(const std::string& name) const +{ + return pOnnxParserImpl->GetNetworkOutputBindingInfo(name); +} + namespace { void CheckValidDataType(std::initializer_list validInputTypes, @@ -357,25 +402,25 @@ TensorInfo ComputeReshapeInfo(const TensorShape& targetShapeTensor, } //namespace -const std::map OnnxParser::m_ParserFunctions = { - { "BatchNormalization", &OnnxParser::ParseBatchNormalization}, - { "GlobalAveragePool", &OnnxParser::ParseGlobalAveragePool}, - { "AveragePool", &OnnxParser::ParseAveragePool }, - { "Clip", &OnnxParser::ParseClip }, - { "Constant", &OnnxParser::ParseConstant }, - { "MaxPool", &OnnxParser::ParseMaxPool }, - { "Reshape", &OnnxParser::ParseReshape }, - { "Sigmoid", &OnnxParser::ParseSigmoid }, - { "Tanh", &OnnxParser::ParseTanh }, - { "Relu", &OnnxParser::ParseRelu }, - { "LeakyRelu", &OnnxParser::ParseLeakyRelu }, - { "Conv", &OnnxParser::ParseConv }, - { "Add", &OnnxParser::ParseAdd }, - { "Flatten", &OnnxParser::ParseFlatten}, +const std::map OnnxParserImpl::m_ParserFunctions = { + { "BatchNormalization", &OnnxParserImpl::ParseBatchNormalization}, + { "GlobalAveragePool", &OnnxParserImpl::ParseGlobalAveragePool}, + { "AveragePool", &OnnxParserImpl::ParseAveragePool }, + { "Clip", &OnnxParserImpl::ParseClip }, + { "Constant", &OnnxParserImpl::ParseConstant }, + { "MaxPool", &OnnxParserImpl::ParseMaxPool }, + { "Reshape", &OnnxParserImpl::ParseReshape }, + { "Sigmoid", &OnnxParserImpl::ParseSigmoid }, + { "Tanh", &OnnxParserImpl::ParseTanh }, + { "Relu", &OnnxParserImpl::ParseRelu }, + { "LeakyRelu", &OnnxParserImpl::ParseLeakyRelu }, + { "Conv", &OnnxParserImpl::ParseConv }, + { "Add", &OnnxParserImpl::ParseAdd }, + { "Flatten", &OnnxParserImpl::ParseFlatten}, }; template -void OnnxParser::ValidateInputs(const onnx::NodeProto& node, +void OnnxParserImpl::ValidateInputs(const onnx::NodeProto& node, TypePair validInputs, const Location& location) { @@ -391,13 +436,13 @@ void OnnxParser::ValidateInputs(const onnx::NodeProto& node, } #define VALID_INPUTS(NODE, VALID_INPUTS) \ - OnnxParser::ValidateInputs(NODE, \ + OnnxParserImpl::ValidateInputs(NODE, \ VALID_INPUTS, \ CHECK_LOCATION()) -std::vector OnnxParser::ComputeOutputInfo(std::vector outNames, - const IConnectableLayer* layer, - std::vector inputShapes) +std::vector OnnxParserImpl::ComputeOutputInfo(std::vector outNames, + const IConnectableLayer* layer, + std::vector inputShapes) { ARMNN_ASSERT(! outNames.empty()); bool needCompute = std::any_of(outNames.begin(), @@ -427,33 +472,18 @@ std::vector OnnxParser::ComputeOutputInfo(std::vector o return outInfo; } -IOnnxParser* IOnnxParser::CreateRaw() -{ - return new OnnxParser(); -} - -IOnnxParserPtr IOnnxParser::Create() -{ - return IOnnxParserPtr(CreateRaw(), &IOnnxParser::Destroy); -} - -void IOnnxParser::Destroy(IOnnxParser* parser) -{ - delete parser; -} - -OnnxParser::OnnxParser() +OnnxParserImpl::OnnxParserImpl() : m_Network(nullptr, nullptr) { } -void OnnxParser::ResetParser() +void OnnxParserImpl::ResetParser() { m_Network = armnn::INetworkPtr(nullptr, nullptr); m_Graph = nullptr; } -void OnnxParser::Cleanup() +void OnnxParserImpl::Cleanup() { m_TensorConnections.clear(); m_TensorsInfo.clear(); @@ -461,7 +491,7 @@ void OnnxParser::Cleanup() m_OutputsFusedAndUsed.clear(); } -std::pair> OnnxParser::CreateConstTensor(const std::string name) +std::pair> OnnxParserImpl::CreateConstTensor(const std::string name) { const TensorInfo tensorInfo = *m_TensorsInfo[name].m_info; onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor; @@ -499,7 +529,7 @@ std::pair> OnnxParser::CreateConstTensor(c return std::make_pair(ConstTensor(tensorInfo, tensorData.get()), std::move(tensorData)); } -ModelPtr OnnxParser::LoadModelFromTextFile(const char* graphFile) +ModelPtr OnnxParserImpl::LoadModelFromTextFile(const char* graphFile) { FILE* fd = fopen(graphFile, "r"); @@ -524,7 +554,7 @@ ModelPtr OnnxParser::LoadModelFromTextFile(const char* graphFile) return modelProto; } -INetworkPtr OnnxParser::CreateNetworkFromTextFile(const char* graphFile) +INetworkPtr OnnxParserImpl::CreateNetworkFromTextFile(const char* graphFile) { ResetParser(); ModelPtr modelProto = LoadModelFromTextFile(graphFile); @@ -532,7 +562,7 @@ INetworkPtr OnnxParser::CreateNetworkFromTextFile(const char* graphFile) } -ModelPtr OnnxParser::LoadModelFromBinaryFile(const char* graphFile) +ModelPtr OnnxParserImpl::LoadModelFromBinaryFile(const char* graphFile) { FILE* fd = fopen(graphFile, "rb"); @@ -560,14 +590,14 @@ ModelPtr OnnxParser::LoadModelFromBinaryFile(const char* graphFile) } -INetworkPtr OnnxParser::CreateNetworkFromBinaryFile(const char* graphFile) +INetworkPtr OnnxParserImpl::CreateNetworkFromBinaryFile(const char* graphFile) { ResetParser(); ModelPtr modelProto = LoadModelFromBinaryFile(graphFile); return CreateNetworkFromModel(*modelProto); } -ModelPtr OnnxParser::LoadModelFromString(const std::string& protoText) +ModelPtr OnnxParserImpl::LoadModelFromString(const std::string& protoText) { if (protoText == "") { @@ -586,14 +616,14 @@ ModelPtr OnnxParser::LoadModelFromString(const std::string& protoText) return modelProto; } -INetworkPtr OnnxParser::CreateNetworkFromString(const std::string& protoText) +INetworkPtr OnnxParserImpl::CreateNetworkFromString(const std::string& protoText) { ResetParser(); ModelPtr modelProto = LoadModelFromString(protoText); return CreateNetworkFromModel(*modelProto); } -INetworkPtr OnnxParser::CreateNetworkFromModel(onnx::ModelProto& model) +INetworkPtr OnnxParserImpl::CreateNetworkFromModel(onnx::ModelProto& model) { m_Network = INetwork::Create(); try @@ -610,7 +640,7 @@ INetworkPtr OnnxParser::CreateNetworkFromModel(onnx::ModelProto& model) return std::move(m_Network); } -void OnnxParser::LoadGraph() +void OnnxParserImpl::LoadGraph() { ARMNN_ASSERT(m_Graph.get() != nullptr); @@ -684,7 +714,7 @@ void OnnxParser::LoadGraph() } } -void OnnxParser::SetupInfo(const google::protobuf::RepeatedPtrField* list) +void OnnxParserImpl::SetupInfo(const google::protobuf::RepeatedPtrField* list) { for (auto tensor : *list) { @@ -695,7 +725,7 @@ void OnnxParser::SetupInfo(const google::protobuf::RepeatedPtrField (static_cast(m_Graph->node_size()), UsageSummary()); auto matmulAndConstant = [&](const std::string& constInput, @@ -753,10 +783,10 @@ void OnnxParser::DetectFullyConnected() } template -void OnnxParser::GetInputAndParam(const onnx::NodeProto& node, - std::string* inputName, - std::string* constName, - const Location& location) +void OnnxParserImpl::GetInputAndParam(const onnx::NodeProto& node, + std::string* inputName, + std::string* constName, + const Location& location) { int cstIndex; if (m_TensorsInfo[node.input(0)].isConstant()) @@ -786,7 +816,7 @@ void OnnxParser::GetInputAndParam(const onnx::NodeProto& node, } template -void OnnxParser::To1DTensor(const std::string& name, const Location& location) +void OnnxParserImpl::To1DTensor(const std::string& name, const Location& location) { TensorShape shape = m_TensorsInfo[name].m_info->GetShape(); std::vector newShape; @@ -805,7 +835,7 @@ void OnnxParser::To1DTensor(const std::string& name, const Location& location) m_TensorsInfo[name].m_info->SetShape(TensorShape(static_cast(newShape.size()), newShape.data())); } -void OnnxParser::AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const Convolution2dDescriptor& convDesc) +void OnnxParserImpl::AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const Convolution2dDescriptor& convDesc) { ARMNN_ASSERT(node.op_type() == "Conv"); @@ -864,7 +894,7 @@ void OnnxParser::AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, cons RegisterOutputSlots(layer, {node.output(0)}); } -void OnnxParser::AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx::NodeProto* addNode) +void OnnxParserImpl::AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx::NodeProto* addNode) { // find matmul inputs @@ -941,7 +971,7 @@ void OnnxParser::AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx } } -void OnnxParser::AddPoolingLayer(const onnx::NodeProto& node, Pooling2dDescriptor& desc) +void OnnxParserImpl::AddPoolingLayer(const onnx::NodeProto& node, Pooling2dDescriptor& desc) { CHECK_VALID_SIZE(static_cast(node.input_size()), 1); @@ -1021,8 +1051,8 @@ void OnnxParser::AddPoolingLayer(const onnx::NodeProto& node, Pooling2dDescripto RegisterOutputSlots(layer, {node.output(0)}); } -std::pair OnnxParser::AddPrepareBroadcast(const std::string& input0, - const std::string& input1) +std::pair OnnxParserImpl::AddPrepareBroadcast(const std::string& input0, + const std::string& input1) { std::pair inputs = std::make_pair(input0, input1); @@ -1044,7 +1074,7 @@ std::pair OnnxParser::AddPrepareBroadcast(const std::s return inputs; } -void OnnxParser::CreateConstantLayer(const std::string& tensorName, const std::string& layerName) +void OnnxParserImpl::CreateConstantLayer(const std::string& tensorName, const std::string& layerName) { auto armnnTensor = CreateConstTensor(tensorName); @@ -1053,9 +1083,9 @@ void OnnxParser::CreateConstantLayer(const std::string& tensorName, const std::s RegisterOutputSlots(layer, {tensorName}); } -void OnnxParser::CreateReshapeLayer(const std::string& inputName, - const std::string& outputName, - const std::string& layerName) +void OnnxParserImpl::CreateReshapeLayer(const std::string& inputName, + const std::string& outputName, + const std::string& layerName) { const TensorInfo outputTensorInfo = *m_TensorsInfo[outputName].m_info; ReshapeDescriptor reshapeDesc; @@ -1073,7 +1103,7 @@ void OnnxParser::CreateReshapeLayer(const std::string& inputName, RegisterOutputSlots(layer, {outputName}); } -void OnnxParser::ParseActivation(const onnx::NodeProto& node, const armnn::ActivationFunction func) +void OnnxParserImpl::ParseActivation(const onnx::NodeProto& node, const armnn::ActivationFunction func) { CHECK_VALID_SIZE(static_cast(node.input_size()), 1, 3); CHECK_VALID_SIZE(static_cast(node.output_size()), 1); @@ -1103,32 +1133,32 @@ void OnnxParser::ParseActivation(const onnx::NodeProto& node, const armnn::Activ RegisterOutputSlots(layer, {node.output(0)}); } -void OnnxParser::ParseClip(const onnx::NodeProto& node) +void OnnxParserImpl::ParseClip(const onnx::NodeProto& node) { ParseActivation(node, ActivationFunction::BoundedReLu); } -void OnnxParser::ParseSigmoid(const onnx::NodeProto& node) +void OnnxParserImpl::ParseSigmoid(const onnx::NodeProto& node) { ParseActivation(node, ActivationFunction::Sigmoid); } -void OnnxParser::ParseTanh(const onnx::NodeProto& node) +void OnnxParserImpl::ParseTanh(const onnx::NodeProto& node) { ParseActivation(node, ActivationFunction::TanH); } -void OnnxParser::ParseRelu(const onnx::NodeProto& node) +void OnnxParserImpl::ParseRelu(const onnx::NodeProto& node) { ParseActivation(node, ActivationFunction::ReLu); } -void OnnxParser::ParseLeakyRelu(const onnx::NodeProto& node) +void OnnxParserImpl::ParseLeakyRelu(const onnx::NodeProto& node) { ParseActivation(node, ActivationFunction::LeakyReLu); } -void OnnxParser::ParseAdd(const onnx::NodeProto& node) +void OnnxParserImpl::ParseAdd(const onnx::NodeProto& node) { CHECK_VALID_SIZE(static_cast(node.input_size()), 2); CHECK_VALID_SIZE(static_cast(node.output_size()), 1); @@ -1186,7 +1216,7 @@ void OnnxParser::ParseAdd(const onnx::NodeProto& node) RegisterOutputSlots(layer, {node.output(0)}); } -void OnnxParser::ParseAveragePool(const onnx::NodeProto& node) +void OnnxParserImpl::ParseAveragePool(const onnx::NodeProto& node) { Pooling2dDescriptor desc; desc.m_PoolType = PoolingAlgorithm::Average; @@ -1199,7 +1229,7 @@ void OnnxParser::ParseAveragePool(const onnx::NodeProto& node) AddPoolingLayer(node, desc); } -void OnnxParser::ParseBatchNormalization(const onnx::NodeProto& node) +void OnnxParserImpl::ParseBatchNormalization(const onnx::NodeProto& node) { //IGNORE momentum parameter and spatial parameters @@ -1246,7 +1276,7 @@ void OnnxParser::ParseBatchNormalization(const onnx::NodeProto& node) RegisterOutputSlots(layer, {node.output(0)}); } -void OnnxParser::ParseConstant(const onnx::NodeProto& node) +void OnnxParserImpl::ParseConstant(const onnx::NodeProto& node) { CHECK_VALID_SIZE(static_cast(node.attribute_size()), 1); if (!node.attribute(0).has_t()) @@ -1269,7 +1299,7 @@ void OnnxParser::ParseConstant(const onnx::NodeProto& node) CreateConstantLayer(node.output(0), node.name()); } -void OnnxParser::ParseConv(const onnx::NodeProto& node) +void OnnxParserImpl::ParseConv(const onnx::NodeProto& node) { CHECK_VALID_SIZE(static_cast(node.input_size()), 2, 3); //input, weight, (bias) CHECK_VALID_SIZE(static_cast(node.output_size()), 1); @@ -1462,7 +1492,7 @@ void OnnxParser::ParseConv(const onnx::NodeProto& node) RegisterOutputSlots(layer, {node.output(0)}); } -void OnnxParser::ParseFlatten(const onnx::NodeProto& node) +void OnnxParserImpl::ParseFlatten(const onnx::NodeProto& node) { CHECK_VALID_SIZE(static_cast(node.input_size()), 1); CHECK_VALID_SIZE(static_cast(node.output_size()), 1); @@ -1509,7 +1539,7 @@ void OnnxParser::ParseFlatten(const onnx::NodeProto& node) CreateReshapeLayer(node.input(0), node.output(0), node.name()); } -void OnnxParser::ParseGlobalAveragePool(const onnx::NodeProto& node) +void OnnxParserImpl::ParseGlobalAveragePool(const onnx::NodeProto& node) { Pooling2dDescriptor desc = Pooling2dDescriptor(); desc.m_PoolType = PoolingAlgorithm::Average; @@ -1533,7 +1563,7 @@ void OnnxParser::ParseGlobalAveragePool(const onnx::NodeProto& node) RegisterOutputSlots(layer, {node.output(0)}); } -void OnnxParser::ParseMaxPool(const onnx::NodeProto& node) +void OnnxParserImpl::ParseMaxPool(const onnx::NodeProto& node) { Pooling2dDescriptor desc; desc.m_PoolType = PoolingAlgorithm::Max; @@ -1541,7 +1571,7 @@ void OnnxParser::ParseMaxPool(const onnx::NodeProto& node) AddPoolingLayer(node, desc); } -void OnnxParser::ParseReshape(const onnx::NodeProto& node) +void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node) { CHECK_VALID_SIZE(static_cast(node.input_size()), 2); CHECK_VALID_SIZE(static_cast(node.output_size()), 1); @@ -1594,9 +1624,9 @@ void OnnxParser::ParseReshape(const onnx::NodeProto& node) } } -void OnnxParser::PrependForBroadcast(const std::string& outputName, - const std::string& input0, - const std::string& input1) +void OnnxParserImpl::PrependForBroadcast(const std::string& outputName, + const std::string& input0, + const std::string& input1) { //input0 should be reshaped to have same number of dim as input1 TensorInfo outputTensorInfo = TensorInfo(*m_TensorsInfo[input0].m_info); @@ -1633,7 +1663,7 @@ void OnnxParser::PrependForBroadcast(const std::string& outputName, } } -void OnnxParser::SetupInputLayers() +void OnnxParserImpl::SetupInputLayers() { //Find user input and add their layers for(int inputIndex = 0; inputIndex < m_Graph->input_size(); ++inputIndex) @@ -1651,7 +1681,7 @@ void OnnxParser::SetupInputLayers() } } -void OnnxParser::SetupOutputLayers() +void OnnxParserImpl::SetupOutputLayers() { if(m_Graph->output_size() == 0) { @@ -1668,7 +1698,7 @@ void OnnxParser::SetupOutputLayers() } } -void OnnxParser::RegisterInputSlots(IConnectableLayer* layer, const std::vector& tensorIds) +void OnnxParserImpl::RegisterInputSlots(IConnectableLayer* layer, const std::vector& tensorIds) { ARMNN_ASSERT(layer != nullptr); if (tensorIds.size() != layer->GetNumInputSlots()) @@ -1695,7 +1725,7 @@ void OnnxParser::RegisterInputSlots(IConnectableLayer* layer, const std::vector< } } -void OnnxParser::RegisterOutputSlots(IConnectableLayer* layer, const std::vector& tensorIds) +void OnnxParserImpl::RegisterOutputSlots(IConnectableLayer* layer, const std::vector& tensorIds) { ARMNN_ASSERT(layer != nullptr); if (tensorIds.size() != layer->GetNumOutputSlots()) @@ -1734,7 +1764,7 @@ void OnnxParser::RegisterOutputSlots(IConnectableLayer* layer, const std::vector } } -BindingPointInfo OnnxParser::GetNetworkInputBindingInfo(const std::string& name) const +BindingPointInfo OnnxParserImpl::GetNetworkInputBindingInfo(const std::string& name) const { for(int i = 0; i < m_Graph->input_size(); ++i) { @@ -1748,7 +1778,7 @@ BindingPointInfo OnnxParser::GetNetworkInputBindingInfo(const std::string& name) name, CHECK_LOCATION().AsString())); } -BindingPointInfo OnnxParser::GetNetworkOutputBindingInfo(const std::string& name) const +BindingPointInfo OnnxParserImpl::GetNetworkOutputBindingInfo(const std::string& name) const { for(int i = 0; i < m_Graph->output_size(); ++i) { @@ -1762,7 +1792,7 @@ BindingPointInfo OnnxParser::GetNetworkOutputBindingInfo(const std::string& name name, CHECK_LOCATION().AsString())); } -std::vector OnnxParser::GetInputs(ModelPtr& model) +std::vector OnnxParserImpl::GetInputs(ModelPtr& model) { if(model == nullptr) { throw InvalidArgumentException(fmt::format("The given model cannot be null {}", @@ -1786,7 +1816,7 @@ std::vector OnnxParser::GetInputs(ModelPtr& model) return inputNames; } -std::vector OnnxParser::GetOutputs(ModelPtr& model) +std::vector OnnxParserImpl::GetOutputs(ModelPtr& model) { if(model == nullptr) { throw InvalidArgumentException(fmt::format("The given model cannot be null {}", -- cgit v1.2.1