aboutsummaryrefslogtreecommitdiff
path: root/delegate/python/test/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/python/test/utils.py')
-rw-r--r--delegate/python/test/utils.py31
1 files changed, 30 insertions, 1 deletions
diff --git a/delegate/python/test/utils.py b/delegate/python/test/utils.py
index 3adc24fe35..f3761ec8a1 100644
--- a/delegate/python/test/utils.py
+++ b/delegate/python/test/utils.py
@@ -21,4 +21,33 @@ def run_mock_model(delegate, test_data_folder):
input_data = np.array(np.random.random_sample(input_shape), dtype=np.uint8)
interpreter.set_tensor(input_details[0]['index'], input_data)
- interpreter.invoke() \ No newline at end of file
+ interpreter.invoke()
+
+def run_inference(test_data_folder, model_filename, inputs, delegates=None):
+ model_path = os.path.join(test_data_folder, model_filename)
+ interpreter = tflite.Interpreter(model_path=model_path,
+ experimental_delegates=delegates)
+ interpreter.allocate_tensors()
+
+ # Get input and output tensors.
+ input_details = interpreter.get_input_details()
+ output_details = interpreter.get_output_details()
+
+ # Set inputs to tensors.
+ for i in range(len(inputs)):
+ interpreter.set_tensor(input_details[i]['index'], inputs[i])
+
+ interpreter.invoke()
+
+ results = []
+ for output in output_details:
+ results.append(interpreter.get_tensor(output['index']))
+
+ return results
+
+def compare_outputs(outputs, expected_outputs):
+ assert len(outputs) == len(expected_outputs), 'Incorrect number of outputs'
+ for i in range(len(expected_outputs)):
+ assert outputs[i].shape == expected_outputs[i].shape, 'Incorrect output shape on output#{}'.format(i)
+ assert outputs[i].dtype == expected_outputs[i].dtype, 'Incorrect output data type on output#{}'.format(i)
+ assert outputs[i].all() == expected_outputs[i].all(), 'Incorrect output value on output#{}'.format(i) \ No newline at end of file