aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/test/test_descriptors.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/test/test_descriptors.py')
-rw-r--r--python/pyarmnn/test/test_descriptors.py21
1 files changed, 18 insertions, 3 deletions
diff --git a/python/pyarmnn/test/test_descriptors.py b/python/pyarmnn/test/test_descriptors.py
index 6d49747d5a..b0574a14ba 100644
--- a/python/pyarmnn/test/test_descriptors.py
+++ b/python/pyarmnn/test/test_descriptors.py
@@ -143,6 +143,16 @@ def test_fakequantization_descriptor_default_values():
np.allclose(-6, desc.m_Min)
+def test_fill_descriptor_default_values():
+ desc = ann.FillDescriptor()
+ np.allclose(0, desc.m_Value)
+
+
+def test_gather_descriptor_default_values():
+ desc = ann.GatherDescriptor()
+ assert desc.m_Axis == 0
+
+
def test_fully_connected_descriptor_default_values():
desc = ann.FullyConnectedDescriptor()
assert desc.m_BiasEnabled == False
@@ -370,7 +380,7 @@ def test_space_to_batch_nd_descriptor_ctor():
def test_transpose_convolution2d_descriptor_default_values():
- desc = ann.DepthwiseConvolution2dDescriptor()
+ desc = ann.TransposeConvolution2dDescriptor()
assert desc.m_PadLeft == 0
assert desc.m_PadTop == 0
assert desc.m_PadRight == 0
@@ -379,6 +389,7 @@ def test_transpose_convolution2d_descriptor_default_values():
assert desc.m_StrideY == 0
assert desc.m_BiasEnabled == False
assert desc.m_DataLayout == ann.DataLayout_NCHW
+ assert desc.m_OutputShapeEnabled == False
def test_view_descriptor_default_values():
@@ -480,7 +491,9 @@ generated_classes_names = list(map(lambda x: x[0], generated_classes))
'StackDescriptor',
'StridedSliceDescriptor',
'TransposeConvolution2dDescriptor',
- 'ElementwiseUnaryDescriptor'])
+ 'ElementwiseUnaryDescriptor',
+ 'FillDescriptor',
+ 'GatherDescriptor'])
class TestDescriptorMassChecks:
def test_desc_implemented(self, desc_name):
@@ -522,7 +535,9 @@ generated_classes_names = list(map(lambda x: x[0], generated_classes))
'StackDescriptor',
'StridedSliceDescriptor',
'TransposeConvolution2dDescriptor',
- 'ElementwiseUnaryDescriptor'])
+ 'ElementwiseUnaryDescriptor',
+ 'FillDescriptor',
+ 'GatherDescriptor'])
class TestDescriptorMassChecks:
def test_desc_implemented(self, desc_name):