# Copyright © 2020 Arm Ltd and Contributors. All rights reserved. # SPDX-License-Identifier: MIT import tflite_runtime.interpreter as tflite import numpy as np import os def run_mock_model(delegate, test_data_folder): model_path = os.path.join(test_data_folder, 'mock_model.tflite') interpreter = tflite.Interpreter(model_path=model_path, experimental_delegates=[delegate]) interpreter.allocate_tensors() # Get input and output tensors. input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() # Test model on random input data. input_shape = input_details[0]['shape'] input_data = np.array(np.random.random_sample(input_shape), dtype=np.uint8) interpreter.set_tensor(input_details[0]['index'], input_data) interpreter.invoke() def run_inference(test_data_folder, model_filename, inputs, delegates=None): model_path = os.path.join(test_data_folder, model_filename) interpreter = tflite.Interpreter(model_path=model_path, experimental_delegates=delegates) interpreter.allocate_tensors() # Get input and output tensors. input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() # Set inputs to tensors. for i in range(len(inputs)): interpreter.set_tensor(input_details[i]['index'], inputs[i]) interpreter.invoke() results = [] for output in output_details: results.append(interpreter.get_tensor(output['index'])) return results def compare_outputs(outputs, expected_outputs): assert len(outputs) == len(expected_outputs), 'Incorrect number of outputs' for i in range(len(expected_outputs)): assert outputs[i].shape == expected_outputs[i].shape, 'Incorrect output shape on output#{}'.format(i) assert outputs[i].dtype == expected_outputs[i].dtype, 'Incorrect output data type on output#{}'.format(i) assert outputs[i].all() == expected_outputs[i].all(), 'Incorrect output value on output#{}'.format(i)