From 2b4182fa86c4efe43f3a631266b5185d8a725fa1 Mon Sep 17 00:00:00 2001 From: Cathal Corbett Date: Thu, 18 Nov 2021 10:28:47 +0000 Subject: IVGCVSW-6592 AddTransposeLayer to PyArmNN * AddTransposeLayer to PyArmNN armnn_network.i * AddTranposeDescriptor to PyArmNN armnn_descriptors.i * Add layer to test_network_method_exists() in test_network.py * Add descriptor unit tests to test_descriptors.py Signed-off-by: Cathal Corbett Change-Id: Ic198448ad11d10701b6b263656285bb75d3656cd --- python/pyarmnn/src/pyarmnn/__init__.py | 2 +- .../src/pyarmnn/swig/modules/armnn_descriptors.i | 21 ++++++++++++++++++++- .../src/pyarmnn/swig/modules/armnn_network.i | 16 +++++++++++++++- 3 files changed, 36 insertions(+), 3 deletions(-) (limited to 'python/pyarmnn/src') diff --git a/python/pyarmnn/src/pyarmnn/__init__.py b/python/pyarmnn/src/pyarmnn/__init__.py index 5f3e5ff521..fecc4f5ee0 100644 --- a/python/pyarmnn/src/pyarmnn/__init__.py +++ b/python/pyarmnn/src/pyarmnn/__init__.py @@ -99,7 +99,7 @@ from ._generated.pyarmnn import OutputShapeRounding_Ceiling, OutputShapeRounding from ._generated.pyarmnn import ResizeMethod_Bilinear, ResizeMethod_NearestNeighbor, ResizeDescriptor, \ ReshapeDescriptor, SliceDescriptor, SpaceToBatchNdDescriptor, SpaceToDepthDescriptor, StandInDescriptor, \ StackDescriptor, StridedSliceDescriptor, SoftmaxDescriptor, TransposeConvolution2dDescriptor, \ - SplitterDescriptor + TransposeDescriptor, SplitterDescriptor from ._generated.pyarmnn import ConcatDescriptor, CreateDescriptorForConcatenation from ._generated.pyarmnn import LstmInputParams, QuantizedLstmInputParams diff --git a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i index 8844cea81a..5f4afd399a 100644 --- a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i +++ b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i @@ -1000,7 +1000,6 @@ struct SoftmaxDescriptor bool operator ==(const SoftmaxDescriptor& rhs) const; }; - %feature("docstring", " A descriptor for the TransposeConvolution2d layer. See `INetwork.AddTransposeConvolution2dLayer()`. @@ -1056,6 +1055,26 @@ struct LogicalBinaryDescriptor bool operator ==(const LogicalBinaryDescriptor &rhs) const; }; +%feature("docstring", + " + A descriptor for the Transpose layer. See `INetwork.AddTransposeLayer()`. + + Contains: + m_DimMappings (PermutationVector): Indicates how to translate tensor elements from a given source into the target destination, + when source and target potentially have different memory layouts e.g. {0U, 3U, 1U, 2U}. + + ") TransposeDescriptor; +struct TransposeDescriptor +{ + TransposeDescriptor(); + TransposeDescriptor(const PermutationVector& dimMappings); + + PermutationVector m_DimMappings; + + bool operator ==(const TransposeDescriptor &rhs) const; +}; + + using ConcatDescriptor = OriginsDescriptor; using LogSoftmaxDescriptor = SoftmaxDescriptor; using SplitterDescriptor = ViewsDescriptor; diff --git a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_network.i b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_network.i index c978e14c80..f1cae5605b 100644 --- a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_network.i +++ b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_network.i @@ -810,7 +810,7 @@ public: ") AddRankLayer; armnn::IConnectableLayer* AddRankLayer(const char* name = nullptr); - + %feature("docstring", " Adds a Reshape layer to the network. @@ -1006,6 +1006,20 @@ public: ") AddLogicalBinaryLayer; armnn::IConnectableLayer* AddLogicalBinaryLayer(const armnn::LogicalBinaryDescriptor& logicalBinaryDescriptor, const char* name = nullptr); + + %feature("docstring", + " + Adds a Transpose layer to the network. + + Args: + transposeDescriptor (TransposeDescriptor): Description of the transpose layer. + name (str): Optional name for the layer. + + Returns: + IConnectableLayer: Interface for configuring the layer. + ") AddTransposeLayer; + armnn::IConnectableLayer* AddTransposeLayer(const armnn::TransposeDescriptor& transposeDescriptor, + const char* name = nullptr); }; %extend INetwork { -- cgit v1.2.1