aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_verif_build_tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_verif_build_tests.py')
-rw-r--r--verif/generator/tosa_verif_build_tests.py38
1 files changed, 35 insertions, 3 deletions
diff --git a/verif/generator/tosa_verif_build_tests.py b/verif/generator/tosa_verif_build_tests.py
index ab78b1a..bc1ec8e 100644
--- a/verif/generator/tosa_verif_build_tests.py
+++ b/verif/generator/tosa_verif_build_tests.py
@@ -2,20 +2,24 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
import re
+import sys
from generator.tosa_test_gen import TosaTestGen
from serializer.tosa_serializer import dtype_str_to_val
from serializer.tosa_serializer import DTypeNames
+OPTION_FP_VALUES_RANGE = "--fp-values-range"
+
# Used for parsing a comma-separated list of integers in a string
# to an actual list of integers
-def str_to_list(in_s):
+def str_to_list(in_s, is_float=False):
"""Converts a comma-separated list of string integers to a python list of ints"""
lst = in_s.split(",")
out_list = []
for i in lst:
- out_list.append(int(i))
+ val = float(i) if is_float else int(i)
+ out_list.append(val)
return out_list
@@ -25,6 +29,26 @@ def auto_int(x):
def parseArgs(argv):
+ """Parse the command line arguments."""
+ if argv is None:
+ argv = sys.argv[1:]
+
+ if OPTION_FP_VALUES_RANGE in argv:
+ # Argparse fix for hyphen (minus values) in argument values
+ # convert "ARG VAL" into "ARG=VAL"
+ # Example --fp-values-range -2.0,2.0 -> --fp-values-range=-2.0,2.0
+ new_argv = []
+ idx = 0
+ while idx < len(argv):
+ arg = argv[idx]
+ if arg == OPTION_FP_VALUES_RANGE and idx + 1 < len(argv):
+ val = argv[idx + 1]
+ if val.startswith("-"):
+ arg = f"{arg}={val}"
+ idx += 1
+ new_argv.append(arg)
+ idx += 1
+ argv = new_argv
parser = argparse.ArgumentParser()
parser.add_argument(
@@ -61,6 +85,14 @@ def parseArgs(argv):
)
parser.add_argument(
+ OPTION_FP_VALUES_RANGE,
+ dest="tensor_fp_value_range",
+ default="0.0,1.0",
+ type=lambda x: str_to_list(x, is_float=True),
+ help="Min,Max range of floating point tensor values",
+ )
+
+ parser.add_argument(
"--max-batch-size",
dest="max_batch_size",
default=1,
@@ -132,7 +164,7 @@ def parseArgs(argv):
help="Upper limit on width and height output dimensions for `resize` op. Default: 1000",
)
- # Targetting a specific shape/rank/dtype
+ # Targeting a specific shape/rank/dtype
parser.add_argument(
"--target-shape",
dest="target_shapes",