aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-09-05 11:39:26 +0100
committerEric Kunze <eric.kunze@arm.com>2023-09-07 16:04:17 +0000
commit1271c44bd2c9e670e132db491a053a0e6603798f (patch)
tree98d3af1572ef38137d876ad858231ebd807a936e /verif/generator/tosa_arg_gen.py
parent77fc614916c1afa506fccb0ff2e5260aae8608b6 (diff)
downloadreference_model-1271c44bd2c9e670e132db491a053a0e6603798f.tar.gz
Initial lazy data-gen and compliance test build support
Add initial support for compliance and lazy data-gen meta data added to desc.json for MATMUL. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I00c047814134a96d7c98d890e93b5884e25b8e64
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py174
1 files changed, 160 insertions, 14 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 97ff237..8d96090 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -4,12 +4,10 @@ import itertools
import math
import warnings
+import generator.tosa_utils as gtu
import numpy as np
from generator.tosa_error_if import ErrorIf
from generator.tosa_error_if import TosaErrorIfArgGen
-from generator.tosa_utils import get_accum_dtype_from_tgTypes
-from generator.tosa_utils import get_wrong_output_type
-from generator.tosa_utils import MAX_RESIZE_DIMENSION
from serializer.tosa_serializer import DTypeNames
from tosa.DType import DType
from tosa.Op import Op
@@ -606,11 +604,18 @@ class TosaTensorGen:
class TosaTensorValuesGen:
- """Tensor Value generators create the random data for each test."""
+ """Tensor Value generators create the random data for each tensor in each test."""
def __init__(self):
pass
+ class TVGInfo:
+ """Enhanced tensor values information including data gen dict."""
+
+ def __init__(self, tensorList, dataGenDict):
+ self.tensorList = tensorList
+ self.dataGenDict = dataGenDict
+
@staticmethod
def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
pCount, cCount = op["operands"]
@@ -624,6 +629,87 @@ class TosaTensorValuesGen:
return tens
@staticmethod
+ def tvgLazyGenDefault(
+ testGen, opName, dtypeList, shapeList, argsDict, error_name=None
+ ):
+ # Variable inputs versus constants
+ pCount, cCount = testGen.TOSA_OP_LIST[opName]["operands"]
+
+ overrideLazy = False
+ if not gtu.dtypeIsFloat(dtypeList[0]) and testGen.args.lazy_data_gen:
+ # TEMPORARY OVERRIDE for integer types
+ overrideLazy = True
+ testGen.args.lazy_data_gen = False
+
+ # TODO - Change to generation of data using library!
+ # For now - we fall back to original path (or when dealing with non-floats)
+ if not testGen.args.lazy_data_gen:
+ tens_ser_list = TosaTensorValuesGen.tvgDefault(
+ testGen,
+ testGen.TOSA_OP_LIST[opName],
+ dtypeList,
+ shapeList,
+ [],
+ error_name,
+ )
+ if overrideLazy:
+ # Return to lazy mode
+ testGen.args.lazy_data_gen = True
+ return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
+
+ # Create data generator meta-data
+ dg_type = argsDict["dg_type"]
+ dg_tens_meta = {}
+ tens_ser_list = []
+ for idx, shape in enumerate(shapeList):
+
+ tens_meta = {}
+ tens_meta["generator"] = gtu.DataGenType(dg_type).name
+ tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
+ tens_meta["shape"] = [int(i) for i in shape]
+ tens_meta["input_pos"] = idx
+ tens_meta["op"] = opName
+
+ if idx < pCount:
+ tens_meta["input_type"] = "variable"
+ tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], None)
+ else:
+ tens_meta["input_type"] = "constant"
+ tens = testGen.ser.addConst(shape, dtypeList[idx], None)
+ tens_ser_list.append(tens)
+
+ if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
+ info = {}
+ # TODO - generate seed for this generator based on test
+ info["rng_seed"] = -1
+ info["range"] = [
+ str(v)
+ for v in testGen.getDTypeRange(dtypeList[idx], high_inclusive=True)
+ ]
+ tens_meta["pseudo_random_info"] = info
+ elif dg_type == gtu.DataGenType.DOT_PRODUCT:
+ info = {}
+ info["s"] = argsDict["s"]
+ info["ks"] = argsDict["ks"]
+ for key in gtu.DG_DOT_PRODUCT_OPTIONAL_INFO:
+ if key in argsDict:
+ if key.endswith("_type"):
+ info[key] = gtu.DTYPE_ATTRIBUTES[argsDict[key]]["json"]
+ else:
+ info[key] = argsDict[key]
+ tens_meta["dot_product_info"] = info
+ else:
+ # TODO - other data gen type
+ assert False, "TODO: support other data gen types"
+ dg_tens_meta[tens.name] = tens_meta
+
+ tens_data = {
+ "version": "0.1",
+ "tensors": dg_tens_meta,
+ }
+ return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
+
+ @staticmethod
def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
if dtypeList[0] == DType.INT32 and error_name is None:
pCount, cCount = op["operands"]
@@ -1024,6 +1110,50 @@ class TosaArgGen:
pass
@staticmethod
+ def _add_data_generators(testGen, opName, dtype, arg_list, error_name, **kwargs):
+ """Add extra tests for each type of data generator for this op."""
+ if error_name is None and "data_gen" in testGen.TOSA_OP_LIST[opName]:
+ if dtype in [DType.FP16, DType.FP32, DType.BF16]:
+ dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
+ else:
+ dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
+ else:
+ # Error test or No data generator types listed - assume random
+ dataGenTypesList = (gtu.DataGenType.PSEUDO_RANDOM,)
+
+ # Expand arg list with other data generator types
+ new_arg_list = []
+ for dg_type in dataGenTypesList:
+ for arg_str, arg_attrs in arg_list:
+ arg_dict = arg_attrs[0]
+ arg_dict["dg_type"] = dg_type
+
+ if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
+ # Default test
+ new_arg_list.append((arg_str, [arg_dict]))
+
+ elif dg_type == gtu.DataGenType.DOT_PRODUCT:
+ # Extra tests for each dot product test set
+ dot_products = kwargs["dot_products"]
+ if dot_products < testGen.TOSA_MI_DOT_PRODUCT_MIN:
+ print(
+ f"Skipping dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
+ )
+ continue
+ arg_dict["ks"] = kwargs["ks"]
+ for key in gtu.DG_DOT_PRODUCT_OPTIONAL_INFO:
+ if key in kwargs:
+ arg_dict[key] = kwargs[key]
+
+ for s in testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS:
+ new_arg_str = f"{arg_str}_s{s}"
+ new_arg_dict = arg_dict.copy()
+ new_arg_dict["s"] = s
+ new_arg_list.append((new_arg_str, [new_arg_dict]))
+
+ return new_arg_list
+
+ @staticmethod
def agNone(testGen, opName, shapeList, dtype, error_name=None):
"""A trivial argument generator for operators that don't take any
non-tensor arguments"""
@@ -1073,7 +1203,7 @@ class TosaArgGen:
# Shape: (OFM channels), (KD), KH, KW, IFM channels
filter_shape = shapeList[1]
- accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
+ accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
# Check the rank
conv3d = opName.startswith("conv3d")
@@ -1258,12 +1388,12 @@ class TosaArgGen:
input_dtype = dtypes[0]
if error_name == ErrorIf.WrongOutputType:
- accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype)
+ accum_dtype = gtu.get_wrong_output_type(opName, testGen.rng, input_dtype)
elif error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output dtype if input type is incorrect
accum_dtype = DType.INT32
else:
- accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
+ accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
return [(f"acc{testGen.typeStr(accum_dtype)}", [accum_dtype])]
@@ -1285,12 +1415,28 @@ class TosaArgGen:
if error_name == ErrorIf.WrongOutputType:
# Get incorrect output dtype for ErrorIf case
- accum_dtypes = [get_wrong_output_type(opName, testGen.rng, dtype)]
+ accum_dtypes = [gtu.get_wrong_output_type(opName, testGen.rng, dtype)]
elif error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output dtype if input type is incorrect
accum_dtypes = [DType.INT32]
- return [(f"acc{testGen.typeStr(a)}", [a]) for a in accum_dtypes]
+ arg_list = [
+ (f"acc{testGen.typeStr(a)}", [{"acc_type": a}]) for a in accum_dtypes
+ ]
+
+ arg_list = TosaArgGen._add_data_generators(
+ testGen,
+ opName,
+ dtype,
+ arg_list,
+ error_name,
+ ks=int(shapeList[0][2]), # Set KS = C, from input A (N,H,C)
+ # Set dot_products = N*H*W
+ dot_products=gtu.product(
+ (shapeList[0][0], shapeList[0][1], shapeList[1][2])
+ ),
+ )
+ return arg_list
@staticmethod
def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
@@ -1303,7 +1449,7 @@ class TosaArgGen:
ifm_shape = shapeList[0]
filter_shape = shapeList[1]
- accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
+ accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
# Must be rank 4
if error_name != ErrorIf.WrongRank:
@@ -2288,9 +2434,9 @@ class TosaArgGen:
if (
output_y <= 0
- or output_y >= MAX_RESIZE_DIMENSION
+ or output_y >= gtu.MAX_RESIZE_DIMENSION
or output_x <= 0
- or output_x >= MAX_RESIZE_DIMENSION
+ or output_x >= gtu.MAX_RESIZE_DIMENSION
):
# Output dimensions out of scope
if error_name is not None and perm > 0:
@@ -2301,11 +2447,11 @@ class TosaArgGen:
if error_name == ErrorIf.ResizeOutputShapeMismatch and (
(
- output_y + scale_y_d >= MAX_RESIZE_DIMENSION
+ output_y + scale_y_d >= gtu.MAX_RESIZE_DIMENSION
and output_y - scale_y_d < 1
)
or (
- output_x + scale_x_d >= MAX_RESIZE_DIMENSION
+ output_x + scale_x_d >= gtu.MAX_RESIZE_DIMENSION
and output_x - scale_x_d < 1
)
):