aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2024-03-26 20:51:48 +0000
committerJerry Ge <jerry.ge@arm.com>2024-03-26 20:55:51 +0000
commitd5b1512b1d2cea3b87e52a0ecc123db2a7a7cad3 (patch)
tree8dd57b438402083cded82347a07e85b784ce2f92
parent42e183cae08b301083416481e7bac92f04f0ce21 (diff)
downloadreference_model-d5b1512b1d2cea3b87e52a0ecc123db2a7a7cad3.tar.gz
Add variable tensor fields for test descriptors
- Add variable_name and variable_file to the desc.json file for writing variable tensors to numpy - Add the key of num_variables in the unit test declaration to specify the number of variable tensors in the graph Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I1109f66ffed52e49dbb14f4a8aca64baa2bea622
-rwxr-xr-xverif/frameworks/tosa_verif_framework_compiler_runner.py19
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py8
-rw-r--r--verif/frameworks/write_test_json.py4
3 files changed, 31 insertions, 0 deletions
diff --git a/verif/frameworks/tosa_verif_framework_compiler_runner.py b/verif/frameworks/tosa_verif_framework_compiler_runner.py
index 56daa51..82e3aad 100755
--- a/verif/frameworks/tosa_verif_framework_compiler_runner.py
+++ b/verif/frameworks/tosa_verif_framework_compiler_runner.py
@@ -259,6 +259,8 @@ def write_reference_runner_json(
ifm_file,
ofm_name,
ofm_file,
+ variable_name,
+ variable_file,
expected_failure=False,
):
"""Write a json test file so that it is fairly easy to pick up the test
@@ -270,6 +272,8 @@ def write_reference_runner_json(
test_desc["ifm_file"] = ifm_file
test_desc["ofm_name"] = ofm_name
test_desc["ofm_file"] = ofm_file
+ test_desc["variable_name"] = variable_name
+ test_desc["variable_file"] = variable_file
test_desc["expected_failure"] = expected_failure
with open(filename, "w") as f:
@@ -630,6 +634,19 @@ def run_test(args, test_path, framework):
else:
reference_runner_ofm_name = ["TosaOutput_0"]
+ if "num_variables" in test_desc:
+ num_variable = test_desc["num_variables"]
+ else:
+ num_variable = 0
+ reference_runner_variable_name = []
+ reference_runner_variable_file = []
+
+ for i in range(num_variable):
+ variable_name_str = "Variable_" + str(i)
+ variable_file_str = "variable_output_" + str(i) + ".npy"
+ reference_runner_variable_name.append(variable_name_str)
+ reference_runner_variable_file.append(variable_file_str)
+
write_reference_runner_json(
filename=str(test_path / flatbuffer_dir / "desc.json"),
tosa_filename=f"{test_name}.tosa",
@@ -637,6 +654,8 @@ def run_test(args, test_path, framework):
ifm_file=reference_runner_ifm_file,
ofm_name=reference_runner_ofm_name,
ofm_file=["ref_model_output_0.npy"],
+ variable_name=reference_runner_variable_name,
+ variable_file=reference_runner_variable_file,
)
ref_model_cmd = [
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 2a7d484..8ae0286 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -940,6 +940,7 @@ TF_OP_LIST = {
tf.float32,
]
},
+ "num_variables": 2,
},
"gru": {
"operands": (1, 0),
@@ -1445,6 +1446,12 @@ def run_unit_test(
_, test_name = os.path.split(test_dir)
+ # For specifying the number of variable tensors if the graph has any
+ try:
+ num_varaibles = op["num_variables"]
+ except KeyError:
+ num_varaibles = 0
+
# Write out test descriptor
write_test_json(
filename=os.path.join(test_dir, "test.json"),
@@ -1461,6 +1468,7 @@ def run_unit_test(
framework_exclusions=excluded_framework_list,
quantized=is_quantized,
test_name=test_name,
+ num_variables=num_varaibles,
)
except Exception as e:
msg = "Error running task: {}".format(e)
diff --git a/verif/frameworks/write_test_json.py b/verif/frameworks/write_test_json.py
index 4e3aa40..cd42198 100644
--- a/verif/frameworks/write_test_json.py
+++ b/verif/frameworks/write_test_json.py
@@ -20,6 +20,7 @@ def write_test_json(
framework_exclusions=None,
quantized=False,
test_name=None,
+ num_variables=None,
):
test_desc = dict()
@@ -74,5 +75,8 @@ def write_test_json(
if quantized:
test_desc["quantized"] = 1
+ if num_variables:
+ test_desc["num_variables"] = num_variables
+
with open(filename, "w") as f:
json.dump(test_desc, f, indent=" ")