aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Network.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Network.hpp')
-rw-r--r--src/armnn/Network.hpp176
1 files changed, 78 insertions, 98 deletions
diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp
index cffade5a21..8f16be1684 100644
--- a/src/armnn/Network.hpp
+++ b/src/armnn/Network.hpp
@@ -19,246 +19,249 @@
#include "Graph.hpp"
#include "Layer.hpp"
+#include "OptimizedNetworkImpl.hpp"
namespace armnn
{
class Graph;
+using NetworkImplPtr = std::unique_ptr<NetworkImpl, void(*)(NetworkImpl* network)>;
+
/// Private implementation of INetwork.
-class Network final : public INetwork
+class NetworkImpl
{
public:
- Network(NetworkOptions networkOptions = {});
- ~Network();
+ NetworkImpl(NetworkOptions networkOptions = {});
+ ~NetworkImpl();
const Graph& GetGraph() const { return *m_Graph; }
- Status PrintGraph() override;
+ Status PrintGraph();
- IConnectableLayer* AddInputLayer(LayerBindingId id, const char* name=nullptr) override;
+ IConnectableLayer* AddInputLayer(LayerBindingId id, const char* name=nullptr);
IConnectableLayer* AddArgMinMaxLayer(const ArgMinMaxDescriptor& desc,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddBatchToSpaceNdLayer(const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddComparisonLayer(const ComparisonDescriptor& comparisonDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddConcatLayer(const ConcatDescriptor& concatDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
const ConstTensor& weights,
const Optional<ConstTensor>& biases,
- const char* name = nullptr) override;
+ const char* name = nullptr);
ARMNN_DEPRECATED_MSG("This AddConvolution2dLayer overload is deprecated")
IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
const ConstTensor& weights,
- const char* name = nullptr) override;
+ const char* name = nullptr);
ARMNN_DEPRECATED_MSG("This AddConvolution2dLayer overload is deprecated")
IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
const ConstTensor& weights,
const ConstTensor& biases,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddDepthToSpaceLayer(const DepthToSpaceDescriptor& depthToSpaceDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddDepthwiseConvolution2dLayer(
const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
const ConstTensor& weights,
const Optional<ConstTensor>& biases,
- const char* name = nullptr) override;
+ const char* name = nullptr);
ARMNN_DEPRECATED_MSG("This AddDepthwiseConvolution2dLayer overload is deprecated")
IConnectableLayer* AddDepthwiseConvolution2dLayer(
const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
const ConstTensor& weights,
- const char* name = nullptr) override;
+ const char* name = nullptr);
ARMNN_DEPRECATED_MSG("This AddDepthwiseConvolution2dLayer overload is deprecated")
IConnectableLayer* AddDepthwiseConvolution2dLayer(
const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
const ConstTensor& weights,
const ConstTensor& biases,
- const char* name = nullptr) override;
+ const char* name = nullptr);
- IConnectableLayer* AddDequantizeLayer(const char* name = nullptr) override;
+ IConnectableLayer* AddDequantizeLayer(const char* name = nullptr);
IConnectableLayer* AddDetectionPostProcessLayer(
const DetectionPostProcessDescriptor& descriptor,
const ConstTensor& anchors,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddElementwiseUnaryLayer(const ElementwiseUnaryDescriptor& elementwiseUnaryDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddFillLayer(const FillDescriptor& fillDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
const ConstTensor& weights,
const Optional<ConstTensor>& biases,
- const char* name = nullptr) override;
+ const char* name = nullptr);
ARMNN_DEPRECATED_MSG("This AddFullyConnectedLayer overload is deprecated")
IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
const ConstTensor& weights,
- const char* name = nullptr) override;
+ const char* name = nullptr);
ARMNN_DEPRECATED_MSG("This AddFullyConnectedLayer overload is deprecated")
IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
const ConstTensor& weights,
const ConstTensor& biases,
- const char* name = nullptr) override;
+ const char* name = nullptr);
ARMNN_DEPRECATED_MSG("This AddGatherLayer overload is deprecated")
- IConnectableLayer* AddGatherLayer(const char* name = nullptr) override;
+ IConnectableLayer* AddGatherLayer(const char* name = nullptr);
IConnectableLayer* AddGatherLayer(const GatherDescriptor& gatherDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddPermuteLayer(const PermuteDescriptor& permuteDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddPooling2dLayer(const Pooling2dDescriptor& pooling2dDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddActivationLayer(const ActivationDescriptor& activationDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddNormalizationLayer(const NormalizationDescriptor& normalizationDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
- IConnectableLayer* AddSliceLayer(const SliceDescriptor& sliceDescriptor, const char* name = nullptr) override;
+ IConnectableLayer* AddSliceLayer(const SliceDescriptor& sliceDescriptor, const char* name = nullptr);
IConnectableLayer* AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddSplitterLayer(const ViewsDescriptor& splitterDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
ARMNN_DEPRECATED_MSG("Use AddConcatLayer instead")
IConnectableLayer* AddMergerLayer(const MergerDescriptor& mergerDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
ARMNN_DEPRECATED_MSG("Use AddElementwiseUnaryLayer instead")
- IConnectableLayer* AddAbsLayer(const char* name = nullptr) override;
+ IConnectableLayer* AddAbsLayer(const char* name = nullptr);
- IConnectableLayer* AddAdditionLayer(const char* name = nullptr) override;
+ IConnectableLayer* AddAdditionLayer(const char* name = nullptr);
- IConnectableLayer* AddMultiplicationLayer(const char* name = nullptr) override;
+ IConnectableLayer* AddMultiplicationLayer(const char* name = nullptr);
IConnectableLayer* AddBatchNormalizationLayer(const BatchNormalizationDescriptor& desc,
const ConstTensor& mean,
const ConstTensor& variance,
const ConstTensor& beta,
const ConstTensor& gamma,
- const char* name = nullptr) override;
+ const char* name = nullptr);
- IConnectableLayer* AddRankLayer(const char* name = nullptr) override;
+ IConnectableLayer* AddRankLayer(const char* name = nullptr);
ARMNN_DEPRECATED_MSG("Use AddResizeLayer instead")
IConnectableLayer* AddResizeBilinearLayer(const ResizeBilinearDescriptor& resizeDesc,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddResizeLayer(const ResizeDescriptor& resizeDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddReduceLayer(const ReduceDescriptor& reduceDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddInstanceNormalizationLayer(const InstanceNormalizationDescriptor& desc,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddL2NormalizationLayer(const L2NormalizationDescriptor& desc,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddLogSoftmaxLayer(const LogSoftmaxDescriptor& logSoftmaxDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
- IConnectableLayer* AddConstantLayer(const ConstTensor& input, const char* name = nullptr) override;
+ IConnectableLayer* AddConstantLayer(const ConstTensor& input, const char* name = nullptr);
IConnectableLayer* AddReshapeLayer(const ReshapeDescriptor& reshapeDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddSpaceToBatchNdLayer(const SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddSpaceToDepthLayer(const SpaceToDepthDescriptor& spaceToDepthDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
- IConnectableLayer* AddFloorLayer(const char* name = nullptr) override;
+ IConnectableLayer* AddFloorLayer(const char* name = nullptr);
- IConnectableLayer* AddOutputLayer(LayerBindingId id, const char* name = nullptr) override;
+ IConnectableLayer* AddOutputLayer(LayerBindingId id, const char* name = nullptr);
IConnectableLayer* AddLstmLayer(const LstmDescriptor& descriptor,
const LstmInputParams& params,
- const char* name = nullptr) override;
+ const char* name = nullptr);
- IConnectableLayer* AddDivisionLayer(const char* name = nullptr) override;
+ IConnectableLayer* AddDivisionLayer(const char* name = nullptr);
- IConnectableLayer* AddSubtractionLayer(const char* name = nullptr) override;
+ IConnectableLayer* AddSubtractionLayer(const char* name = nullptr);
- IConnectableLayer* AddMaximumLayer(const char* name = nullptr) override;
+ IConnectableLayer* AddMaximumLayer(const char* name = nullptr);
- IConnectableLayer* AddMeanLayer(const MeanDescriptor& meanDescriptor, const char* name = nullptr) override;
+ IConnectableLayer* AddMeanLayer(const MeanDescriptor& meanDescriptor, const char* name = nullptr);
- IConnectableLayer* AddPadLayer(const PadDescriptor& padDescriptor, const char* name = nullptr) override;
+ IConnectableLayer* AddPadLayer(const PadDescriptor& padDescriptor, const char* name = nullptr);
- IConnectableLayer* AddQuantizeLayer(const char* name = nullptr) override;
+ IConnectableLayer* AddQuantizeLayer(const char* name = nullptr);
IConnectableLayer* AddStridedSliceLayer(const StridedSliceDescriptor& stridedSliceDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
- IConnectableLayer* AddMinimumLayer(const char* name = nullptr) override;
+ IConnectableLayer* AddMinimumLayer(const char* name = nullptr);
ARMNN_DEPRECATED_MSG("Use AddComparisonLayer instead")
- IConnectableLayer* AddGreaterLayer(const char* name = nullptr) override;
+ IConnectableLayer* AddGreaterLayer(const char* name = nullptr);
ARMNN_DEPRECATED_MSG("Use AddComparisonLayer instead")
- IConnectableLayer* AddEqualLayer(const char* name = nullptr) override;
+ IConnectableLayer* AddEqualLayer(const char* name = nullptr);
ARMNN_DEPRECATED_MSG("Use AddElementwiseUnaryLayer instead")
- IConnectableLayer* AddRsqrtLayer(const char* name = nullptr) override;
+ IConnectableLayer* AddRsqrtLayer(const char* name = nullptr);
- IConnectableLayer* AddMergeLayer(const char* name = nullptr) override;
+ IConnectableLayer* AddMergeLayer(const char* name = nullptr);
- IConnectableLayer* AddSwitchLayer(const char* name = nullptr) override;
+ IConnectableLayer* AddSwitchLayer(const char* name = nullptr);
- IConnectableLayer* AddPreluLayer(const char* name = nullptr) override;
+ IConnectableLayer* AddPreluLayer(const char* name = nullptr);
IConnectableLayer* AddTransposeConvolution2dLayer(const TransposeConvolution2dDescriptor& descriptor,
const ConstTensor& weights,
const Optional<ConstTensor>& biases,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddTransposeLayer(const TransposeDescriptor& transposeDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddStackLayer(const StackDescriptor& stackDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddStandInLayer(const StandInDescriptor& descriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddQLstmLayer(const QLstmDescriptor& descriptor,
const LstmInputParams& params,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddQuantizedLstmLayer(const QuantizedLstmInputParams& params,
- const char* name = nullptr) override;
+ const char* name = nullptr);
IConnectableLayer* AddLogicalBinaryLayer(const LogicalBinaryDescriptor& logicalBinaryDescriptor,
- const char* name = nullptr) override;
+ const char* name = nullptr);
- void Accept(ILayerVisitor& visitor) const override;
+ void Accept(ILayerVisitor& visitor) const;
- void ExecuteStrategy(IStrategy& strategy) const override;
+ void ExecuteStrategy(IStrategy& strategy) const;
private:
IConnectableLayer* AddFullyConnectedLayerImpl(const FullyConnectedDescriptor& fullyConnectedDescriptor,
@@ -284,29 +287,6 @@ private:
ModelOptions m_ModelOptions;
};
-class OptimizedNetwork final : public IOptimizedNetwork
-{
-public:
- OptimizedNetwork(std::unique_ptr<Graph> graph);
- OptimizedNetwork(std::unique_ptr<Graph> graph, const ModelOptions& modelOptions);
- ~OptimizedNetwork();
-
- Status PrintGraph() override;
- Status SerializeToDot(std::ostream& stream) const override;
-
- profiling::ProfilingGuid GetGuid() const final { return m_Guid; };
-
- Graph& GetGraph() { return *m_Graph; }
- ModelOptions& GetModelOptions() { return m_ModelOptions; }
-
-private:
- std::unique_ptr<Graph> m_Graph;
- profiling::ProfilingGuid m_Guid;
- ModelOptions m_ModelOptions;
-};
-
-
-
struct OptimizationResult
{
bool m_Warning;
@@ -338,7 +318,7 @@ OptimizationResult SelectTensorHandleStrategy(Graph& optGraph,
bool importEnabled,
Optional<std::vector<std::string>&> errMessages);
-OptimizationResult AssignBackends(OptimizedNetwork* optNetObjPtr,
+OptimizationResult AssignBackends(OptimizedNetworkImpl* optNetObjPtr,
BackendSettings& backendSettings,
Graph::Iterator& firstLayer,
Graph::Iterator& lastLayer,