aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/test
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2020-08-12 14:59:06 +0100
committerTeresaARM <teresa.charlinreyes@arm.com>2020-08-17 16:10:54 +0000
commit841aca155b35cc17ea9527599d2c364695e28166 (patch)
tree8861d6f275f5955495201f0e1b0677bb44604a17 /python/pyarmnn/test
parent313c99f9a64c7fc51dc70757bffff3088c2e95cf (diff)
downloadarmnn-841aca155b35cc17ea9527599d2c364695e28166.tar.gz
IVGCVSW-5200 Update pyarmnn
* Add HalfPixelCenters to Resize * Update pyarmnn version to semantic versioning * Add fill operator * Add Bf16 optimization * Add Gather operator * Update TransposeConvolution2d descriptor * Add Rank operator * Add load dynamic tensor support of TfLiteParser Signed-off-by: Jan Eilers <jan.eilers@arm.com> Change-Id: I7e76ed286ab87bd97a65ff62868ba7db7967376f
Diffstat (limited to 'python/pyarmnn/test')
-rw-r--r--python/pyarmnn/test/test_descriptors.py21
-rw-r--r--python/pyarmnn/test/test_network.py19
-rw-r--r--python/pyarmnn/test/test_tflite_parser.py42
3 files changed, 79 insertions, 3 deletions
diff --git a/python/pyarmnn/test/test_descriptors.py b/python/pyarmnn/test/test_descriptors.py
index 6d49747d5a..b0574a14ba 100644
--- a/python/pyarmnn/test/test_descriptors.py
+++ b/python/pyarmnn/test/test_descriptors.py
@@ -143,6 +143,16 @@ def test_fakequantization_descriptor_default_values():
np.allclose(-6, desc.m_Min)
+def test_fill_descriptor_default_values():
+ desc = ann.FillDescriptor()
+ np.allclose(0, desc.m_Value)
+
+
+def test_gather_descriptor_default_values():
+ desc = ann.GatherDescriptor()
+ assert desc.m_Axis == 0
+
+
def test_fully_connected_descriptor_default_values():
desc = ann.FullyConnectedDescriptor()
assert desc.m_BiasEnabled == False
@@ -370,7 +380,7 @@ def test_space_to_batch_nd_descriptor_ctor():
def test_transpose_convolution2d_descriptor_default_values():
- desc = ann.DepthwiseConvolution2dDescriptor()
+ desc = ann.TransposeConvolution2dDescriptor()
assert desc.m_PadLeft == 0
assert desc.m_PadTop == 0
assert desc.m_PadRight == 0
@@ -379,6 +389,7 @@ def test_transpose_convolution2d_descriptor_default_values():
assert desc.m_StrideY == 0
assert desc.m_BiasEnabled == False
assert desc.m_DataLayout == ann.DataLayout_NCHW
+ assert desc.m_OutputShapeEnabled == False
def test_view_descriptor_default_values():
@@ -480,7 +491,9 @@ generated_classes_names = list(map(lambda x: x[0], generated_classes))
'StackDescriptor',
'StridedSliceDescriptor',
'TransposeConvolution2dDescriptor',
- 'ElementwiseUnaryDescriptor'])
+ 'ElementwiseUnaryDescriptor',
+ 'FillDescriptor',
+ 'GatherDescriptor'])
class TestDescriptorMassChecks:
def test_desc_implemented(self, desc_name):
@@ -522,7 +535,9 @@ generated_classes_names = list(map(lambda x: x[0], generated_classes))
'StackDescriptor',
'StridedSliceDescriptor',
'TransposeConvolution2dDescriptor',
- 'ElementwiseUnaryDescriptor'])
+ 'ElementwiseUnaryDescriptor',
+ 'FillDescriptor',
+ 'GatherDescriptor'])
class TestDescriptorMassChecks:
def test_desc_implemented(self, desc_name):
diff --git a/python/pyarmnn/test/test_network.py b/python/pyarmnn/test/test_network.py
index fc2591c1d5..679e640374 100644
--- a/python/pyarmnn/test/test_network.py
+++ b/python/pyarmnn/test/test_network.py
@@ -6,6 +6,23 @@ import stat
import pytest
import pyarmnn as ann
+def test_optimizer_options_default_values():
+ opt = ann.OptimizerOptions()
+ assert opt.m_ReduceFp32ToFp16 == False
+ assert opt.m_Debug == False
+ assert opt.m_ReduceFp32ToBf16 == False
+
+def test_optimizer_options_set_values1():
+ opt = ann.OptimizerOptions(True, True)
+ assert opt.m_ReduceFp32ToFp16 == True
+ assert opt.m_Debug == True
+ assert opt.m_ReduceFp32ToBf16 == False
+
+def test_optimizer_options_set_values2():
+ opt = ann.OptimizerOptions(False, False, True)
+ assert opt.m_ReduceFp32ToFp16 == False
+ assert opt.m_Debug == False
+ assert opt.m_ReduceFp32ToBf16 == True
@pytest.fixture(scope="function")
def get_runtime(shared_data_folder, network_file):
@@ -166,6 +183,7 @@ def test_serialize_to_dot_mode_readonly(network_file, get_runtime, tmpdir):
'AddDivisionLayer',
'AddElementwiseUnaryLayer',
'AddFloorLayer',
+ 'AddFillLayer',
'AddFullyConnectedLayer',
'AddGatherLayer',
'AddInputLayer',
@@ -186,6 +204,7 @@ def test_serialize_to_dot_mode_readonly(network_file, get_runtime, tmpdir):
'AddPreluLayer',
'AddQuantizeLayer',
'AddQuantizedLstmLayer',
+ 'AddRankLayer',
'AddReshapeLayer',
'AddResizeLayer',
'AddSliceLayer',
diff --git a/python/pyarmnn/test/test_tflite_parser.py b/python/pyarmnn/test/test_tflite_parser.py
index 344ec7ca13..8735eef3fb 100644
--- a/python/pyarmnn/test/test_tflite_parser.py
+++ b/python/pyarmnn/test/test_tflite_parser.py
@@ -7,6 +7,12 @@ import pyarmnn as ann
import numpy as np
+def test_TfLiteParserOptions_default_values():
+ parserOptions = ann.TfLiteParserOptions()
+ assert parserOptions.m_InferAndValidate == False
+ assert parserOptions.m_StandInLayerForUnsupported == False
+
+
@pytest.fixture()
def parser(shared_data_folder):
"""
@@ -30,6 +36,42 @@ def test_check_tflite_parser_swig_ownership(parser):
assert parser.thisown
+def test_tflite_parser_with_optional_options():
+ parserOptions = ann.TfLiteParserOptions()
+ parserOptions.m_InferAndValidate = True
+ parser = ann.ITfLiteParser(parserOptions)
+ assert parser.thisown
+
+
+def create_with_opt() :
+ parserOptions = ann.TfLiteParserOptions()
+ parserOptions.m_InferAndValidate = True
+ return ann.ITfLiteParser(parserOptions)
+
+def test_tflite_parser_with_optional_options_out_of_scope(shared_data_folder):
+ parser = create_with_opt()
+ network = parser.CreateNetworkFromBinaryFile(os.path.join(shared_data_folder, "mock_model.tflite"))
+
+ graphs_count = parser.GetSubgraphCount()
+ graph_id = graphs_count - 1
+
+ input_names = parser.GetSubgraphInputTensorNames(graph_id)
+ input_binding_info = parser.GetNetworkInputBindingInfo(graph_id, input_names[0])
+
+ output_names = parser.GetSubgraphOutputTensorNames(graph_id)
+
+ preferred_backends = [ann.BackendId('CpuAcc'), ann.BackendId('CpuRef')]
+
+ options = ann.CreationOptions()
+ runtime = ann.IRuntime(options)
+
+ opt_network, messages = ann.Optimize(network, preferred_backends, runtime.GetDeviceSpec(), ann.OptimizerOptions())
+ assert 0 == len(messages)
+
+ net_id, messages = runtime.LoadNetwork(opt_network)
+ assert "" == messages
+
+
def test_tflite_get_sub_graph_count(parser):
graphs_count = parser.GetSubgraphCount()
assert graphs_count == 1