aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xverif/frameworks/tosa_verif_framework_compiler_runner.py146
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py12
-rw-r--r--verif/frameworks/write_test_json.py6
-rw-r--r--verif/runner/run_command.py28
4 files changed, 160 insertions, 32 deletions
diff --git a/verif/frameworks/tosa_verif_framework_compiler_runner.py b/verif/frameworks/tosa_verif_framework_compiler_runner.py
index bf035cc..ab3db90 100755
--- a/verif/frameworks/tosa_verif_framework_compiler_runner.py
+++ b/verif/frameworks/tosa_verif_framework_compiler_runner.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-# Copyright (c) 2020-2023, ARM Limited.
+# Copyright (c) 2020-2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import argparse
import json
@@ -276,6 +276,97 @@ def write_reference_runner_json(
json.dump(test_desc, f, indent=" ")
+""" For dynamic shape model, apply 2 steps to perform compilation, shape inference,
+ and serialization."""
+
+
+def compile_dynamic_model(
+ args,
+ framework,
+ test_path,
+ test_name,
+ pre_opt_filename,
+ post_opt_filename,
+ tosa_mlir_filename,
+ compiler_cmd,
+ flatbuffer_dir_fullpath,
+ shape,
+):
+ try:
+ # 1. Compile the dynamic shape model with unknown shapes and tosa shape ops.
+ dyn_tosa_mlir_filename = str(test_path / f"output_{framework}.dyn.tosa.mlir")
+ compile_dynamic_cmd = compiler_cmd.copy()
+ compile_dynamic_cmd.extend(
+ [
+ "--verify-each",
+ post_opt_filename,
+ "-o",
+ dyn_tosa_mlir_filename,
+ ]
+ )
+ compiler_stdout, compiler_stderr = run_sh_command(
+ compile_dynamic_cmd, args.verbose, True
+ )
+
+ compiler_rc_1 = parse_compiler_output(compiler_stdout, compiler_stderr)
+
+ if compiler_rc_1 == TestResult.NOT_LOWERED:
+ print_color(
+ LogColors.RED,
+ f"Results NOT_LOWERED {test_name}, framework {framework}",
+ )
+ return (TestResult.NOT_LOWERED, 0.0, "", test_name)
+
+ def convert_shape_tuple_to_string(tup):
+ string = ""
+ for dim in tup:
+ string = string + str(dim) + ","
+ # skip the last `,` character.
+ return string[0:-1]
+
+ # 2. Resolve unknown shapes, and perform serialization.
+ if not isinstance(shape, tuple):
+ raise Exception("Only single input is supported currently")
+
+ arg0_argument = '"arg0=' + convert_shape_tuple_to_string(shape) + '"'
+
+ compile_and_shape_infer_cmd = compiler_cmd.copy()
+ compile_and_shape_infer_cmd.extend(
+ [
+ f"--tosa-input-shape={arg0_argument}",
+ "--tosa-infer-shapes",
+ dyn_tosa_mlir_filename,
+ "-o",
+ tosa_mlir_filename,
+ "--tosa-serialize",
+ f"--tosa-flatbuffer-filename={flatbuffer_dir_fullpath / f'{test_name}.tosa'}",
+ ]
+ )
+
+ # Convert list type to string type as double quote \" in list structure causes
+ # single quote \' residue in the final command.
+ compiler_stdout, compiler_stderr = run_sh_command(
+ " ".join(map(str, compile_and_shape_infer_cmd)), args.verbose, True
+ )
+
+ compiler_rc_2 = parse_compiler_output(compiler_stdout, compiler_stderr)
+
+ if compiler_rc_2 == TestResult.NOT_LOWERED:
+ print_color(
+ LogColors.RED,
+ f"Results NOT_LOWERED {test_name}, framework {framework}",
+ )
+ return (TestResult.NOT_LOWERED, 0.0, "", test_name)
+
+ except Exception as e:
+ if "same scale constraint" in str(e):
+ print_color(LogColors.RED, f"Results INVALID_MLIR {test_name}: {e}")
+ return (TestResult.INVALID_MLIR, 0.0, e, test_name)
+ else:
+ print_color(LogColors.RED, f"Results COMPILER_ERROR {test_name}: {e}")
+ return (TestResult.COMPILER_ERROR, 0.0, e, test_name)
+
+
def run_test(args, test_path, framework):
msg = ""
@@ -431,7 +522,8 @@ def run_test(args, test_path, framework):
flatbuffer_dir_fullpath.mkdir(exist_ok=True)
- compiler_cmd.extend(
+ compile_and_serialize_cmd = compiler_cmd.copy()
+ compile_and_serialize_cmd.extend(
[
"--verify-each",
post_opt_filename,
@@ -452,27 +544,43 @@ def run_test(args, test_path, framework):
print_color(LogColors.RED, f"Results INVALID_MLIR {test_name}: {e}")
return (TestResult.INVALID_MLIR, 0.0, e, test_name)
- try:
- compiler_stdout, compiler_stderr = run_sh_command(
- compiler_cmd, args.verbose, True
+ if "ifm_dynamic" in test_desc and test_desc["ifm_dynamic"] == 1:
+ compile_dynamic_model(
+ args,
+ framework,
+ test_path,
+ test_name,
+ pre_opt_filename,
+ post_opt_filename,
+ tosa_mlir_filename,
+ compiler_cmd,
+ flatbuffer_dir_fullpath,
+ ifm_np.shape,
)
- compiler_rc = parse_compiler_output(compiler_stdout, compiler_stderr)
- if compiler_rc == TestResult.NOT_LOWERED:
- print_color(
- LogColors.RED,
- f"Results NOT_LOWERED {test_name}, framework {framework}",
+ else:
+ try:
+ compiler_stdout, compiler_stderr = run_sh_command(
+ compile_and_serialize_cmd, args.verbose, True
)
- return (TestResult.NOT_LOWERED, 0.0, "", test_name)
+ compiler_rc = parse_compiler_output(compiler_stdout, compiler_stderr)
+ if compiler_rc == TestResult.NOT_LOWERED:
+ print_color(
+ LogColors.RED,
+ f"Results NOT_LOWERED {test_name}, framework {framework}",
+ )
+ return (TestResult.NOT_LOWERED, 0.0, "", test_name)
- pass
+ pass
- except Exception as e:
- if "same scale constraint" in str(e):
- print_color(LogColors.RED, f"Results INVALID_MLIR {test_name}: {e}")
- return (TestResult.INVALID_MLIR, 0.0, e, test_name)
- else:
- print_color(LogColors.RED, f"Results COMPILER_ERROR {test_name}: {e}")
- return (TestResult.COMPILER_ERROR, 0.0, e, test_name)
+ except Exception as e:
+ if "same scale constraint" in str(e):
+ print_color(LogColors.RED, f"Results INVALID_MLIR {test_name}: {e}")
+ return (TestResult.INVALID_MLIR, 0.0, e, test_name)
+ else:
+ print_color(
+ LogColors.RED, f"Results COMPILER_ERROR {test_name}: {e}"
+ )
+ return (TestResult.COMPILER_ERROR, 0.0, e, test_name)
if framework == "tf":
try:
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 3a9c0ca..538f314 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -826,9 +826,9 @@ TF_OP_LIST = {
},
# number of operands of tuples which spcifies which dim to set to None
# In this case, we have 1 input. So we have 1 tuple
- # We're setting the first input's third dim to None
+ # We're setting the first input's first (batch) dim to None
"dynamic_shape_dim": [
- (3,),
+ (0,),
],
},
"depth_to_space": {
@@ -849,9 +849,9 @@ TF_OP_LIST = {
},
# number of operands of tuples which spcifies which dim to set to None
# In this case, we have 1 input. So we have 1 tuple
- # We're setting the first input's third dim to None
+ # We're setting the first input's first (batch) dim to None
"dynamic_shape_dim": [
- (3,),
+ (0,),
],
},
"one_hot": {
@@ -1166,6 +1166,7 @@ def run_unit_test(
placeholder_signatures = ()
placeholder_npy_filenames = []
placeholder_shapes = []
+ placeholder_dynamic = False
for idx, (name, val) in enumerate(placeholders):
input_shape = tuple(val.shape)
@@ -1176,6 +1177,8 @@ def run_unit_test(
dim = dim_tuple[0]
input_shape = list(input_shape)
input_shape[dim] = None
+ # When any dimension size is unknown, mark the placeholder as dynamic type.
+ placeholder_dynamic = True
addl_args.append(tuple(input_shape))
except KeyError:
@@ -1432,6 +1435,7 @@ def run_unit_test(
ifm_name=placeholder_names,
ifm_file=placeholder_npy_filenames,
ifm_shape=placeholder_shapes,
+ ifm_dynamic=placeholder_dynamic,
framework_exclusions=excluded_framework_list,
quantized=is_quantized,
test_name=test_name,
diff --git a/verif/frameworks/write_test_json.py b/verif/frameworks/write_test_json.py
index d52a777..4e3aa40 100644
--- a/verif/frameworks/write_test_json.py
+++ b/verif/frameworks/write_test_json.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import json
@@ -16,6 +16,7 @@ def write_test_json(
ifm_name=None,
ifm_file=None,
ifm_shape=None,
+ ifm_dynamic=False,
framework_exclusions=None,
quantized=False,
test_name=None,
@@ -60,6 +61,9 @@ def write_test_json(
ifm_shape = [ifm_shape]
test_desc["ifm_shape"] = ifm_shape
+ if ifm_dynamic:
+ test_desc["ifm_dynamic"] = True
+
# Some tests cannot be used with specific frameworks.
# This list indicates which tests should be excluded from a given framework.
if framework_exclusions:
diff --git a/verif/runner/run_command.py b/verif/runner/run_command.py
index eef5a76..97d9837 100644
--- a/verif/runner/run_command.py
+++ b/verif/runner/run_command.py
@@ -1,5 +1,5 @@
"""Shell command runner function."""
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import shlex
import subprocess
@@ -33,18 +33,28 @@ class RunShCommandError(Exception):
def run_sh_command(full_cmd, verbose=False, capture_output=False):
"""Run an external shell command.
- full_cmd: array containing shell command and its arguments
+ full_cmd: string, or array containing shell command and its arguments
verbose: optional flag that enables verbose output
capture_output: optional flag to return captured stdout/stderr
"""
- # Quote the command line for printing
- full_cmd_esc = [shlex.quote(x) for x in full_cmd]
+
+ is_str = True if isinstance(full_cmd, str) else False
+ is_list = True if isinstance(full_cmd, list) else False
+
+ if is_list:
+ # Quote the command line for printing
+ full_cmd_esc = [shlex.quote(x) for x in full_cmd]
if verbose:
- print("### Running {}".format(" ".join(full_cmd_esc)))
+ if is_list:
+ print("### Running {}".format(" ".join(full_cmd_esc)))
+ if is_str:
+ print("### Running {}".format(full_cmd))
if capture_output:
- rc = subprocess.run(full_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ rc = subprocess.run(
+ full_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=is_str
+ )
stdout = rc.stdout.decode("utf-8")
stderr = rc.stderr.decode("utf-8")
if verbose:
@@ -54,8 +64,10 @@ def run_sh_command(full_cmd, verbose=False, capture_output=False):
print(stderr, end="")
else:
stdout, stderr = None, None
- rc = subprocess.run(full_cmd)
+ rc = subprocess.run(full_cmd, shell=is_str)
if rc.returncode != 0:
- raise RunShCommandError(rc.returncode, full_cmd_esc, stderr, stdout)
+ raise RunShCommandError(
+ rc.returncode, full_cmd_esc if is_list else full_cmd, stderr, stdout
+ )
return (stdout, stderr)