aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2021-11-18 15:24:50 +0000
committerTeresa Charlin <teresa.charlinreyes@arm.com>2021-11-23 14:44:59 +0000
commitf7b5011298367c6d635b17ad029c627076072198 (patch)
treec11f1f5e9a84128da71dcd1e76f855fb4746aabf /python
parent181473307ca4e8745dbb0c8474f2615d1752b03a (diff)
downloadarmnn-f7b5011298367c6d635b17ad029c627076072198.tar.gz
IVGCVSW-6585 AddChannelShuffleLayer to PyArmNN
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: I2bfc54ea9aae78c60a66d7a5c39a33ca8a238e62
Diffstat (limited to 'python')
-rw-r--r--python/pyarmnn/src/pyarmnn/__init__.py4
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i20
-rw-r--r--python/pyarmnn/src/pyarmnn/swig/modules/armnn_network.i17
-rw-r--r--python/pyarmnn/test/test_descriptors.py11
-rw-r--r--python/pyarmnn/test/test_network.py1
5 files changed, 48 insertions, 5 deletions
diff --git a/python/pyarmnn/src/pyarmnn/__init__.py b/python/pyarmnn/src/pyarmnn/__init__.py
index 1a71844cfc..b71fc3632b 100644
--- a/python/pyarmnn/src/pyarmnn/__init__.py
+++ b/python/pyarmnn/src/pyarmnn/__init__.py
@@ -79,8 +79,8 @@ from ._generated.pyarmnn import ActivationFunction_Abs, ActivationFunction_Bound
ActivationFunction_Sqrt, ActivationFunction_Square, ActivationFunction_TanH, ActivationDescriptor
from ._generated.pyarmnn import ArgMinMaxFunction_Max, ArgMinMaxFunction_Min, ArgMinMaxDescriptor
from ._generated.pyarmnn import BatchNormalizationDescriptor, BatchToSpaceNdDescriptor
-from ._generated.pyarmnn import ComparisonDescriptor, ComparisonOperation_Equal, ComparisonOperation_Greater, \
- ComparisonOperation_GreaterOrEqual, ComparisonOperation_Less, \
+from ._generated.pyarmnn import ChannelShuffleDescriptor, ComparisonDescriptor, ComparisonOperation_Equal, \
+ ComparisonOperation_Greater, ComparisonOperation_GreaterOrEqual, ComparisonOperation_Less, \
ComparisonOperation_LessOrEqual, ComparisonOperation_NotEqual
from ._generated.pyarmnn import UnaryOperation_Abs, UnaryOperation_Exp, UnaryOperation_Sqrt, UnaryOperation_Rsqrt, \
UnaryOperation_Neg, ElementwiseUnaryDescriptor
diff --git a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i
index e51c8674ab..d20796a88e 100644
--- a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i
+++ b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_descriptors.i
@@ -133,6 +133,26 @@ struct BatchToSpaceNdDescriptor
%feature("docstring",
"
+ A descriptor for the ChannelShuffle layer. See `INetwork.AddChannelShuffleLayer()`.
+
+ Contains:
+ m_NumGroups (int): Underlying C++ type is uint32_t. Number of groups for the shuffle operation. Default: 0.
+ m_Axis (int): Underlying C++ type is uint32_t. 0-based axis along which shuffle is performed. Default: 0.
+
+ ") ChannelShuffleDescriptor;
+struct ChannelShuffleDescriptor
+{
+ ChannelShuffleDescriptor();
+ ChannelShuffleDescriptor(int numGroups, int axis);
+
+ int m_NumGroups;
+ int m_Axis;
+
+ bool operator ==(const ChannelShuffleDescriptor &rhs) const;
+};
+
+%feature("docstring",
+ "
A descriptor for the Comparison layer. See `INetwork.AddComparisonLayer()`.
Contains:
diff --git a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_network.i b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_network.i
index fe626dc2ea..789e428e46 100644
--- a/python/pyarmnn/src/pyarmnn/swig/modules/armnn_network.i
+++ b/python/pyarmnn/src/pyarmnn/swig/modules/armnn_network.i
@@ -425,7 +425,22 @@ public:
IConnectableLayer: Interface for configuring the layer.
") AddBatchToSpaceNdLayer;
armnn::IConnectableLayer* AddBatchToSpaceNdLayer(const armnn::BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor,
- const char* name = nullptr);
+ const char* name = nullptr);
+
+ %feature("docstring",
+ "
+ Adds a ChannelShuffle layer to the network.
+
+ Args:
+ channelShuffleDescriptor (ChannelShuffleDescriptor): Configuration parameters for the layer.
+ name (str): Optional name for the layer.
+
+ Returns:
+ IConnectableLayer: Interface for configuring the layer.
+ ") AddChannelShuffleLayer;
+ armnn::IConnectableLayer* AddChannelShuffleLayer(const armnn::ChannelShuffleDescriptor& channelShuffleDescriptor,
+ const char* name = nullptr);
+
%feature("docstring",
diff --git a/python/pyarmnn/test/test_descriptors.py b/python/pyarmnn/test/test_descriptors.py
index 54b79d7397..80c5359eb2 100644
--- a/python/pyarmnn/test/test_descriptors.py
+++ b/python/pyarmnn/test/test_descriptors.py
@@ -88,6 +88,11 @@ def test_batchtospacend_descriptor_ctor():
assert [(4, 5), (6, 7)] == desc.m_Crops
+def test_channelshuffle_descriptor_default_values():
+ desc = ann.ChannelShuffleDescriptor()
+ assert desc.m_Axis == 0
+ assert desc.m_NumGroups == 0
+
def test_convolution2d_descriptor_default_values():
desc = ann.Convolution2dDescriptor()
assert desc.m_PadLeft == 0
@@ -527,7 +532,8 @@ generated_classes_names = list(map(lambda x: x[0], generated_classes))
'ElementwiseUnaryDescriptor',
'FillDescriptor',
'GatherDescriptor',
- 'LogicalBinaryDescriptor'])
+ 'LogicalBinaryDescriptor',
+ 'ChannelShuffleDescriptor'])
class TestDescriptorMassChecks:
def test_desc_implemented(self, desc_name):
@@ -574,7 +580,8 @@ generated_classes_names = list(map(lambda x: x[0], generated_classes))
'ElementwiseUnaryDescriptor',
'FillDescriptor',
'GatherDescriptor',
- 'LogicalBinaryDescriptor'])
+ 'LogicalBinaryDescriptor',
+ 'ChannelShuffleDescriptor'])
class TestDescriptorMassChecks:
def test_desc_implemented(self, desc_name):
diff --git a/python/pyarmnn/test/test_network.py b/python/pyarmnn/test/test_network.py
index 04e1b7a05f..e33470e090 100644
--- a/python/pyarmnn/test/test_network.py
+++ b/python/pyarmnn/test/test_network.py
@@ -193,6 +193,7 @@ def test_serialize_to_dot_mode_readonly(network_file, get_runtime, tmpdir):
'AddBatchNormalizationLayer',
'AddBatchToSpaceNdLayer',
'AddCastLayer',
+ 'AddChannelShuffleLayer',
'AddComparisonLayer',
'AddConcatLayer',
'AddConstantLayer',