aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/write_test_json.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/frameworks/write_test_json.py')
-rw-r--r--verif/frameworks/write_test_json.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/verif/frameworks/write_test_json.py b/verif/frameworks/write_test_json.py
new file mode 100644
index 0000000..68dfc8f
--- /dev/null
+++ b/verif/frameworks/write_test_json.py
@@ -0,0 +1,70 @@
+# Copyright (c) 2020-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+import json
+
+# Used by basic_test_generator to create test description
+
+
+def write_test_json(
+ filename,
+ tf_model_filename=None,
+ tf_result_npy_filename=None,
+ tf_result_name=None,
+ tflite_model_filename=None,
+ tflite_result_npy_filename=None,
+ tflite_result_name=None,
+ ifm_name=None,
+ ifm_file=None,
+ ifm_shape=None,
+ framework_exclusions=None,
+ quantized=False,
+):
+
+ test_desc = dict()
+
+ if tf_model_filename:
+ test_desc["tf_model_filename"] = tf_model_filename
+
+ if tf_result_npy_filename:
+ test_desc["tf_result_npy_filename"] = tf_result_npy_filename
+
+ if tf_result_name:
+ test_desc["tf_result_name"] = tf_result_name
+
+ if tflite_model_filename:
+ test_desc["tflite_model_filename"] = tflite_model_filename
+
+ if tflite_result_npy_filename:
+ test_desc["tflite_result_npy_filename"] = tflite_result_npy_filename
+
+ if tflite_result_name:
+ test_desc["tflite_result_name"] = tflite_result_name
+
+ if ifm_file:
+ if not isinstance(ifm_file, list):
+ ifm_file = [ifm_file]
+ test_desc["ifm_file"] = ifm_file
+
+ # Make sure these arguments are wrapped as lists
+ if ifm_name:
+ if not isinstance(ifm_name, list):
+ ifm_name = [ifm_name]
+ test_desc["ifm_name"] = ifm_name
+
+ if ifm_shape:
+ if not isinstance(ifm_shape, list):
+ ifm_shape = [ifm_shape]
+ test_desc["ifm_shape"] = ifm_shape
+
+ # Some tests cannot be used with specific frameworks.
+ # This list indicates which tests should be excluded from a given framework.
+ if framework_exclusions:
+ if not isinstance(framework_exclusions, list):
+ framework_exclusions = [framework_exclusions]
+ test_desc["framework_exclusions"] = framework_exclusions
+
+ if quantized:
+ test_desc["quantized"] = 1
+
+ with open(filename, "w") as f:
+ json.dump(test_desc, f, indent=" ")