From 129201df8126a16abb0e4fbf7372354021f8a55d Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Thu, 4 Apr 2024 10:03:52 +0100 Subject: Add support for multi args in tosa_verif_build_tests Now supports shorter "--target-rank 0 1" and the original method of "--target-rank 0 --target-rank 1" Signed-off-by: Jeremy Johnson Change-Id: Ia45a168588c6fca4dcd4cbbf526ac49cb0bdf621 --- verif/generator/tosa_verif_build_tests.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/verif/generator/tosa_verif_build_tests.py b/verif/generator/tosa_verif_build_tests.py index 4a15834..472ba4d 100644 --- a/verif/generator/tosa_verif_build_tests.py +++ b/verif/generator/tosa_verif_build_tests.py @@ -170,9 +170,10 @@ def parseArgs(argv): ops_group.add_argument( "--conv-kernel", dest="conv_kernels", - action="append", + action="extend", default=[], type=lambda x: str_to_list(x), + nargs="*", help="Create convolution tests with a particular kernel shape, e.g., 1,4 or 1,3,1 (only 2D kernel sizes will be used for 2D ops, etc.)", ) @@ -220,28 +221,31 @@ def parseArgs(argv): tens_group.add_argument( "--target-shape", dest="target_shapes", - action="append", + action="extend", default=[], + # Used for parsing a comma-separated list of integers in a string type=lambda x: str_to_list(x), + nargs="*", help="Create tests with a particular input tensor shape, e.g., 1,4,4,8 (may be repeated for tests that require multiple input shapes)", ) tens_group.add_argument( "--target-rank", dest="target_ranks", - action="append", + action="extend", default=None, type=lambda x: auto_int(x), - help="Create tests with a particular input tensor rank", + nargs="*", + help="Create tests with a particular input tensor rank (may be repeated)", ) - # Used for parsing a comma-separated list of integers in a string tens_group.add_argument( "--target-dtype", dest="target_dtypes", - action="append", + action="extend", default=None, type=lambda x: dtype_str_to_val(x), + nargs="*", help=f"Create test with a particular DType: [{', '.join([d.lower() for d in DTypeNames[1:]])}] (may be repeated)", ) -- cgit v1.2.1