diff options
Diffstat (limited to 'delegate/python/test/utils.py')
-rw-r--r-- | delegate/python/test/utils.py | 31 |
1 files changed, 30 insertions, 1 deletions
diff --git a/delegate/python/test/utils.py b/delegate/python/test/utils.py index 3adc24fe35..f3761ec8a1 100644 --- a/delegate/python/test/utils.py +++ b/delegate/python/test/utils.py @@ -21,4 +21,33 @@ def run_mock_model(delegate, test_data_folder): input_data = np.array(np.random.random_sample(input_shape), dtype=np.uint8) interpreter.set_tensor(input_details[0]['index'], input_data) - interpreter.invoke()
\ No newline at end of file + interpreter.invoke() + +def run_inference(test_data_folder, model_filename, inputs, delegates=None): + model_path = os.path.join(test_data_folder, model_filename) + interpreter = tflite.Interpreter(model_path=model_path, + experimental_delegates=delegates) + interpreter.allocate_tensors() + + # Get input and output tensors. + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + # Set inputs to tensors. + for i in range(len(inputs)): + interpreter.set_tensor(input_details[i]['index'], inputs[i]) + + interpreter.invoke() + + results = [] + for output in output_details: + results.append(interpreter.get_tensor(output['index'])) + + return results + +def compare_outputs(outputs, expected_outputs): + assert len(outputs) == len(expected_outputs), 'Incorrect number of outputs' + for i in range(len(expected_outputs)): + assert outputs[i].shape == expected_outputs[i].shape, 'Incorrect output shape on output#{}'.format(i) + assert outputs[i].dtype == expected_outputs[i].dtype, 'Incorrect output data type on output#{}'.format(i) + assert outputs[i].all() == expected_outputs[i].all(), 'Incorrect output value on output#{}'.format(i)
\ No newline at end of file |