aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-04-04 10:03:52 +0100
committerEric Kunze <eric.kunze@arm.com>2024-04-08 17:09:29 +0000
commit129201df8126a16abb0e4fbf7372354021f8a55d (patch)
tree8f5f7160cb8bf0bd488ff9478c3a78e87b9c3fb4
parent9c0a5075d9e184f6b92762b3bc903e021b700e65 (diff)
downloadreference_model-129201df8126a16abb0e4fbf7372354021f8a55d.tar.gz
Add support for multi args in tosa_verif_build_tests
Now supports shorter "--target-rank 0 1" and the original method of "--target-rank 0 --target-rank 1" Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ia45a168588c6fca4dcd4cbbf526ac49cb0bdf621
-rw-r--r--verif/generator/tosa_verif_build_tests.py16
1 files changed, 10 insertions, 6 deletions
diff --git a/verif/generator/tosa_verif_build_tests.py b/verif/generator/tosa_verif_build_tests.py
index 4a15834..472ba4d 100644
--- a/verif/generator/tosa_verif_build_tests.py
+++ b/verif/generator/tosa_verif_build_tests.py
@@ -170,9 +170,10 @@ def parseArgs(argv):
ops_group.add_argument(
"--conv-kernel",
dest="conv_kernels",
- action="append",
+ action="extend",
default=[],
type=lambda x: str_to_list(x),
+ nargs="*",
help="Create convolution tests with a particular kernel shape, e.g., 1,4 or 1,3,1 (only 2D kernel sizes will be used for 2D ops, etc.)",
)
@@ -220,28 +221,31 @@ def parseArgs(argv):
tens_group.add_argument(
"--target-shape",
dest="target_shapes",
- action="append",
+ action="extend",
default=[],
+ # Used for parsing a comma-separated list of integers in a string
type=lambda x: str_to_list(x),
+ nargs="*",
help="Create tests with a particular input tensor shape, e.g., 1,4,4,8 (may be repeated for tests that require multiple input shapes)",
)
tens_group.add_argument(
"--target-rank",
dest="target_ranks",
- action="append",
+ action="extend",
default=None,
type=lambda x: auto_int(x),
- help="Create tests with a particular input tensor rank",
+ nargs="*",
+ help="Create tests with a particular input tensor rank (may be repeated)",
)
- # Used for parsing a comma-separated list of integers in a string
tens_group.add_argument(
"--target-dtype",
dest="target_dtypes",
- action="append",
+ action="extend",
default=None,
type=lambda x: dtype_str_to_val(x),
+ nargs="*",
help=f"Create test with a particular DType: [{', '.join([d.lower() for d in DTypeNames[1:]])}] (may be repeated)",
)