diff options
Diffstat (limited to 'python/pyarmnn/test/test_descriptors.py')
-rw-r--r-- | python/pyarmnn/test/test_descriptors.py | 21 |
1 files changed, 18 insertions, 3 deletions
diff --git a/python/pyarmnn/test/test_descriptors.py b/python/pyarmnn/test/test_descriptors.py index 6d49747d5a..b0574a14ba 100644 --- a/python/pyarmnn/test/test_descriptors.py +++ b/python/pyarmnn/test/test_descriptors.py @@ -143,6 +143,16 @@ def test_fakequantization_descriptor_default_values(): np.allclose(-6, desc.m_Min) +def test_fill_descriptor_default_values(): + desc = ann.FillDescriptor() + np.allclose(0, desc.m_Value) + + +def test_gather_descriptor_default_values(): + desc = ann.GatherDescriptor() + assert desc.m_Axis == 0 + + def test_fully_connected_descriptor_default_values(): desc = ann.FullyConnectedDescriptor() assert desc.m_BiasEnabled == False @@ -370,7 +380,7 @@ def test_space_to_batch_nd_descriptor_ctor(): def test_transpose_convolution2d_descriptor_default_values(): - desc = ann.DepthwiseConvolution2dDescriptor() + desc = ann.TransposeConvolution2dDescriptor() assert desc.m_PadLeft == 0 assert desc.m_PadTop == 0 assert desc.m_PadRight == 0 @@ -379,6 +389,7 @@ def test_transpose_convolution2d_descriptor_default_values(): assert desc.m_StrideY == 0 assert desc.m_BiasEnabled == False assert desc.m_DataLayout == ann.DataLayout_NCHW + assert desc.m_OutputShapeEnabled == False def test_view_descriptor_default_values(): @@ -480,7 +491,9 @@ generated_classes_names = list(map(lambda x: x[0], generated_classes)) 'StackDescriptor', 'StridedSliceDescriptor', 'TransposeConvolution2dDescriptor', - 'ElementwiseUnaryDescriptor']) + 'ElementwiseUnaryDescriptor', + 'FillDescriptor', + 'GatherDescriptor']) class TestDescriptorMassChecks: def test_desc_implemented(self, desc_name): @@ -522,7 +535,9 @@ generated_classes_names = list(map(lambda x: x[0], generated_classes)) 'StackDescriptor', 'StridedSliceDescriptor', 'TransposeConvolution2dDescriptor', - 'ElementwiseUnaryDescriptor']) + 'ElementwiseUnaryDescriptor', + 'FillDescriptor', + 'GatherDescriptor']) class TestDescriptorMassChecks: def test_desc_implemented(self, desc_name): |