aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/test
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2021-11-18 15:24:50 +0000
committerTeresa Charlin <teresa.charlinreyes@arm.com>2021-11-23 14:44:59 +0000
commitf7b5011298367c6d635b17ad029c627076072198 (patch)
treec11f1f5e9a84128da71dcd1e76f855fb4746aabf /python/pyarmnn/test
parent181473307ca4e8745dbb0c8474f2615d1752b03a (diff)
downloadarmnn-f7b5011298367c6d635b17ad029c627076072198.tar.gz
IVGCVSW-6585 AddChannelShuffleLayer to PyArmNN
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: I2bfc54ea9aae78c60a66d7a5c39a33ca8a238e62
Diffstat (limited to 'python/pyarmnn/test')
-rw-r--r--python/pyarmnn/test/test_descriptors.py11
-rw-r--r--python/pyarmnn/test/test_network.py1
2 files changed, 10 insertions, 2 deletions
diff --git a/python/pyarmnn/test/test_descriptors.py b/python/pyarmnn/test/test_descriptors.py
index 54b79d7397..80c5359eb2 100644
--- a/python/pyarmnn/test/test_descriptors.py
+++ b/python/pyarmnn/test/test_descriptors.py
@@ -88,6 +88,11 @@ def test_batchtospacend_descriptor_ctor():
assert [(4, 5), (6, 7)] == desc.m_Crops
+def test_channelshuffle_descriptor_default_values():
+ desc = ann.ChannelShuffleDescriptor()
+ assert desc.m_Axis == 0
+ assert desc.m_NumGroups == 0
+
def test_convolution2d_descriptor_default_values():
desc = ann.Convolution2dDescriptor()
assert desc.m_PadLeft == 0
@@ -527,7 +532,8 @@ generated_classes_names = list(map(lambda x: x[0], generated_classes))
'ElementwiseUnaryDescriptor',
'FillDescriptor',
'GatherDescriptor',
- 'LogicalBinaryDescriptor'])
+ 'LogicalBinaryDescriptor',
+ 'ChannelShuffleDescriptor'])
class TestDescriptorMassChecks:
def test_desc_implemented(self, desc_name):
@@ -574,7 +580,8 @@ generated_classes_names = list(map(lambda x: x[0], generated_classes))
'ElementwiseUnaryDescriptor',
'FillDescriptor',
'GatherDescriptor',
- 'LogicalBinaryDescriptor'])
+ 'LogicalBinaryDescriptor',
+ 'ChannelShuffleDescriptor'])
class TestDescriptorMassChecks:
def test_desc_implemented(self, desc_name):
diff --git a/python/pyarmnn/test/test_network.py b/python/pyarmnn/test/test_network.py
index 04e1b7a05f..e33470e090 100644
--- a/python/pyarmnn/test/test_network.py
+++ b/python/pyarmnn/test/test_network.py
@@ -193,6 +193,7 @@ def test_serialize_to_dot_mode_readonly(network_file, get_runtime, tmpdir):
'AddBatchNormalizationLayer',
'AddBatchToSpaceNdLayer',
'AddCastLayer',
+ 'AddChannelShuffleLayer',
'AddComparisonLayer',
'AddConcatLayer',
'AddConstantLayer',