aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/test/test_descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/test/test_descriptors.py')
-rw-r--r--python/pyarmnn/test/test_descriptors.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/python/pyarmnn/test/test_descriptors.py b/python/pyarmnn/test/test_descriptors.py
index 54b79d7397..80c5359eb2 100644
--- a/python/pyarmnn/test/test_descriptors.py
+++ b/python/pyarmnn/test/test_descriptors.py
@@ -88,6 +88,11 @@ def test_batchtospacend_descriptor_ctor():
assert [(4, 5), (6, 7)] == desc.m_Crops
+def test_channelshuffle_descriptor_default_values():
+ desc = ann.ChannelShuffleDescriptor()
+ assert desc.m_Axis == 0
+ assert desc.m_NumGroups == 0
+
def test_convolution2d_descriptor_default_values():
desc = ann.Convolution2dDescriptor()
assert desc.m_PadLeft == 0
@@ -527,7 +532,8 @@ generated_classes_names = list(map(lambda x: x[0], generated_classes))
'ElementwiseUnaryDescriptor',
'FillDescriptor',
'GatherDescriptor',
- 'LogicalBinaryDescriptor'])
+ 'LogicalBinaryDescriptor',
+ 'ChannelShuffleDescriptor'])
class TestDescriptorMassChecks:
def test_desc_implemented(self, desc_name):
@@ -574,7 +580,8 @@ generated_classes_names = list(map(lambda x: x[0], generated_classes))
'ElementwiseUnaryDescriptor',
'FillDescriptor',
'GatherDescriptor',
- 'LogicalBinaryDescriptor'])
+ 'LogicalBinaryDescriptor',
+ 'ChannelShuffleDescriptor'])
class TestDescriptorMassChecks:
def test_desc_implemented(self, desc_name):