diff options
Diffstat (limited to 'src/armnnOnnxParser')
-rw-r--r-- | src/armnnOnnxParser/OnnxParser.cpp | 208 | ||||
-rw-r--r-- | src/armnnOnnxParser/OnnxParser.hpp | 17 | ||||
-rw-r--r-- | src/armnnOnnxParser/test/GetInputsOutputs.cpp | 20 |
3 files changed, 138 insertions, 107 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp index f3d0a73342..9f5aa1975a 100644 --- a/src/armnnOnnxParser/OnnxParser.cpp +++ b/src/armnnOnnxParser/OnnxParser.cpp @@ -20,6 +20,51 @@ using namespace armnn; namespace armnnOnnxParser { + +IOnnxParser::IOnnxParser() : pOnnxParserImpl(new OnnxParserImpl()) {} + +IOnnxParser::~IOnnxParser() = default; + +IOnnxParser* IOnnxParser::CreateRaw() +{ + return new IOnnxParser(); +} + +IOnnxParserPtr IOnnxParser::Create() +{ + return IOnnxParserPtr(CreateRaw(), &IOnnxParser::Destroy); +} + +void IOnnxParser::Destroy(IOnnxParser* parser) +{ + delete parser; +} + +armnn::INetworkPtr IOnnxParser::CreateNetworkFromBinaryFile(const char* graphFile) +{ + return pOnnxParserImpl->CreateNetworkFromBinaryFile(graphFile); +} + +armnn::INetworkPtr IOnnxParser::CreateNetworkFromTextFile(const char* graphFile) +{ + return pOnnxParserImpl->CreateNetworkFromTextFile(graphFile); +} + +armnn::INetworkPtr IOnnxParser::CreateNetworkFromString(const std::string& protoText) +{ + return pOnnxParserImpl->CreateNetworkFromString(protoText); +} + +BindingPointInfo IOnnxParser::GetNetworkInputBindingInfo(const std::string& name) const +{ + return pOnnxParserImpl->GetNetworkInputBindingInfo(name); +} + +BindingPointInfo IOnnxParser::GetNetworkOutputBindingInfo(const std::string& name) const +{ + return pOnnxParserImpl->GetNetworkOutputBindingInfo(name); +} + namespace { void CheckValidDataType(std::initializer_list<onnx::TensorProto::DataType> validInputTypes, @@ -357,25 +402,25 @@ TensorInfo ComputeReshapeInfo(const TensorShape& targetShapeTensor, } //namespace -const std::map<std::string, OnnxParser::OperationParsingFunction> OnnxParser::m_ParserFunctions = { - { "BatchNormalization", &OnnxParser::ParseBatchNormalization}, - { "GlobalAveragePool", &OnnxParser::ParseGlobalAveragePool}, - { "AveragePool", &OnnxParser::ParseAveragePool }, - { "Clip", &OnnxParser::ParseClip }, - { "Constant", &OnnxParser::ParseConstant }, - { "MaxPool", &OnnxParser::ParseMaxPool }, - { "Reshape", &OnnxParser::ParseReshape }, - { "Sigmoid", &OnnxParser::ParseSigmoid }, - { "Tanh", &OnnxParser::ParseTanh }, - { "Relu", &OnnxParser::ParseRelu }, - { "LeakyRelu", &OnnxParser::ParseLeakyRelu }, - { "Conv", &OnnxParser::ParseConv }, - { "Add", &OnnxParser::ParseAdd }, - { "Flatten", &OnnxParser::ParseFlatten}, +const std::map<std::string, OnnxParserImpl::OperationParsingFunction> OnnxParserImpl::m_ParserFunctions = { + { "BatchNormalization", &OnnxParserImpl::ParseBatchNormalization}, + { "GlobalAveragePool", &OnnxParserImpl::ParseGlobalAveragePool}, + { "AveragePool", &OnnxParserImpl::ParseAveragePool }, + { "Clip", &OnnxParserImpl::ParseClip }, + { "Constant", &OnnxParserImpl::ParseConstant }, + { "MaxPool", &OnnxParserImpl::ParseMaxPool }, + { "Reshape", &OnnxParserImpl::ParseReshape }, + { "Sigmoid", &OnnxParserImpl::ParseSigmoid }, + { "Tanh", &OnnxParserImpl::ParseTanh }, + { "Relu", &OnnxParserImpl::ParseRelu }, + { "LeakyRelu", &OnnxParserImpl::ParseLeakyRelu }, + { "Conv", &OnnxParserImpl::ParseConv }, + { "Add", &OnnxParserImpl::ParseAdd }, + { "Flatten", &OnnxParserImpl::ParseFlatten}, }; template<typename TypePair, typename Location> -void OnnxParser::ValidateInputs(const onnx::NodeProto& node, +void OnnxParserImpl::ValidateInputs(const onnx::NodeProto& node, TypePair validInputs, const Location& location) { @@ -391,13 +436,13 @@ void OnnxParser::ValidateInputs(const onnx::NodeProto& node, } #define VALID_INPUTS(NODE, VALID_INPUTS) \ - OnnxParser::ValidateInputs(NODE, \ + OnnxParserImpl::ValidateInputs(NODE, \ VALID_INPUTS, \ CHECK_LOCATION()) -std::vector<TensorInfo> OnnxParser::ComputeOutputInfo(std::vector<std::string> outNames, - const IConnectableLayer* layer, - std::vector<TensorShape> inputShapes) +std::vector<TensorInfo> OnnxParserImpl::ComputeOutputInfo(std::vector<std::string> outNames, + const IConnectableLayer* layer, + std::vector<TensorShape> inputShapes) { ARMNN_ASSERT(! outNames.empty()); bool needCompute = std::any_of(outNames.begin(), @@ -427,33 +472,18 @@ std::vector<TensorInfo> OnnxParser::ComputeOutputInfo(std::vector<std::string> o return outInfo; } -IOnnxParser* IOnnxParser::CreateRaw() -{ - return new OnnxParser(); -} - -IOnnxParserPtr IOnnxParser::Create() -{ - return IOnnxParserPtr(CreateRaw(), &IOnnxParser::Destroy); -} - -void IOnnxParser::Destroy(IOnnxParser* parser) -{ - delete parser; -} - -OnnxParser::OnnxParser() +OnnxParserImpl::OnnxParserImpl() : m_Network(nullptr, nullptr) { } -void OnnxParser::ResetParser() +void OnnxParserImpl::ResetParser() { m_Network = armnn::INetworkPtr(nullptr, nullptr); m_Graph = nullptr; } -void OnnxParser::Cleanup() +void OnnxParserImpl::Cleanup() { m_TensorConnections.clear(); m_TensorsInfo.clear(); @@ -461,7 +491,7 @@ void OnnxParser::Cleanup() m_OutputsFusedAndUsed.clear(); } -std::pair<ConstTensor, std::unique_ptr<float[]>> OnnxParser::CreateConstTensor(const std::string name) +std::pair<ConstTensor, std::unique_ptr<float[]>> OnnxParserImpl::CreateConstTensor(const std::string name) { const TensorInfo tensorInfo = *m_TensorsInfo[name].m_info; onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor; @@ -499,7 +529,7 @@ std::pair<ConstTensor, std::unique_ptr<float[]>> OnnxParser::CreateConstTensor(c return std::make_pair(ConstTensor(tensorInfo, tensorData.get()), std::move(tensorData)); } -ModelPtr OnnxParser::LoadModelFromTextFile(const char* graphFile) +ModelPtr OnnxParserImpl::LoadModelFromTextFile(const char* graphFile) { FILE* fd = fopen(graphFile, "r"); @@ -524,7 +554,7 @@ ModelPtr OnnxParser::LoadModelFromTextFile(const char* graphFile) return modelProto; } -INetworkPtr OnnxParser::CreateNetworkFromTextFile(const char* graphFile) +INetworkPtr OnnxParserImpl::CreateNetworkFromTextFile(const char* graphFile) { ResetParser(); ModelPtr modelProto = LoadModelFromTextFile(graphFile); @@ -532,7 +562,7 @@ INetworkPtr OnnxParser::CreateNetworkFromTextFile(const char* graphFile) } -ModelPtr OnnxParser::LoadModelFromBinaryFile(const char* graphFile) +ModelPtr OnnxParserImpl::LoadModelFromBinaryFile(const char* graphFile) { FILE* fd = fopen(graphFile, "rb"); @@ -560,14 +590,14 @@ ModelPtr OnnxParser::LoadModelFromBinaryFile(const char* graphFile) } -INetworkPtr OnnxParser::CreateNetworkFromBinaryFile(const char* graphFile) +INetworkPtr OnnxParserImpl::CreateNetworkFromBinaryFile(const char* graphFile) { ResetParser(); ModelPtr modelProto = LoadModelFromBinaryFile(graphFile); return CreateNetworkFromModel(*modelProto); } -ModelPtr OnnxParser::LoadModelFromString(const std::string& protoText) +ModelPtr OnnxParserImpl::LoadModelFromString(const std::string& protoText) { if (protoText == "") { @@ -586,14 +616,14 @@ ModelPtr OnnxParser::LoadModelFromString(const std::string& protoText) return modelProto; } -INetworkPtr OnnxParser::CreateNetworkFromString(const std::string& protoText) +INetworkPtr OnnxParserImpl::CreateNetworkFromString(const std::string& protoText) { ResetParser(); ModelPtr modelProto = LoadModelFromString(protoText); return CreateNetworkFromModel(*modelProto); } -INetworkPtr OnnxParser::CreateNetworkFromModel(onnx::ModelProto& model) +INetworkPtr OnnxParserImpl::CreateNetworkFromModel(onnx::ModelProto& model) { m_Network = INetwork::Create(); try @@ -610,7 +640,7 @@ INetworkPtr OnnxParser::CreateNetworkFromModel(onnx::ModelProto& model) return std::move(m_Network); } -void OnnxParser::LoadGraph() +void OnnxParserImpl::LoadGraph() { ARMNN_ASSERT(m_Graph.get() != nullptr); @@ -684,7 +714,7 @@ void OnnxParser::LoadGraph() } } -void OnnxParser::SetupInfo(const google::protobuf::RepeatedPtrField<onnx::ValueInfoProto >* list) +void OnnxParserImpl::SetupInfo(const google::protobuf::RepeatedPtrField<onnx::ValueInfoProto >* list) { for (auto tensor : *list) { @@ -695,7 +725,7 @@ void OnnxParser::SetupInfo(const google::protobuf::RepeatedPtrField<onnx::ValueI } } -void OnnxParser::DetectFullyConnected() +void OnnxParserImpl::DetectFullyConnected() { m_OutputsFusedAndUsed = std::vector<UsageSummary> (static_cast<size_t>(m_Graph->node_size()), UsageSummary()); auto matmulAndConstant = [&](const std::string& constInput, @@ -753,10 +783,10 @@ void OnnxParser::DetectFullyConnected() } template<typename Location> -void OnnxParser::GetInputAndParam(const onnx::NodeProto& node, - std::string* inputName, - std::string* constName, - const Location& location) +void OnnxParserImpl::GetInputAndParam(const onnx::NodeProto& node, + std::string* inputName, + std::string* constName, + const Location& location) { int cstIndex; if (m_TensorsInfo[node.input(0)].isConstant()) @@ -786,7 +816,7 @@ void OnnxParser::GetInputAndParam(const onnx::NodeProto& node, } template<typename Location> -void OnnxParser::To1DTensor(const std::string& name, const Location& location) +void OnnxParserImpl::To1DTensor(const std::string& name, const Location& location) { TensorShape shape = m_TensorsInfo[name].m_info->GetShape(); std::vector<uint32_t> newShape; @@ -805,7 +835,7 @@ void OnnxParser::To1DTensor(const std::string& name, const Location& location) m_TensorsInfo[name].m_info->SetShape(TensorShape(static_cast<unsigned int>(newShape.size()), newShape.data())); } -void OnnxParser::AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const Convolution2dDescriptor& convDesc) +void OnnxParserImpl::AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const Convolution2dDescriptor& convDesc) { ARMNN_ASSERT(node.op_type() == "Conv"); @@ -864,7 +894,7 @@ void OnnxParser::AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, cons RegisterOutputSlots(layer, {node.output(0)}); } -void OnnxParser::AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx::NodeProto* addNode) +void OnnxParserImpl::AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx::NodeProto* addNode) { // find matmul inputs @@ -941,7 +971,7 @@ void OnnxParser::AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx } } -void OnnxParser::AddPoolingLayer(const onnx::NodeProto& node, Pooling2dDescriptor& desc) +void OnnxParserImpl::AddPoolingLayer(const onnx::NodeProto& node, Pooling2dDescriptor& desc) { CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1); @@ -1021,8 +1051,8 @@ void OnnxParser::AddPoolingLayer(const onnx::NodeProto& node, Pooling2dDescripto RegisterOutputSlots(layer, {node.output(0)}); } -std::pair<std::string, std::string> OnnxParser::AddPrepareBroadcast(const std::string& input0, - const std::string& input1) +std::pair<std::string, std::string> OnnxParserImpl::AddPrepareBroadcast(const std::string& input0, + const std::string& input1) { std::pair<std::string, std::string> inputs = std::make_pair(input0, input1); @@ -1044,7 +1074,7 @@ std::pair<std::string, std::string> OnnxParser::AddPrepareBroadcast(const std::s return inputs; } -void OnnxParser::CreateConstantLayer(const std::string& tensorName, const std::string& layerName) +void OnnxParserImpl::CreateConstantLayer(const std::string& tensorName, const std::string& layerName) { auto armnnTensor = CreateConstTensor(tensorName); @@ -1053,9 +1083,9 @@ void OnnxParser::CreateConstantLayer(const std::string& tensorName, const std::s RegisterOutputSlots(layer, {tensorName}); } -void OnnxParser::CreateReshapeLayer(const std::string& inputName, - const std::string& outputName, - const std::string& layerName) +void OnnxParserImpl::CreateReshapeLayer(const std::string& inputName, + const std::string& outputName, + const std::string& layerName) { const TensorInfo outputTensorInfo = *m_TensorsInfo[outputName].m_info; ReshapeDescriptor reshapeDesc; @@ -1073,7 +1103,7 @@ void OnnxParser::CreateReshapeLayer(const std::string& inputName, RegisterOutputSlots(layer, {outputName}); } -void OnnxParser::ParseActivation(const onnx::NodeProto& node, const armnn::ActivationFunction func) +void OnnxParserImpl::ParseActivation(const onnx::NodeProto& node, const armnn::ActivationFunction func) { CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1, 3); CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1); @@ -1103,32 +1133,32 @@ void OnnxParser::ParseActivation(const onnx::NodeProto& node, const armnn::Activ RegisterOutputSlots(layer, {node.output(0)}); } -void OnnxParser::ParseClip(const onnx::NodeProto& node) +void OnnxParserImpl::ParseClip(const onnx::NodeProto& node) { ParseActivation(node, ActivationFunction::BoundedReLu); } -void OnnxParser::ParseSigmoid(const onnx::NodeProto& node) +void OnnxParserImpl::ParseSigmoid(const onnx::NodeProto& node) { ParseActivation(node, ActivationFunction::Sigmoid); } -void OnnxParser::ParseTanh(const onnx::NodeProto& node) +void OnnxParserImpl::ParseTanh(const onnx::NodeProto& node) { ParseActivation(node, ActivationFunction::TanH); } -void OnnxParser::ParseRelu(const onnx::NodeProto& node) +void OnnxParserImpl::ParseRelu(const onnx::NodeProto& node) { ParseActivation(node, ActivationFunction::ReLu); } -void OnnxParser::ParseLeakyRelu(const onnx::NodeProto& node) +void OnnxParserImpl::ParseLeakyRelu(const onnx::NodeProto& node) { ParseActivation(node, ActivationFunction::LeakyReLu); } -void OnnxParser::ParseAdd(const onnx::NodeProto& node) +void OnnxParserImpl::ParseAdd(const onnx::NodeProto& node) { CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2); CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1); @@ -1186,7 +1216,7 @@ void OnnxParser::ParseAdd(const onnx::NodeProto& node) RegisterOutputSlots(layer, {node.output(0)}); } -void OnnxParser::ParseAveragePool(const onnx::NodeProto& node) +void OnnxParserImpl::ParseAveragePool(const onnx::NodeProto& node) { Pooling2dDescriptor desc; desc.m_PoolType = PoolingAlgorithm::Average; @@ -1199,7 +1229,7 @@ void OnnxParser::ParseAveragePool(const onnx::NodeProto& node) AddPoolingLayer(node, desc); } -void OnnxParser::ParseBatchNormalization(const onnx::NodeProto& node) +void OnnxParserImpl::ParseBatchNormalization(const onnx::NodeProto& node) { //IGNORE momentum parameter and spatial parameters @@ -1246,7 +1276,7 @@ void OnnxParser::ParseBatchNormalization(const onnx::NodeProto& node) RegisterOutputSlots(layer, {node.output(0)}); } -void OnnxParser::ParseConstant(const onnx::NodeProto& node) +void OnnxParserImpl::ParseConstant(const onnx::NodeProto& node) { CHECK_VALID_SIZE(static_cast<size_t>(node.attribute_size()), 1); if (!node.attribute(0).has_t()) @@ -1269,7 +1299,7 @@ void OnnxParser::ParseConstant(const onnx::NodeProto& node) CreateConstantLayer(node.output(0), node.name()); } -void OnnxParser::ParseConv(const onnx::NodeProto& node) +void OnnxParserImpl::ParseConv(const onnx::NodeProto& node) { CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2, 3); //input, weight, (bias) CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1); @@ -1462,7 +1492,7 @@ void OnnxParser::ParseConv(const onnx::NodeProto& node) RegisterOutputSlots(layer, {node.output(0)}); } -void OnnxParser::ParseFlatten(const onnx::NodeProto& node) +void OnnxParserImpl::ParseFlatten(const onnx::NodeProto& node) { CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 1); CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1); @@ -1509,7 +1539,7 @@ void OnnxParser::ParseFlatten(const onnx::NodeProto& node) CreateReshapeLayer(node.input(0), node.output(0), node.name()); } -void OnnxParser::ParseGlobalAveragePool(const onnx::NodeProto& node) +void OnnxParserImpl::ParseGlobalAveragePool(const onnx::NodeProto& node) { Pooling2dDescriptor desc = Pooling2dDescriptor(); desc.m_PoolType = PoolingAlgorithm::Average; @@ -1533,7 +1563,7 @@ void OnnxParser::ParseGlobalAveragePool(const onnx::NodeProto& node) RegisterOutputSlots(layer, {node.output(0)}); } -void OnnxParser::ParseMaxPool(const onnx::NodeProto& node) +void OnnxParserImpl::ParseMaxPool(const onnx::NodeProto& node) { Pooling2dDescriptor desc; desc.m_PoolType = PoolingAlgorithm::Max; @@ -1541,7 +1571,7 @@ void OnnxParser::ParseMaxPool(const onnx::NodeProto& node) AddPoolingLayer(node, desc); } -void OnnxParser::ParseReshape(const onnx::NodeProto& node) +void OnnxParserImpl::ParseReshape(const onnx::NodeProto& node) { CHECK_VALID_SIZE(static_cast<size_t>(node.input_size()), 2); CHECK_VALID_SIZE(static_cast<size_t>(node.output_size()), 1); @@ -1594,9 +1624,9 @@ void OnnxParser::ParseReshape(const onnx::NodeProto& node) } } -void OnnxParser::PrependForBroadcast(const std::string& outputName, - const std::string& input0, - const std::string& input1) +void OnnxParserImpl::PrependForBroadcast(const std::string& outputName, + const std::string& input0, + const std::string& input1) { //input0 should be reshaped to have same number of dim as input1 TensorInfo outputTensorInfo = TensorInfo(*m_TensorsInfo[input0].m_info); @@ -1633,7 +1663,7 @@ void OnnxParser::PrependForBroadcast(const std::string& outputName, } } -void OnnxParser::SetupInputLayers() +void OnnxParserImpl::SetupInputLayers() { //Find user input and add their layers for(int inputIndex = 0; inputIndex < m_Graph->input_size(); ++inputIndex) @@ -1651,7 +1681,7 @@ void OnnxParser::SetupInputLayers() } } -void OnnxParser::SetupOutputLayers() +void OnnxParserImpl::SetupOutputLayers() { if(m_Graph->output_size() == 0) { @@ -1668,7 +1698,7 @@ void OnnxParser::SetupOutputLayers() } } -void OnnxParser::RegisterInputSlots(IConnectableLayer* layer, const std::vector<std::string>& tensorIds) +void OnnxParserImpl::RegisterInputSlots(IConnectableLayer* layer, const std::vector<std::string>& tensorIds) { ARMNN_ASSERT(layer != nullptr); if (tensorIds.size() != layer->GetNumInputSlots()) @@ -1695,7 +1725,7 @@ void OnnxParser::RegisterInputSlots(IConnectableLayer* layer, const std::vector< } } -void OnnxParser::RegisterOutputSlots(IConnectableLayer* layer, const std::vector<std::string>& tensorIds) +void OnnxParserImpl::RegisterOutputSlots(IConnectableLayer* layer, const std::vector<std::string>& tensorIds) { ARMNN_ASSERT(layer != nullptr); if (tensorIds.size() != layer->GetNumOutputSlots()) @@ -1734,7 +1764,7 @@ void OnnxParser::RegisterOutputSlots(IConnectableLayer* layer, const std::vector } } -BindingPointInfo OnnxParser::GetNetworkInputBindingInfo(const std::string& name) const +BindingPointInfo OnnxParserImpl::GetNetworkInputBindingInfo(const std::string& name) const { for(int i = 0; i < m_Graph->input_size(); ++i) { @@ -1748,7 +1778,7 @@ BindingPointInfo OnnxParser::GetNetworkInputBindingInfo(const std::string& name) name, CHECK_LOCATION().AsString())); } -BindingPointInfo OnnxParser::GetNetworkOutputBindingInfo(const std::string& name) const +BindingPointInfo OnnxParserImpl::GetNetworkOutputBindingInfo(const std::string& name) const { for(int i = 0; i < m_Graph->output_size(); ++i) { @@ -1762,7 +1792,7 @@ BindingPointInfo OnnxParser::GetNetworkOutputBindingInfo(const std::string& name name, CHECK_LOCATION().AsString())); } -std::vector<std::string> OnnxParser::GetInputs(ModelPtr& model) +std::vector<std::string> OnnxParserImpl::GetInputs(ModelPtr& model) { if(model == nullptr) { throw InvalidArgumentException(fmt::format("The given model cannot be null {}", @@ -1786,7 +1816,7 @@ std::vector<std::string> OnnxParser::GetInputs(ModelPtr& model) return inputNames; } -std::vector<std::string> OnnxParser::GetOutputs(ModelPtr& model) +std::vector<std::string> OnnxParserImpl::GetOutputs(ModelPtr& model) { if(model == nullptr) { throw InvalidArgumentException(fmt::format("The given model cannot be null {}", diff --git a/src/armnnOnnxParser/OnnxParser.hpp b/src/armnnOnnxParser/OnnxParser.hpp index a87863e95c..0db93248bc 100644 --- a/src/armnnOnnxParser/OnnxParser.hpp +++ b/src/armnnOnnxParser/OnnxParser.hpp @@ -22,33 +22,34 @@ namespace armnnOnnxParser using ModelPtr = std::unique_ptr<onnx::ModelProto>; -class OnnxParser : public IOnnxParser +class OnnxParserImpl { -using OperationParsingFunction = void(OnnxParser::*)(const onnx::NodeProto& NodeProto); +using OperationParsingFunction = void(OnnxParserImpl::*)(const onnx::NodeProto& NodeProto); public: using GraphPtr = std::unique_ptr<onnx::GraphProto>; /// Create the network from a protobuf binary file on disk - virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile) override; + armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile); /// Create the network from a protobuf text file on disk - virtual armnn::INetworkPtr CreateNetworkFromTextFile(const char* graphFile) override; + armnn::INetworkPtr CreateNetworkFromTextFile(const char* graphFile); /// Create the network directly from protobuf text in a string. Useful for debugging/testing - virtual armnn::INetworkPtr CreateNetworkFromString(const std::string& protoText) override; + armnn::INetworkPtr CreateNetworkFromString(const std::string& protoText); /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name - virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const override; + BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const; /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name - virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const override; + BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const; public: - OnnxParser(); + OnnxParserImpl(); + ~OnnxParserImpl() = default; static ModelPtr LoadModelFromBinaryFile(const char * fileName); static ModelPtr LoadModelFromTextFile(const char * fileName); diff --git a/src/armnnOnnxParser/test/GetInputsOutputs.cpp b/src/armnnOnnxParser/test/GetInputsOutputs.cpp index b22ef3a308..5bb3095cc7 100644 --- a/src/armnnOnnxParser/test/GetInputsOutputs.cpp +++ b/src/armnnOnnxParser/test/GetInputsOutputs.cpp @@ -68,8 +68,8 @@ struct GetInputsOutputsMainFixture : public armnnUtils::ParserPrototxtFixture<ar BOOST_FIXTURE_TEST_CASE(GetInput, GetInputsOutputsMainFixture) { - ModelPtr model = armnnOnnxParser::OnnxParser::LoadModelFromString(m_Prototext.c_str()); - std::vector<std::string> tensors = armnnOnnxParser::OnnxParser::GetInputs(model); + ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str()); + std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetInputs(model); BOOST_CHECK_EQUAL(1, tensors.size()); BOOST_CHECK_EQUAL("Input", tensors[0]); @@ -77,8 +77,8 @@ BOOST_FIXTURE_TEST_CASE(GetInput, GetInputsOutputsMainFixture) BOOST_FIXTURE_TEST_CASE(GetOutput, GetInputsOutputsMainFixture) { - ModelPtr model = armnnOnnxParser::OnnxParser::LoadModelFromString(m_Prototext.c_str()); - std::vector<std::string> tensors = armnnOnnxParser::OnnxParser::GetOutputs(model); + ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str()); + std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetOutputs(model); BOOST_CHECK_EQUAL(1, tensors.size()); BOOST_CHECK_EQUAL("Output", tensors[0]); } @@ -139,20 +139,20 @@ struct GetEmptyInputsOutputsFixture : public armnnUtils::ParserPrototxtFixture<a BOOST_FIXTURE_TEST_CASE(GetEmptyInputs, GetEmptyInputsOutputsFixture) { - ModelPtr model = armnnOnnxParser::OnnxParser::LoadModelFromString(m_Prototext.c_str()); - std::vector<std::string> tensors = armnnOnnxParser::OnnxParser::GetInputs(model); + ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str()); + std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetInputs(model); BOOST_CHECK_EQUAL(0, tensors.size()); } BOOST_AUTO_TEST_CASE(GetInputsNullModel) { - BOOST_CHECK_THROW(armnnOnnxParser::OnnxParser::LoadModelFromString(""), armnn::InvalidArgumentException); + BOOST_CHECK_THROW(armnnOnnxParser::OnnxParserImpl::LoadModelFromString(""), armnn::InvalidArgumentException); } BOOST_AUTO_TEST_CASE(GetOutputsNullModel) { auto silencer = google::protobuf::LogSilencer(); //get rid of errors from protobuf - BOOST_CHECK_THROW(armnnOnnxParser::OnnxParser::LoadModelFromString("nknnk"), armnn::ParseException); + BOOST_CHECK_THROW(armnnOnnxParser::OnnxParserImpl::LoadModelFromString("nknnk"), armnn::ParseException); } struct GetInputsMultipleFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> @@ -243,8 +243,8 @@ struct GetInputsMultipleFixture : public armnnUtils::ParserPrototxtFixture<armnn BOOST_FIXTURE_TEST_CASE(GetInputsMultipleInputs, GetInputsMultipleFixture) { - ModelPtr model = armnnOnnxParser::OnnxParser::LoadModelFromString(m_Prototext.c_str()); - std::vector<std::string> tensors = armnnOnnxParser::OnnxParser::GetInputs(model); + ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str()); + std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetInputs(model); BOOST_CHECK_EQUAL(2, tensors.size()); BOOST_CHECK_EQUAL("Input0", tensors[0]); BOOST_CHECK_EQUAL("Input1", tensors[1]); |