diff options
Diffstat (limited to 'verif/tosa_test_gen.py')
-rw-r--r-- | verif/tosa_test_gen.py | 105 |
1 files changed, 73 insertions, 32 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index a3c6b05..efc819c 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -32,6 +32,7 @@ import math import itertools from enum import IntEnum, Enum, unique +from tosa_ref_run import TosaReturnCode # Include the ../thirdparty/serialization_lib/python directory in PYTHONPATH parent_dir = os.path.dirname(os.path.realpath(__file__)) @@ -65,8 +66,9 @@ class TosaQuantGen: @staticmethod def qgUnary(testGen, op, dtype): qinfo = ts.TosaSerializerQuantInfo() - qinfo.UnaryQuantInfo(TosaQuantGen.getQinfo(testGen, dtype), - TosaQuantGen.getQinfo(testGen, dtype)) + qinfo.UnaryQuantInfo( + TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype) + ) return qinfo @staticmethod @@ -86,8 +88,9 @@ class TosaQuantGen: @staticmethod def qgMatmul(testGen, op, dtype): qinfo = ts.TosaSerializerQuantInfo() - qinfo.MatMulQuantInfo(TosaQuantGen.getQinfo(testGen, dtype), - TosaQuantGen.getQinfo(testGen, dtype)) + qinfo.MatMulQuantInfo( + TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype) + ) return qinfo @staticmethod @@ -304,13 +307,11 @@ class TosaTensorGen: assert rank == 2 input_shape = testGen.makeShape(rank) - filter_oc = ( - testGen.rng.integers( - low=testGen.args.tensor_shape_range[0], - high=testGen.args.tensor_shape_range[1], - size=1, - )[0] - ) + filter_oc = testGen.rng.integers( + low=testGen.args.tensor_shape_range[0], + high=testGen.args.tensor_shape_range[1], + size=1, + )[0] filter_shape = np.asarray([filter_oc, input_shape[1]]) bias_shape = np.asarray([filter_oc]) @@ -734,7 +735,10 @@ class TosaArgGen: random_permutations = testGen.rng.permutation(permutations) # Create list of required amount of permutations - arg_list = [("perm{}".format(p), [random_permutations[p].tolist()]) for p in range(limit)] + arg_list = [ + ("perm{}".format(p), [random_permutations[p].tolist()]) + for p in range(limit) + ] return arg_list @staticmethod @@ -1154,7 +1158,7 @@ class TosaTestGen: def build_table(self, op, a): # Constant size depending on type, random values if a.dtype == DType.INT16: - table_dtype = DType.INT16 + table_dtype = DType.INT16 table_arr = self.getRandTensor([513], table_dtype) else: assert a.dtype == DType.INT8 @@ -1497,7 +1501,7 @@ class TosaTestGen: if val.dtype == DType.INT8: input_zp = self.randInt(-128, 128) in_type_width = in_type_width + 1 - elif val.dtype == DType.UINT8: + elif val.dtype == DType.UINT8: input_zp = self.randInt(0, 256) in_type_width = in_type_width + 1 else: @@ -1536,7 +1540,9 @@ class TosaTestGen: scale_arr[i], scale32 ) if shift_arr[i] < 2 or shift_arr[i] > 62: - self.ser.setExpectedFailure(True, "OpRescale: invalid shift value") + self.ser.setExpectedReturnCode( + TosaReturnCode.UNPREDICTABLE, "OpRescale: invalid shift value" + ) # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp)) @@ -1710,14 +1716,21 @@ class TosaTestGen: # Filter out the rank? if rankFilter is not None and r not in rankFilter: continue - if rankFilter is None and shapeFilter[0] is None and r not in default_test_rank_range: + if ( + rankFilter is None + and shapeFilter[0] is None + and r not in default_test_rank_range + ): continue for t in op["types"]: # Filter tests based on dtype? if dtypeFilter is not None: - if not (t in dtypeFilter or (isinstance(t, list) and t[0] in dtypeFilter)): + if not ( + t in dtypeFilter + or (isinstance(t, list) and t[0] in dtypeFilter) + ): continue # Create the placeholder and const tensors @@ -2660,7 +2673,9 @@ class OutputShaper: # Invalid test parameters? h = 0 w = 0 - ser.setExpectedFailure(True, "Invalid combination of conv2d parameters") + ser.setExpectedReturnCode( + TosaReturnCode.UNPREDICTABLE, "Invalid combination of conv2d parameters" + ) ofm_shape = [ifm.shape[0], h, w, filter.shape[0]] @@ -2700,7 +2715,9 @@ class OutputShaper: # Invalid test parameters? h = 0 w = 0 - ser.setExpectedFailure(True, "Invalid combination of conv2d parameters") + ser.setExpectedReturnCode( + TosaReturnCode.UNPREDICTABLE, "Invalid combination of conv2d parameters" + ) ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]] @@ -2725,7 +2742,9 @@ class OutputShaper: # Invalid test parameters? h = 0 w = 0 - ser.setExpectedFailure(True, "Invalid combination of pooling parameters") + ser.setExpectedReturnCode( + TosaReturnCode.UNPREDICTABLE, "Invalid combination of pool2d parameters" + ) ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]] return ser.addOutput(ofm_shape, ifm.dtype) @@ -2889,39 +2908,59 @@ class OutputShaper: if input_dtype == DType.FLOAT: if stride_fp[0] <= 0 or stride_fp[1] <= 0: - ser.setExpectedFailure(True, "Negative or zero stride") + ser.setExpectedReturnCode( + TosaReturnCode.ERROR, "Negative or zero stride" + ) else: if stride[0] <= 0 or stride[1] <= 0: - ser.setExpectedFailure(True, "Negative or zero stride") + ser.setExpectedReturnCode( + TosaReturnCode.ERROR, "Negative or zero stride" + ) if mode == ResizeMode.BILINEAR: if input_dtype == DType.INT8: if output_dtype != DType.INT32: - ser.setExpectedFailure(True, "Invalid output data type") + ser.setExpectedReturnCode( + TosaReturnCode.ERROR, "Invalid output data type" + ) elif input_dtype == DType.INT16: if output_dtype != DType.INT48: - ser.setExpectedFailure(true, "Invalid output data type") + ser.setExpectedReturnCode( + TosaReturnCode.ERROR, "Invalid output data type" + ) elif input_dtype == DType.FLOAT: if output_dtype != DType.FLOAT: - ser.setExpectedFailure(true, "Invalid output data type") + ser.setExpectedReturnCode( + TosaReturnCode.ERROR, "Invalid output data type" + ) else: - ser.setExpectedFailure(true, "Invalid input data type") + ser.setExpectedReturnCode( + TosaReturnCode.ERROR, "Invalid input data type" + ) elif mode == ResizeMode.NEAREST: if input_dtype == DType.INT8: if output_dtype != DType.INT8: - ser.setExpectedFailure(True, "Invalid output data type") + ser.setExpectedReturnCode( + TosaReturnCode.ERROR, "Invalid output data type" + ) elif input_dtype == DType.INT16: if output_dtype != DType.INT16: - ser.setExpectedFailure(true, "Invalid output data type") + ser.setExpectedReturnCode( + TosaReturnCode.ERROR, "Invalid output data type" + ) elif input_dtype == DType.FLOAT: if output_dtype != DType.FLOAT: - ser.setExpectedFailure(true, "Invalid output data type") + ser.setExpectedReturnCode( + TosaReturnCode.ERROR, "Invalid output data type" + ) else: - ser.setExpectedFailure(true, "Invalid input data type") + ser.setExpectedReturnCode( + TosaReturnCode.ERROR, "Invalid input data type" + ) else: - ser.setExpectedFailure(true, "Invalid resize mode") + ser.setExpectedReturnCode(TosaReturnCode.ERROR, "Invalid resize mode") return ser.addOutput(output_dims, output_dtype) @@ -2941,6 +2980,8 @@ class OutputShaper: raise Exception("Unsupported input dtype: {}".format(ifm.dtype)) if output_shape[1] <= 0 or output_shape[2] <= 0: - ser.setExpectedFailure(True, "Negative output shape") + ser.setExpectedReturnCode( + TosaReturnCode.UNPREDICTABLE, "Negative output shape" + ) return ser.addOutput(output_shape, out_dtype) |