diff options
Diffstat (limited to 'python/pyarmnn/test/test_descriptors.py')
-rw-r--r-- | python/pyarmnn/test/test_descriptors.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/python/pyarmnn/test/test_descriptors.py b/python/pyarmnn/test/test_descriptors.py index 54b79d7397..80c5359eb2 100644 --- a/python/pyarmnn/test/test_descriptors.py +++ b/python/pyarmnn/test/test_descriptors.py @@ -88,6 +88,11 @@ def test_batchtospacend_descriptor_ctor(): assert [(4, 5), (6, 7)] == desc.m_Crops +def test_channelshuffle_descriptor_default_values(): + desc = ann.ChannelShuffleDescriptor() + assert desc.m_Axis == 0 + assert desc.m_NumGroups == 0 + def test_convolution2d_descriptor_default_values(): desc = ann.Convolution2dDescriptor() assert desc.m_PadLeft == 0 @@ -527,7 +532,8 @@ generated_classes_names = list(map(lambda x: x[0], generated_classes)) 'ElementwiseUnaryDescriptor', 'FillDescriptor', 'GatherDescriptor', - 'LogicalBinaryDescriptor']) + 'LogicalBinaryDescriptor', + 'ChannelShuffleDescriptor']) class TestDescriptorMassChecks: def test_desc_implemented(self, desc_name): @@ -574,7 +580,8 @@ generated_classes_names = list(map(lambda x: x[0], generated_classes)) 'ElementwiseUnaryDescriptor', 'FillDescriptor', 'GatherDescriptor', - 'LogicalBinaryDescriptor']) + 'LogicalBinaryDescriptor', + 'ChannelShuffleDescriptor']) class TestDescriptorMassChecks: def test_desc_implemented(self, desc_name): |