aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/test
diff options
context:
space:
mode:
authorRyan OShea <ryan.oshea3@arm.com>2022-03-09 02:07:24 +0000
committerRyan OShea <ryan.oshea3@arm.com>2022-03-09 02:07:24 +0000
commit89655004ba20d36ec4882ed9c10f5d91aa244af2 (patch)
treeabe29ee911b97bf1256af1832b9e00bb45c5c267 /python/pyarmnn/test
parent3464ba127b83cd36d65cdc7ee9f5dd7b3715a18e (diff)
downloadarmnn-89655004ba20d36ec4882ed9c10f5d91aa244af2.tar.gz
IVGCVSW-6749 Add Pooling3d to PyArmnn
* Add layer to __init__.py * Add descriptor for Pooling3d * Add descriptor test for Pooling3d * Add network test for Pooling3d layer Signed-off-by: Ryan OShea <ryan.oshea3@arm.com> Change-Id: Id5e1587a89d3ffb5bee7764a92b299fa43a2ae35
Diffstat (limited to 'python/pyarmnn/test')
-rw-r--r--python/pyarmnn/test/test_descriptors.py18
-rw-r--r--python/pyarmnn/test/test_network.py1
2 files changed, 19 insertions, 0 deletions
diff --git a/python/pyarmnn/test/test_descriptors.py b/python/pyarmnn/test/test_descriptors.py
index 262b8fcf2a..8969344d6d 100644
--- a/python/pyarmnn/test/test_descriptors.py
+++ b/python/pyarmnn/test/test_descriptors.py
@@ -282,6 +282,24 @@ def test_pooling_descriptor_default_values():
assert desc.m_PaddingMethod == ann.PaddingMethod_Exclude
assert desc.m_DataLayout == ann.DataLayout_NCHW
+def test_pooling_3d_descriptor_default_values():
+ desc = ann.Pooling3dDescriptor()
+ assert desc.m_PoolType == ann.PoolingAlgorithm_Max
+ assert desc.m_PadLeft == 0
+ assert desc.m_PadTop == 0
+ assert desc.m_PadRight == 0
+ assert desc.m_PadBottom == 0
+ assert desc.m_PadFront == 0
+ assert desc.m_PadBack == 0
+ assert desc.m_PoolHeight == 0
+ assert desc.m_PoolWidth == 0
+ assert desc.m_StrideX == 0
+ assert desc.m_StrideY == 0
+ assert desc.m_StrideZ == 0
+ assert desc.m_OutputShapeRounding == ann.OutputShapeRounding_Floor
+ assert desc.m_PaddingMethod == ann.PaddingMethod_Exclude
+ assert desc.m_DataLayout == ann.DataLayout_NCDHW
+
def test_reshape_descriptor_default_values():
desc = ann.ReshapeDescriptor()
diff --git a/python/pyarmnn/test/test_network.py b/python/pyarmnn/test/test_network.py
index 27ad70be3b..8cb81221e2 100644
--- a/python/pyarmnn/test/test_network.py
+++ b/python/pyarmnn/test/test_network.py
@@ -225,6 +225,7 @@ def test_serialize_to_dot_mode_readonly(network_file, get_runtime, tmpdir):
'AddPadLayer',
'AddPermuteLayer',
'AddPooling2dLayer',
+ 'AddPooling3dLayer',
'AddPreluLayer',
'AddQuantizeLayer',
'AddQuantizedLstmLayer',