aboutsummaryrefslogtreecommitdiff
path: root/verif/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/tosa_test_gen.py')
-rw-r--r--verif/tosa_test_gen.py101
1 files changed, 57 insertions, 44 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index b5ddbd7..777c059 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -982,6 +982,11 @@ class TosaTestGen:
with open(os.path.join(self.basePath, self.testPath, "desc.json"), "w") as fd:
fd.write(self.ser.writeJson("{}.tosa".format(testName)))
+ def resetRNG(self, seed=None):
+ if seed == None:
+ seed = self.random_seed + 1
+ self.rng = np.random.default_rng(seed)
+
def getRandTensor(self, shape, dtype):
if dtype == DType.BOOL:
np_dt = np.bool
@@ -1694,7 +1699,7 @@ class TosaTestGen:
return acc_out
def genOpTestList(
- self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None
+ self, opName, shapeFilter=[None], rankFilter=None, dtypeFilter=None, testType='positive'
):
try:
@@ -1720,58 +1725,66 @@ class TosaTestGen:
if not shapeFilter:
shapeFilter = [None]
- for r in range(rmin, rmax + 1):
-
- # Filter out the rank?
- if rankFilter is not None and r not in rankFilter:
- continue
- if (
- rankFilter is None
- and shapeFilter[0] is None
- and r not in default_test_rank_range
- ):
- continue
+ # Positive test loop
+ if testType in ['positive', 'both']:
+ for r in range(rmin, rmax + 1):
- for t in op["types"]:
+ # Filter out the rank?
+ if rankFilter is not None and r not in rankFilter:
+ continue
+ if (
+ rankFilter is None
+ and shapeFilter[0] is None
+ and r not in default_test_rank_range
+ ):
+ continue
- # Filter tests based on dtype?
- if dtypeFilter is not None:
- if not (
- t in dtypeFilter
- or (isinstance(t, list) and t[0] in dtypeFilter)
- ):
- continue
+ for t in op["types"]:
- # Create the placeholder and const tensors
- for shape in shapeFilter:
- # A None shape chooses a random shape of a given rank
+ # Filter tests based on dtype?
+ if dtypeFilter is not None:
+ if not (
+ t in dtypeFilter
+ or (isinstance(t, list) and t[0] in dtypeFilter)
+ ):
+ continue
- # Filter out by rank
- if shape is not None and len(shape) != r:
- continue
+ # Create the placeholder and const tensors
+ for shape in shapeFilter:
+ # A None shape chooses a random shape of a given rank
- self.setTargetShape(shape)
- shapeList = tgen_fcn(self, op, r)
+ # Filter out by rank
+ if shape is not None and len(shape) != r:
+ continue
- shapeStr = self.shapeStr(shapeList[0])
- typeStr = self.typeStr(t)
+ self.setTargetShape(shape)
+ shapeList = tgen_fcn(self, op, r)
- # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
- argList = []
- if agen_fcn:
- argList = agen_fcn(self, opName, shapeList, t)
- else:
- argList = [("", [])]
+ shapeStr = self.shapeStr(shapeList[0])
+ typeStr = self.typeStr(t)
- for argStr, args in argList:
- if argStr:
- testStr = "{}_{}_{}_{}".format(
- opName, shapeStr, typeStr, argStr
- )
+ # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
+ argList = []
+ if agen_fcn:
+ argList = agen_fcn(self, opName, shapeList, t)
else:
- testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
-
- testList.append((opName, testStr, t, shapeList, args))
+ argList = [("", [])]
+
+ for argStr, args in argList:
+ if argStr:
+ testStr = "{}_{}_{}_{}".format(
+ opName, shapeStr, typeStr, argStr
+ )
+ else:
+ testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
+
+ testList.append((opName, testStr, t, shapeList, args))
+
+ # Reset RNG so both positive and negative tests are reproducible
+ self.resetRNG()
+ # Negative test loop
+ if testType in ['negative', 'both']:
+ print("Negative tests unsupported")
return testList