aboutsummaryrefslogtreecommitdiff
path: root/verif/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/tosa_test_gen.py')
-rw-r--r--verif/tosa_test_gen.py105
1 files changed, 73 insertions, 32 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index a3c6b05..efc819c 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -32,6 +32,7 @@ import math
import itertools
from enum import IntEnum, Enum, unique
+from tosa_ref_run import TosaReturnCode
# Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH
parent_dir = os.path.dirname(os.path.realpath(__file__))
@@ -65,8 +66,9 @@ class TosaQuantGen:
@staticmethod
def qgUnary(testGen, op, dtype):
qinfo = ts.TosaSerializerQuantInfo()
- qinfo.UnaryQuantInfo(TosaQuantGen.getQinfo(testGen, dtype),
- TosaQuantGen.getQinfo(testGen, dtype))
+ qinfo.UnaryQuantInfo(
+ TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
+ )
return qinfo
@staticmethod
@@ -86,8 +88,9 @@ class TosaQuantGen:
@staticmethod
def qgMatmul(testGen, op, dtype):
qinfo = ts.TosaSerializerQuantInfo()
- qinfo.MatMulQuantInfo(TosaQuantGen.getQinfo(testGen, dtype),
- TosaQuantGen.getQinfo(testGen, dtype))
+ qinfo.MatMulQuantInfo(
+ TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
+ )
return qinfo
@staticmethod
@@ -304,13 +307,11 @@ class TosaTensorGen:
assert rank == 2
input_shape = testGen.makeShape(rank)
- filter_oc = (
- testGen.rng.integers(
- low=testGen.args.tensor_shape_range[0],
- high=testGen.args.tensor_shape_range[1],
- size=1,
- )[0]
- )
+ filter_oc = testGen.rng.integers(
+ low=testGen.args.tensor_shape_range[0],
+ high=testGen.args.tensor_shape_range[1],
+ size=1,
+ )[0]
filter_shape = np.asarray([filter_oc, input_shape[1]])
bias_shape = np.asarray([filter_oc])
@@ -734,7 +735,10 @@ class TosaArgGen:
random_permutations = testGen.rng.permutation(permutations)
# Create list of required amount of permutations
- arg_list = [("perm{}".format(p), [random_permutations[p].tolist()]) for p in range(limit)]
+ arg_list = [
+ ("perm{}".format(p), [random_permutations[p].tolist()])
+ for p in range(limit)
+ ]
return arg_list
@staticmethod
@@ -1154,7 +1158,7 @@ class TosaTestGen:
def build_table(self, op, a):
# Constant size depending on type, random values
if a.dtype == DType.INT16:
- table_dtype = DType.INT16
+ table_dtype = DType.INT16
table_arr = self.getRandTensor([513], table_dtype)
else:
assert a.dtype == DType.INT8
@@ -1497,7 +1501,7 @@ class TosaTestGen:
if val.dtype == DType.INT8:
input_zp = self.randInt(-128, 128)
in_type_width = in_type_width + 1
- elif val.dtype == DType.UINT8:
+ elif val.dtype == DType.UINT8:
input_zp = self.randInt(0, 256)
in_type_width = in_type_width + 1
else:
@@ -1536,7 +1540,9 @@ class TosaTestGen:
scale_arr[i], scale32
)
if shift_arr[i] < 2 or shift_arr[i] > 62:
- self.ser.setExpectedFailure(True, "OpRescale: invalid shift value")
+ self.ser.setExpectedReturnCode(
+ TosaReturnCode.UNPREDICTABLE, "OpRescale: invalid shift value"
+ )
# print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
@@ -1710,14 +1716,21 @@ class TosaTestGen:
# Filter out the rank?
if rankFilter is not None and r not in rankFilter:
continue
- if rankFilter is None and shapeFilter[0] is None and r not in default_test_rank_range:
+ if (
+ rankFilter is None
+ and shapeFilter[0] is None
+ and r not in default_test_rank_range
+ ):
continue
for t in op["types"]:
# Filter tests based on dtype?
if dtypeFilter is not None:
- if not (t in dtypeFilter or (isinstance(t, list) and t[0] in dtypeFilter)):
+ if not (
+ t in dtypeFilter
+ or (isinstance(t, list) and t[0] in dtypeFilter)
+ ):
continue
# Create the placeholder and const tensors
@@ -2660,7 +2673,9 @@ class OutputShaper:
# Invalid test parameters?
h = 0
w = 0
- ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.UNPREDICTABLE, "Invalid combination of conv2d parameters"
+ )
ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
@@ -2700,7 +2715,9 @@ class OutputShaper:
# Invalid test parameters?
h = 0
w = 0
- ser.setExpectedFailure(True, "Invalid combination of conv2d parameters")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.UNPREDICTABLE, "Invalid combination of conv2d parameters"
+ )
ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
@@ -2725,7 +2742,9 @@ class OutputShaper:
# Invalid test parameters?
h = 0
w = 0
- ser.setExpectedFailure(True, "Invalid combination of pooling parameters")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.UNPREDICTABLE, "Invalid combination of pool2d parameters"
+ )
ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
return ser.addOutput(ofm_shape, ifm.dtype)
@@ -2889,39 +2908,59 @@ class OutputShaper:
if input_dtype == DType.FLOAT:
if stride_fp[0] <= 0 or stride_fp[1] <= 0:
- ser.setExpectedFailure(True, "Negative or zero stride")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.ERROR, "Negative or zero stride"
+ )
else:
if stride[0] <= 0 or stride[1] <= 0:
- ser.setExpectedFailure(True, "Negative or zero stride")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.ERROR, "Negative or zero stride"
+ )
if mode == ResizeMode.BILINEAR:
if input_dtype == DType.INT8:
if output_dtype != DType.INT32:
- ser.setExpectedFailure(True, "Invalid output data type")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.ERROR, "Invalid output data type"
+ )
elif input_dtype == DType.INT16:
if output_dtype != DType.INT48:
- ser.setExpectedFailure(true, "Invalid output data type")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.ERROR, "Invalid output data type"
+ )
elif input_dtype == DType.FLOAT:
if output_dtype != DType.FLOAT:
- ser.setExpectedFailure(true, "Invalid output data type")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.ERROR, "Invalid output data type"
+ )
else:
- ser.setExpectedFailure(true, "Invalid input data type")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.ERROR, "Invalid input data type"
+ )
elif mode == ResizeMode.NEAREST:
if input_dtype == DType.INT8:
if output_dtype != DType.INT8:
- ser.setExpectedFailure(True, "Invalid output data type")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.ERROR, "Invalid output data type"
+ )
elif input_dtype == DType.INT16:
if output_dtype != DType.INT16:
- ser.setExpectedFailure(true, "Invalid output data type")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.ERROR, "Invalid output data type"
+ )
elif input_dtype == DType.FLOAT:
if output_dtype != DType.FLOAT:
- ser.setExpectedFailure(true, "Invalid output data type")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.ERROR, "Invalid output data type"
+ )
else:
- ser.setExpectedFailure(true, "Invalid input data type")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.ERROR, "Invalid input data type"
+ )
else:
- ser.setExpectedFailure(true, "Invalid resize mode")
+ ser.setExpectedReturnCode(TosaReturnCode.ERROR, "Invalid resize mode")
return ser.addOutput(output_dims, output_dtype)
@@ -2941,6 +2980,8 @@ class OutputShaper:
raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
if output_shape[1] <= 0 or output_shape[2] <= 0:
- ser.setExpectedFailure(True, "Negative output shape")
+ ser.setExpectedReturnCode(
+ TosaReturnCode.UNPREDICTABLE, "Negative output shape"
+ )
return ser.addOutput(output_shape, out_dtype)