aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorCathal Corbett <cathal.corbett@arm.com>2021-11-18 10:28:47 +0000
committerCathal Corbett <cathal.corbett@arm.com>2021-11-23 09:22:21 +0000
commit2b4182fa86c4efe43f3a631266b5185d8a725fa1 (patch)
treec0db9b4f5b0eefd4426d4f7f3bc1d22d84f33b89 /python
parentf0836e0d7efa947f7589c476b531944078cc02d2 (diff)
downloadarmnn-2b4182fa86c4efe43f3a631266b5185d8a725fa1.tar.gz
IVGCVSW-6592 AddTransposeLayer to PyArmNN
* AddTransposeLayer to PyArmNN armnn_network.i * AddTranposeDescriptor to PyArmNN armnn_descriptors.i * Add layer to test_network_method_exists() in test_network.py * Add descriptor unit tests to test_descriptors.py Signed-off-by: Cathal Corbett <cathal.corbett@arm.com> Change-Id: Ic198448ad11d10701b6b263656285bb75d3656cd
Diffstat (limited to 'python')
-rw-r--r--python/pyarmnn/src/pyarmnn/__init__.py2
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i21
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/modules/armnn_network.i16
-rw-r--r--python/pyarmnn/test/test_descriptors.py11
-rw-r--r--python/pyarmnn/test/test_network.py3
5 files changed, 49 insertions, 4 deletions
diff --git a/python/pyarmnn/src/pyarmnn/__init__.py b/python/pyarmnn/src/pyarmnn/__init__.py
index 5f3e5ff521..fecc4f5ee0 100644
--- a/python/pyarmnn/src/pyarmnn/__init__.py
+++ b/python/pyarmnn/src/pyarmnn/__init__.py
@@ -99,7 +99,7 @@ from ._generated.pyarmnn import OutputShapeRounding_Ceiling, OutputShapeRounding
from ._generated.pyarmnn import ResizeMethod_Bilinear, ResizeMethod_NearestNeighbor, ResizeDescriptor, \
ReshapeDescriptor, SliceDescriptor, SpaceToBatchNdDescriptor, SpaceToDepthDescriptor, StandInDescriptor, \
StackDescriptor, StridedSliceDescriptor, SoftmaxDescriptor, TransposeConvolution2dDescriptor, \
- SplitterDescriptor
+ TransposeDescriptor, SplitterDescriptor
from ._generated.pyarmnn import ConcatDescriptor, CreateDescriptorForConcatenation
from ._generated.pyarmnn import LstmInputParams, QuantizedLstmInputParams
diff --git a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i
index 8844cea81a..5f4afd399a 100644
--- a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i
+++ b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i
@@ -1000,7 +1000,6 @@ struct SoftmaxDescriptor
bool operator ==(const SoftmaxDescriptor& rhs) const;
};
-
%feature("docstring",
"
A descriptor for the TransposeConvolution2d layer. See `INetwork.AddTransposeConvolution2dLayer()`.
@@ -1056,6 +1055,26 @@ struct LogicalBinaryDescriptor
bool operator ==(const LogicalBinaryDescriptor &rhs) const;
};
+%feature("docstring",
+ "
+ A descriptor for the Transpose layer. See `INetwork.AddTransposeLayer()`.
+
+ Contains:
+ m_DimMappings (PermutationVector): Indicates how to translate tensor elements from a given source into the target destination,
+ when source and target potentially have different memory layouts e.g. {0U, 3U, 1U, 2U}.
+
+ ") TransposeDescriptor;
+struct TransposeDescriptor
+{
+ TransposeDescriptor();
+ TransposeDescriptor(const PermutationVector& dimMappings);
+
+ PermutationVector m_DimMappings;
+
+ bool operator ==(const TransposeDescriptor &rhs) const;
+};
+
+
using ConcatDescriptor = OriginsDescriptor;
using LogSoftmaxDescriptor = SoftmaxDescriptor;
using SplitterDescriptor = ViewsDescriptor;
diff --git a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_network.i b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_network.i
index c978e14c80..f1cae5605b 100644
--- a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_network.i
+++ b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_network.i
@@ -810,7 +810,7 @@ public:
") AddRankLayer;
armnn::IConnectableLayer* AddRankLayer(const char* name = nullptr);
-
+
%feature("docstring",
"
Adds a Reshape layer to the network.
@@ -1006,6 +1006,20 @@ public:
") AddLogicalBinaryLayer;
armnn::IConnectableLayer* AddLogicalBinaryLayer(const armnn::LogicalBinaryDescriptor& logicalBinaryDescriptor,
const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a Transpose layer to the network.
+
+ Args:
+ transposeDescriptor (TransposeDescriptor): Description of the transpose layer.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddTransposeLayer;
+ armnn::IConnectableLayer* AddTransposeLayer(const armnn::TransposeDescriptor& transposeDescriptor,
+ const char* name = nullptr);
};
%extend INetwork {
diff --git a/python/pyarmnn/test/test_descriptors.py b/python/pyarmnn/test/test_descriptors.py
index 0360196614..a39766696f 100644
--- a/python/pyarmnn/test/test_descriptors.py
+++ b/python/pyarmnn/test/test_descriptors.py
@@ -391,6 +391,15 @@ def test_transpose_convolution2d_descriptor_default_values():
assert desc.m_DataLayout == ann.DataLayout_NCHW
assert desc.m_OutputShapeEnabled == False
+def test_transpose_descriptor_default_values():
+ pv = ann.PermutationVector((0, 3, 2, 1, 4))
+ desc = ann.TransposeDescriptor(pv)
+ assert desc.m_DimMappings.GetSize() == 5
+ assert desc.m_DimMappings[0] == 0
+ assert desc.m_DimMappings[1] == 3
+ assert desc.m_DimMappings[2] == 2
+ assert desc.m_DimMappings[3] == 1
+ assert desc.m_DimMappings[4] == 4
def test_view_descriptor_default_values():
desc = ann.SplitterDescriptor()
@@ -495,6 +504,7 @@ generated_classes_names = list(map(lambda x: x[0], generated_classes))
'StackDescriptor',
'StridedSliceDescriptor',
'TransposeConvolution2dDescriptor',
+ 'TransposeDescriptor',
'ElementwiseUnaryDescriptor',
'FillDescriptor',
'GatherDescriptor',
@@ -540,6 +550,7 @@ generated_classes_names = list(map(lambda x: x[0], generated_classes))
'StackDescriptor',
'StridedSliceDescriptor',
'TransposeConvolution2dDescriptor',
+ 'TransposeDescriptor',
'ElementwiseUnaryDescriptor',
'FillDescriptor',
'GatherDescriptor',
diff --git a/python/pyarmnn/test/test_network.py b/python/pyarmnn/test/test_network.py
index 1792041e11..ff1c66edc4 100644
--- a/python/pyarmnn/test/test_network.py
+++ b/python/pyarmnn/test/test_network.py
@@ -238,7 +238,8 @@ def test_serialize_to_dot_mode_readonly(network_file, get_runtime, tmpdir):
'AddStridedSliceLayer',
'AddSubtractionLayer',
'AddSwitchLayer',
- 'AddTransposeConvolution2dLayer'
+ 'AddTransposeConvolution2dLayer',
+ 'AddTransposeLayer'
])
def test_network_method_exists(method):
assert getattr(ann.INetwork, method, None)