diff options
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 79 |
1 files changed, 60 insertions, 19 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 00490fa..a655a50 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -731,7 +731,11 @@ class TosaTensorValuesGen: # Change from inclusive to exclusive range data_range = (data_range[0], data_range[1] + 1) # Ignore lazy data gen option and create data array using any range limits - arr = testGen.getRandTensor(shape, dtype, data_range) + + if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None: + arr = np.int64(argsDict["fixed_data"][idx]) + else: + arr = testGen.getRandTensor(shape, dtype, data_range) if roundMode: arr = np.round(arr) if idx < pCount: @@ -751,7 +755,13 @@ class TosaTensorValuesGen: for idx, shape in enumerate(shapeList): tens_meta = {} - tens_meta["generator"] = gtu.DataGenType(dg_type).name + if "fixed_data" in argsDict and argsDict["fixed_data"][idx] is not None: + tens_meta["generator"] = gtu.DataGenType( + gtu.DataGenType.FIXED_DATA + ).name + else: + tens_meta["generator"] = gtu.DataGenType(dg_type).name + tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"] tens_meta["shape"] = [int(i) for i in shape] tens_meta["input_pos"] = idx @@ -764,23 +774,30 @@ class TosaTensorValuesGen: if dg_type == gtu.DataGenType.PSEUDO_RANDOM: info = {} - # TODO - generate seed for this generator based on test - info["rng_seed"] = 42 - - data_range = None - if "data_range_list" in argsDict: - data_range = argsDict["data_range_list"][idx]["range"] - if "round" in argsDict["data_range_list"][idx]: - info["round"] = argsDict["data_range_list"][idx]["round"] - elif "data_range" in argsDict: - data_range = argsDict["data_range"] - - if data_range is None: - data_range = testGen.getDTypeRange( - dtypeList[idx], high_inclusive=True - ) - info["range"] = [str(v) for v in data_range] - tens_meta["pseudo_random_info"] = info + if ( + tens_meta["generator"] + == gtu.DataGenType(gtu.DataGenType.FIXED_DATA).name + ): + info["data"] = [int(i) for i in argsDict["fixed_data"][idx]] + tens_meta["fixed_data_info"] = info + else: + # TODO - generate seed for this generator based on test + info["rng_seed"] = 42 + + data_range = None + if "data_range_list" in argsDict: + data_range = argsDict["data_range_list"][idx]["range"] + if "round" in argsDict["data_range_list"][idx]: + info["round"] = argsDict["data_range_list"][idx]["round"] + elif "data_range" in argsDict: + data_range = argsDict["data_range"] + + if data_range is None: + data_range = testGen.getDTypeRange( + dtypeList[idx], high_inclusive=True + ) + info["range"] = [str(v) for v in data_range] + tens_meta["pseudo_random_info"] = info elif dg_type == gtu.DataGenType.DOT_PRODUCT: info = {} info["s"] = argsDict["s"] @@ -812,6 +829,9 @@ class TosaTensorValuesGen: dg_tens_meta[temp_name] = tens_meta # Create data now using the temporary name to access meta details data = testGen.dgl.get_tensor_data(temp_name, tens_data) + if tens_meta["data_type"] == "SHAPE": + # Tensor type SHAPE and Numpy file type must be the same + data = np.int64(data) # Remove the item as we will give it the correct name later del dg_tens_meta[temp_name] @@ -1014,6 +1034,27 @@ class TosaTensorValuesGen: return placeholders @staticmethod + def tvgReshape(testGen, op, dtypeList, shapeList, argsDict, error_name=None): + dtypeList[1] = DType.SHAPE + shapeList[1] = [len(argsDict["new_shape"])] + # Create a new list for the pre-generated data in argsDict["fixed_data"] + argsDict["fixed_data"] = [None, argsDict["new_shape"]] + + return TosaTensorValuesGen.tvgLazyGenDefault( + testGen, op, dtypeList, shapeList, argsDict, error_name + ) + + @staticmethod + def tvgTile(testGen, op, dtypeList, shapeList, argsDict, error_name=None): + dtypeList[1] = DType.SHAPE + shapeList[1] = [len(argsDict["multiples"])] + argsDict["fixed_data"] = [None, argsDict["multiples"]] + + return TosaTensorValuesGen.tvgLazyGenDefault( + testGen, op, dtypeList, shapeList, argsDict, error_name + ) + + @staticmethod def tvgSelect(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): # Set datatype of condition tensor to boolean dtypeList[0] = DType.BOOL |