aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Haddon <matthew.haddon@arm.com>2021-07-16 15:38:20 +0100
committerMatthew Haddon <matthew.haddon@arm.com>2021-09-08 12:17:59 +0100
commit74567097e161ce9cdb0f09d474da6d9d36aa7476 (patch)
treea210dbedfdc5f1fa588a31c49e2b7f63f4624daa
parent4b2881a7c1cbb2a4b0b24cafcdef28af0f4975c1 (diff)
downloadreference_model-74567097e161ce9cdb0f09d474da6d9d36aa7476.tar.gz
Allow user to specify test type generated
* The option --test-type allows the user to select 'positive', 'negative', or 'both' types of tests produced by the test generator. * Reset RNG when looping through negative test generation (generation not implemented) Signed-off-by: Matthew Haddon <matthew.haddon@arm.com> Change-Id: I1bfcb3170e7380be0f98b36b3d4abc4779a05abe
-rw-r--r--verif/tosa_test_gen.py101
-rwxr-xr-xverif/tosa_verif_build_tests.py16
2 files changed, 71 insertions, 46 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index b5ddbd7..777c059 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -982,6 +982,11 @@ class TosaTestGen:
with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
fd.write(self.ser.writeJson("{}.tosa".format(testName)))
+ def resetRNG(self, seed=None):
+ if seed == None:
+ seed = self.random_seed + 1
+ self.rng = np.random.default_rng(seed)
+
def getRandTensor(self, shape, dtype):
if dtype == DType.BOOL:
np_dt = np.bool
@@ -1694,7 +1699,7 @@ class TosaTestGen:
return acc_out
def genOpTestList(
- self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None
+ self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
):
try:
@@ -1720,58 +1725,66 @@ class TosaTestGen:
if not shapeFilter:
shapeFilter = [None]
- for r in range(rmin, rmax + 1):
-
- # Filter out the rank?
- if rankFilter is not None and r not in rankFilter:
- continue
- if (
- rankFilter is None
- and shapeFilter[0] is None
- and r not in default_test_rank_range
- ):
- continue
+ # Positive test loop
+ if testType in ['positive', 'both']:
+ for r in range(rmin, rmax + 1):
- for t in op["types"]:
+ # Filter out the rank?
+ if rankFilter is not None and r not in rankFilter:
+ continue
+ if (
+ rankFilter is None
+ and shapeFilter[0] is None
+ and r not in default_test_rank_range
+ ):
+ continue
- # Filter tests based on dtype?
- if dtypeFilter is not None:
- if not (
- t in dtypeFilter
- or (isinstance(t, list) and t[0] in dtypeFilter)
- ):
- continue
+ for t in op["types"]:
- # Create the placeholder and const tensors
- for shape in shapeFilter:
- # A None shape chooses a random shape of a given rank
+ # Filter tests based on dtype?
+ if dtypeFilter is not None:
+ if not (
+ t in dtypeFilter
+ or (isinstance(t, list) and t[0] in dtypeFilter)
+ ):
+ continue
- # Filter out by rank
- if shape is not None and len(shape) != r:
- continue
+ # Create the placeholder and const tensors
+ for shape in shapeFilter:
+ # A None shape chooses a random shape of a given rank
- self.setTargetShape(shape)
- shapeList = tgen_fcn(self, op, r)
+ # Filter out by rank
+ if shape is not None and len(shape) != r:
+ continue
- shapeStr = self.shapeStr(shapeList[0])
- typeStr = self.typeStr(t)
+ self.setTargetShape(shape)
+ shapeList = tgen_fcn(self, op, r)
- # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
- argList = []
- if agen_fcn:
- argList = agen_fcn(self, opName, shapeList, t)
- else:
- argList = [("", [])]
+ shapeStr = self.shapeStr(shapeList[0])
+ typeStr = self.typeStr(t)
- for argStr, args in argList:
- if argStr:
- testStr = "{}_{}_{}_{}".format(
- opName, shapeStr, typeStr, argStr
- )
+ # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
+ argList = []
+ if agen_fcn:
+ argList = agen_fcn(self, opName, shapeList, t)
else:
- testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
-
- testList.append((opName, testStr, t, shapeList, args))
+ argList = [("", [])]
+
+ for argStr, args in argList:
+ if argStr:
+ testStr = "{}_{}_{}_{}".format(
+ opName, shapeStr, typeStr, argStr
+ )
+ else:
+ testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
+
+ testList.append((opName, testStr, t, shapeList, args))
+
+ # Reset RNG so both positive and negative tests are reproducible
+ self.resetRNG()
+ # Negative test loop
+ if testType in ['negative', 'both']:
+ print("Negative tests unsupported")
return testList
diff --git a/verif/tosa_verif_build_tests.py b/verif/tosa_verif_build_tests.py
index 343d8d4..02d1934 100755
--- a/verif/tosa_verif_build_tests.py
+++ b/verif/tosa_verif_build_tests.py
@@ -201,6 +201,14 @@ def parseArgs():
help="Allow constant input tensors for concat operator",
)
+ parser.add_argument(
+ "--test-type",
+ dest="test_type",
+ choices=['positive', 'negative', 'both'],
+ default="positive",
+ type=str,
+ help="type of tests produced, postive, negative, or both",
+ )
args = parser.parse_args()
return args
@@ -221,15 +229,19 @@ def main():
shapeFilter=args.target_shapes,
rankFilter=args.target_ranks,
dtypeFilter=args.target_dtypes,
+ testType=args.test_type
)
)
print("{} matching tests".format(len(testList)))
+ results = []
for opName, testStr, dtype, shapeList, testArgs in testList:
if args.verbose:
print(testStr)
- ttg.serializeTest(opName, testStr, dtype, shapeList, testArgs)
- print("Done creating {} tests".format(len(testList)))
+ results.append(ttg.serializeTest(opName, testStr, dtype, shapeList, testArgs))
+
+ print(f"Done creating {len(results)} tests")
+
if __name__ == "__main__":