diff options
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 108 |
1 files changed, 87 insertions, 21 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index f7837a0..32f4341 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -638,9 +638,9 @@ class TosaTensorValuesGen: if ( error_name is not None or not gtu.dtypeIsSupportedByCompliance(dtypeList[0]) - or opName in ("avg_pool2d",) + or "data_gen" not in testGen.TOSA_OP_LIST[opName] ): - # Fall back to original path when dealing with unsupported types + # Fall back to original path when dealing with unsupported types or ops # First turn off lazy data gen so we always produce data lazy_data_gen = testGen.args.lazy_data_gen @@ -660,7 +660,11 @@ class TosaTensorValuesGen: # Create data generator meta-data dg_type = argsDict["dg_type"] - dg_tens_meta = {} + tens_data = { + "version": "0.1", + "tensors": {}, + } + dg_tens_meta = tens_data["tensors"] tens_ser_list = [] for idx, shape in enumerate(shapeList): @@ -669,15 +673,12 @@ class TosaTensorValuesGen: tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"] tens_meta["shape"] = [int(i) for i in shape] tens_meta["input_pos"] = idx - tens_meta["op"] = opName.upper() + tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper() if idx < pCount: tens_meta["input_type"] = "VARIABLE" - tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], None) else: tens_meta["input_type"] = "CONSTANT" - tens = testGen.ser.addConst(shape, dtypeList[idx], None) - tens_ser_list.append(tens) if dg_type == gtu.DataGenType.PSEUDO_RANDOM: info = {} @@ -691,23 +692,55 @@ class TosaTensorValuesGen: elif dg_type == gtu.DataGenType.DOT_PRODUCT: info = {} info["s"] = argsDict["s"] - info["ks"] = argsDict["ks"] - for key in gtu.DG_DOT_PRODUCT_OPTIONAL_INFO: - if key in argsDict: - if key.endswith("_type"): - info[key] = gtu.DTYPE_ATTRIBUTES[argsDict[key]]["json"] - else: - info[key] = argsDict[key] + info["ks"] = int(argsDict["ks"]) + if "acc_type" in argsDict: + # Convert type number into JSON name + info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][ + "json" + ] + if "kernel" in argsDict: + info["kernel"] = [int(k) for k in argsDict["kernel"]] + if "axis" in argsDict: + info["axis"] = int(argsDict["axis"]) tens_meta["dot_product_info"] = info else: # TODO - other data gen type assert False, "TODO: support other data gen types" + + # Using the finished generate config meta data - generate the data if + # needed and assign a tensor name from the serializer + + # Need to generate data when not lazy or for the bias tensor as we need + # to work out if the bias data is non-zero for compliance + if not testGen.args.lazy_data_gen or ( + idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT + ): + # Give this tensor a temporary name until we get one from the serializer + temp_name = f"placeholder_{idx}" + dg_tens_meta[temp_name] = tens_meta + # Create data now using the temporary name to access meta details + data = testGen.dgl.get_tensor_data(temp_name, tens_data) + # Remove the item as we will give it the correct name later + del dg_tens_meta[temp_name] + + if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT: + # The KS value used by compliance verification is altered when the + # bias data is non-zero + if max(abs(data)) > 0.0: + argsDict["ksb"] = argsDict["ks"] + 1 + + if testGen.args.lazy_data_gen: + data = None + + if tens_meta["input_type"] == "VARIABLE": + tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data) + else: + tens = testGen.ser.addConst(shape, dtypeList[idx], data) + + tens_ser_list.append(tens) + # Add the meta data to the list using the serializer tensor name dg_tens_meta[tens.name] = tens_meta - tens_data = { - "version": "0.1", - "tensors": dg_tens_meta, - } return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data) @staticmethod @@ -1206,8 +1239,11 @@ class TosaArgGen: accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes) - # Check the rank + # Op type checks conv3d = opName.startswith("conv3d") + depthwise = opName.startswith("depthwise") + + # Check the rank rank = 5 if conv3d else 4 if error_name != ErrorIf.WrongRank: assert len(ifm_shape) == rank @@ -1215,8 +1251,12 @@ class TosaArgGen: # kernel rank omits channels k_rank = rank - 2 - k_pos = 0 if opName.startswith("depthwise") else 1 + k_pos = 0 if depthwise else 1 k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)]) + # compliance size - KS + k_size = gtu.product(k_shape) + if not depthwise: + k_size *= ifm_shape[-1] if not testGen.args.level8k: # Generate comprehensive argument lists @@ -1363,6 +1403,24 @@ class TosaArgGen: # Test will consume too much memory - skip it continue + # Compliance - number of dot product calculations + if depthwise: + # TODO - add support + dots = 0 + else: + dots = gtu.product( + (ifm_shape[0], *outputs, filter_shape[0]) + ) + args_dict = { + "acc_type": accum_dtype, + "stride": s, + "pad": p, + "dilation": d, + "kernel": k_shape, + "ks": k_size, + "dot_products": dots, + } + # Support for larger values than 9 needs different delimiter delim = "" if max(s + p + d) <= 9 else "x" arg_list.append( @@ -1373,11 +1431,19 @@ class TosaArgGen: delim.join([str(x) for x in p]), delim.join([str(x) for x in d]), ), - [accum_dtype, s, p, d], + args_dict, ) ) n += 1 + arg_list = TosaArgGen._add_data_generators( + testGen, + opName, + dtypes[0], + arg_list, + error_name, + ) + # Return list of tuples: (arg_str, args_dict) return arg_list @staticmethod |