diff options
Diffstat (limited to 'python/pyarmnn/test')
-rw-r--r-- | python/pyarmnn/test/test_descriptors.py | 11 | ||||
-rw-r--r-- | python/pyarmnn/test/test_network.py | 3 |
2 files changed, 13 insertions, 1 deletions
diff --git a/python/pyarmnn/test/test_descriptors.py b/python/pyarmnn/test/test_descriptors.py index 0360196614..a39766696f 100644 --- a/python/pyarmnn/test/test_descriptors.py +++ b/python/pyarmnn/test/test_descriptors.py @@ -391,6 +391,15 @@ def test_transpose_convolution2d_descriptor_default_values(): assert desc.m_DataLayout == ann.DataLayout_NCHW assert desc.m_OutputShapeEnabled == False +def test_transpose_descriptor_default_values(): + pv = ann.PermutationVector((0, 3, 2, 1, 4)) + desc = ann.TransposeDescriptor(pv) + assert desc.m_DimMappings.GetSize() == 5 + assert desc.m_DimMappings[0] == 0 + assert desc.m_DimMappings[1] == 3 + assert desc.m_DimMappings[2] == 2 + assert desc.m_DimMappings[3] == 1 + assert desc.m_DimMappings[4] == 4 def test_view_descriptor_default_values(): desc = ann.SplitterDescriptor() @@ -495,6 +504,7 @@ generated_classes_names = list(map(lambda x: x[0], generated_classes)) 'StackDescriptor', 'StridedSliceDescriptor', 'TransposeConvolution2dDescriptor', + 'TransposeDescriptor', 'ElementwiseUnaryDescriptor', 'FillDescriptor', 'GatherDescriptor', @@ -540,6 +550,7 @@ generated_classes_names = list(map(lambda x: x[0], generated_classes)) 'StackDescriptor', 'StridedSliceDescriptor', 'TransposeConvolution2dDescriptor', + 'TransposeDescriptor', 'ElementwiseUnaryDescriptor', 'FillDescriptor', 'GatherDescriptor', diff --git a/python/pyarmnn/test/test_network.py b/python/pyarmnn/test/test_network.py index 1792041e11..ff1c66edc4 100644 --- a/python/pyarmnn/test/test_network.py +++ b/python/pyarmnn/test/test_network.py @@ -238,7 +238,8 @@ def test_serialize_to_dot_mode_readonly(network_file, get_runtime, tmpdir): 'AddStridedSliceLayer', 'AddSubtractionLayer', 'AddSwitchLayer', - 'AddTransposeConvolution2dLayer' + 'AddTransposeConvolution2dLayer', + 'AddTransposeLayer' ]) def test_network_method_exists(method): assert getattr(ann.INetwork, method, None) |