aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/test
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/test')
-rw-r--r--python/pyarmnn/test/test_descriptors.py11
-rw-r--r--python/pyarmnn/test/test_network.py1
2 files changed, 10 insertions, 2 deletions
diff --git a/python/pyarmnn/test/test_descriptors.py b/python/pyarmnn/test/test_descriptors.py
index 54b79d7397..80c5359eb2 100644
--- a/python/pyarmnn/test/test_descriptors.py
+++ b/python/pyarmnn/test/test_descriptors.py
@@ -88,6 +88,11 @@ def test_batchtospacend_descriptor_ctor():
assert [(4, 5), (6, 7)] == desc.m_Crops
+def test_channelshuffle_descriptor_default_values():
+ desc = ann.ChannelShuffleDescriptor()
+ assert desc.m_Axis == 0
+ assert desc.m_NumGroups == 0
+
def test_convolution2d_descriptor_default_values():
desc = ann.Convolution2dDescriptor()
assert desc.m_PadLeft == 0
@@ -527,7 +532,8 @@ generated_classes_names = list(map(lambda x: x[0], generated_classes))
'ElementwiseUnaryDescriptor',
'FillDescriptor',
'GatherDescriptor',
- 'LogicalBinaryDescriptor'])
+ 'LogicalBinaryDescriptor',
+ 'ChannelShuffleDescriptor'])
class TestDescriptorMassChecks:
def test_desc_implemented(self, desc_name):
@@ -574,7 +580,8 @@ generated_classes_names = list(map(lambda x: x[0], generated_classes))
'ElementwiseUnaryDescriptor',
'FillDescriptor',
'GatherDescriptor',
- 'LogicalBinaryDescriptor'])
+ 'LogicalBinaryDescriptor',
+ 'ChannelShuffleDescriptor'])
class TestDescriptorMassChecks:
def test_desc_implemented(self, desc_name):
diff --git a/python/pyarmnn/test/test_network.py b/python/pyarmnn/test/test_network.py
index 04e1b7a05f..e33470e090 100644
--- a/python/pyarmnn/test/test_network.py
+++ b/python/pyarmnn/test/test_network.py
@@ -193,6 +193,7 @@ def test_serialize_to_dot_mode_readonly(network_file, get_runtime, tmpdir):
'AddBatchNormalizationLayer',
'AddBatchToSpaceNdLayer',
'AddCastLayer',
+ 'AddChannelShuffleLayer',
'AddComparisonLayer',
'AddConcatLayer',
'AddConstantLayer',