diff options
Diffstat (limited to 'src/armnn/Network.hpp')
-rw-r--r-- | src/armnn/Network.hpp | 176 |
1 files changed, 78 insertions, 98 deletions
diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp index cffade5a21..8f16be1684 100644 --- a/src/armnn/Network.hpp +++ b/src/armnn/Network.hpp @@ -19,246 +19,249 @@ #include "Graph.hpp" #include "Layer.hpp" +#include "OptimizedNetworkImpl.hpp" namespace armnn { class Graph; +using NetworkImplPtr = std::unique_ptr<NetworkImpl, void(*)(NetworkImpl* network)>; + /// Private implementation of INetwork. -class Network final : public INetwork +class NetworkImpl { public: - Network(NetworkOptions networkOptions = {}); - ~Network(); + NetworkImpl(NetworkOptions networkOptions = {}); + ~NetworkImpl(); const Graph& GetGraph() const { return *m_Graph; } - Status PrintGraph() override; + Status PrintGraph(); - IConnectableLayer* AddInputLayer(LayerBindingId id, const char* name=nullptr) override; + IConnectableLayer* AddInputLayer(LayerBindingId id, const char* name=nullptr); IConnectableLayer* AddArgMinMaxLayer(const ArgMinMaxDescriptor& desc, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddBatchToSpaceNdLayer(const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddComparisonLayer(const ComparisonDescriptor& comparisonDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddConcatLayer(const ConcatDescriptor& concatDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor, const ConstTensor& weights, const Optional<ConstTensor>& biases, - const char* name = nullptr) override; + const char* name = nullptr); ARMNN_DEPRECATED_MSG("This AddConvolution2dLayer overload is deprecated") IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor, const ConstTensor& weights, - const char* name = nullptr) override; + const char* name = nullptr); ARMNN_DEPRECATED_MSG("This AddConvolution2dLayer overload is deprecated") IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor, const ConstTensor& weights, const ConstTensor& biases, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddDepthToSpaceLayer(const DepthToSpaceDescriptor& depthToSpaceDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddDepthwiseConvolution2dLayer( const DepthwiseConvolution2dDescriptor& convolution2dDescriptor, const ConstTensor& weights, const Optional<ConstTensor>& biases, - const char* name = nullptr) override; + const char* name = nullptr); ARMNN_DEPRECATED_MSG("This AddDepthwiseConvolution2dLayer overload is deprecated") IConnectableLayer* AddDepthwiseConvolution2dLayer( const DepthwiseConvolution2dDescriptor& convolution2dDescriptor, const ConstTensor& weights, - const char* name = nullptr) override; + const char* name = nullptr); ARMNN_DEPRECATED_MSG("This AddDepthwiseConvolution2dLayer overload is deprecated") IConnectableLayer* AddDepthwiseConvolution2dLayer( const DepthwiseConvolution2dDescriptor& convolution2dDescriptor, const ConstTensor& weights, const ConstTensor& biases, - const char* name = nullptr) override; + const char* name = nullptr); - IConnectableLayer* AddDequantizeLayer(const char* name = nullptr) override; + IConnectableLayer* AddDequantizeLayer(const char* name = nullptr); IConnectableLayer* AddDetectionPostProcessLayer( const DetectionPostProcessDescriptor& descriptor, const ConstTensor& anchors, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddElementwiseUnaryLayer(const ElementwiseUnaryDescriptor& elementwiseUnaryDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddFillLayer(const FillDescriptor& fillDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor, const ConstTensor& weights, const Optional<ConstTensor>& biases, - const char* name = nullptr) override; + const char* name = nullptr); ARMNN_DEPRECATED_MSG("This AddFullyConnectedLayer overload is deprecated") IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor, const ConstTensor& weights, - const char* name = nullptr) override; + const char* name = nullptr); ARMNN_DEPRECATED_MSG("This AddFullyConnectedLayer overload is deprecated") IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor, const ConstTensor& weights, const ConstTensor& biases, - const char* name = nullptr) override; + const char* name = nullptr); ARMNN_DEPRECATED_MSG("This AddGatherLayer overload is deprecated") - IConnectableLayer* AddGatherLayer(const char* name = nullptr) override; + IConnectableLayer* AddGatherLayer(const char* name = nullptr); IConnectableLayer* AddGatherLayer(const GatherDescriptor& gatherDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddPermuteLayer(const PermuteDescriptor& permuteDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddPooling2dLayer(const Pooling2dDescriptor& pooling2dDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddActivationLayer(const ActivationDescriptor& activationDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddNormalizationLayer(const NormalizationDescriptor& normalizationDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); - IConnectableLayer* AddSliceLayer(const SliceDescriptor& sliceDescriptor, const char* name = nullptr) override; + IConnectableLayer* AddSliceLayer(const SliceDescriptor& sliceDescriptor, const char* name = nullptr); IConnectableLayer* AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddSplitterLayer(const ViewsDescriptor& splitterDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); ARMNN_DEPRECATED_MSG("Use AddConcatLayer instead") IConnectableLayer* AddMergerLayer(const MergerDescriptor& mergerDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); ARMNN_DEPRECATED_MSG("Use AddElementwiseUnaryLayer instead") - IConnectableLayer* AddAbsLayer(const char* name = nullptr) override; + IConnectableLayer* AddAbsLayer(const char* name = nullptr); - IConnectableLayer* AddAdditionLayer(const char* name = nullptr) override; + IConnectableLayer* AddAdditionLayer(const char* name = nullptr); - IConnectableLayer* AddMultiplicationLayer(const char* name = nullptr) override; + IConnectableLayer* AddMultiplicationLayer(const char* name = nullptr); IConnectableLayer* AddBatchNormalizationLayer(const BatchNormalizationDescriptor& desc, const ConstTensor& mean, const ConstTensor& variance, const ConstTensor& beta, const ConstTensor& gamma, - const char* name = nullptr) override; + const char* name = nullptr); - IConnectableLayer* AddRankLayer(const char* name = nullptr) override; + IConnectableLayer* AddRankLayer(const char* name = nullptr); ARMNN_DEPRECATED_MSG("Use AddResizeLayer instead") IConnectableLayer* AddResizeBilinearLayer(const ResizeBilinearDescriptor& resizeDesc, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddResizeLayer(const ResizeDescriptor& resizeDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddReduceLayer(const ReduceDescriptor& reduceDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddInstanceNormalizationLayer(const InstanceNormalizationDescriptor& desc, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddL2NormalizationLayer(const L2NormalizationDescriptor& desc, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddLogSoftmaxLayer(const LogSoftmaxDescriptor& logSoftmaxDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); - IConnectableLayer* AddConstantLayer(const ConstTensor& input, const char* name = nullptr) override; + IConnectableLayer* AddConstantLayer(const ConstTensor& input, const char* name = nullptr); IConnectableLayer* AddReshapeLayer(const ReshapeDescriptor& reshapeDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddSpaceToBatchNdLayer(const SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddSpaceToDepthLayer(const SpaceToDepthDescriptor& spaceToDepthDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); - IConnectableLayer* AddFloorLayer(const char* name = nullptr) override; + IConnectableLayer* AddFloorLayer(const char* name = nullptr); - IConnectableLayer* AddOutputLayer(LayerBindingId id, const char* name = nullptr) override; + IConnectableLayer* AddOutputLayer(LayerBindingId id, const char* name = nullptr); IConnectableLayer* AddLstmLayer(const LstmDescriptor& descriptor, const LstmInputParams& params, - const char* name = nullptr) override; + const char* name = nullptr); - IConnectableLayer* AddDivisionLayer(const char* name = nullptr) override; + IConnectableLayer* AddDivisionLayer(const char* name = nullptr); - IConnectableLayer* AddSubtractionLayer(const char* name = nullptr) override; + IConnectableLayer* AddSubtractionLayer(const char* name = nullptr); - IConnectableLayer* AddMaximumLayer(const char* name = nullptr) override; + IConnectableLayer* AddMaximumLayer(const char* name = nullptr); - IConnectableLayer* AddMeanLayer(const MeanDescriptor& meanDescriptor, const char* name = nullptr) override; + IConnectableLayer* AddMeanLayer(const MeanDescriptor& meanDescriptor, const char* name = nullptr); - IConnectableLayer* AddPadLayer(const PadDescriptor& padDescriptor, const char* name = nullptr) override; + IConnectableLayer* AddPadLayer(const PadDescriptor& padDescriptor, const char* name = nullptr); - IConnectableLayer* AddQuantizeLayer(const char* name = nullptr) override; + IConnectableLayer* AddQuantizeLayer(const char* name = nullptr); IConnectableLayer* AddStridedSliceLayer(const StridedSliceDescriptor& stridedSliceDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); - IConnectableLayer* AddMinimumLayer(const char* name = nullptr) override; + IConnectableLayer* AddMinimumLayer(const char* name = nullptr); ARMNN_DEPRECATED_MSG("Use AddComparisonLayer instead") - IConnectableLayer* AddGreaterLayer(const char* name = nullptr) override; + IConnectableLayer* AddGreaterLayer(const char* name = nullptr); ARMNN_DEPRECATED_MSG("Use AddComparisonLayer instead") - IConnectableLayer* AddEqualLayer(const char* name = nullptr) override; + IConnectableLayer* AddEqualLayer(const char* name = nullptr); ARMNN_DEPRECATED_MSG("Use AddElementwiseUnaryLayer instead") - IConnectableLayer* AddRsqrtLayer(const char* name = nullptr) override; + IConnectableLayer* AddRsqrtLayer(const char* name = nullptr); - IConnectableLayer* AddMergeLayer(const char* name = nullptr) override; + IConnectableLayer* AddMergeLayer(const char* name = nullptr); - IConnectableLayer* AddSwitchLayer(const char* name = nullptr) override; + IConnectableLayer* AddSwitchLayer(const char* name = nullptr); - IConnectableLayer* AddPreluLayer(const char* name = nullptr) override; + IConnectableLayer* AddPreluLayer(const char* name = nullptr); IConnectableLayer* AddTransposeConvolution2dLayer(const TransposeConvolution2dDescriptor& descriptor, const ConstTensor& weights, const Optional<ConstTensor>& biases, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddTransposeLayer(const TransposeDescriptor& transposeDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddStackLayer(const StackDescriptor& stackDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddStandInLayer(const StandInDescriptor& descriptor, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddQLstmLayer(const QLstmDescriptor& descriptor, const LstmInputParams& params, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddQuantizedLstmLayer(const QuantizedLstmInputParams& params, - const char* name = nullptr) override; + const char* name = nullptr); IConnectableLayer* AddLogicalBinaryLayer(const LogicalBinaryDescriptor& logicalBinaryDescriptor, - const char* name = nullptr) override; + const char* name = nullptr); - void Accept(ILayerVisitor& visitor) const override; + void Accept(ILayerVisitor& visitor) const; - void ExecuteStrategy(IStrategy& strategy) const override; + void ExecuteStrategy(IStrategy& strategy) const; private: IConnectableLayer* AddFullyConnectedLayerImpl(const FullyConnectedDescriptor& fullyConnectedDescriptor, @@ -284,29 +287,6 @@ private: ModelOptions m_ModelOptions; }; -class OptimizedNetwork final : public IOptimizedNetwork -{ -public: - OptimizedNetwork(std::unique_ptr<Graph> graph); - OptimizedNetwork(std::unique_ptr<Graph> graph, const ModelOptions& modelOptions); - ~OptimizedNetwork(); - - Status PrintGraph() override; - Status SerializeToDot(std::ostream& stream) const override; - - profiling::ProfilingGuid GetGuid() const final { return m_Guid; }; - - Graph& GetGraph() { return *m_Graph; } - ModelOptions& GetModelOptions() { return m_ModelOptions; } - -private: - std::unique_ptr<Graph> m_Graph; - profiling::ProfilingGuid m_Guid; - ModelOptions m_ModelOptions; -}; - - - struct OptimizationResult { bool m_Warning; @@ -338,7 +318,7 @@ OptimizationResult SelectTensorHandleStrategy(Graph& optGraph, bool importEnabled, Optional<std::vector<std::string>&> errMessages); -OptimizationResult AssignBackends(OptimizedNetwork* optNetObjPtr, +OptimizationResult AssignBackends(OptimizedNetworkImpl* optNetObjPtr, BackendSettings& backendSettings, Graph::Iterator& firstLayer, Graph::Iterator& lastLayer, |