aboutsummaryrefslogtreecommitdiff
path: root/verif
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-10-18 17:22:21 +0100
committerEric Kunze <eric.kunze@arm.com>2023-11-02 23:22:09 +0000
commitd1a08ce27ef8d0f6cf77e1b864610aade06edc5c (patch)
tree777992f45d240361f898b1d21902c2a46c58235f /verif
parentb0b9e33c3500bd8dc9b12ef012d4234b1245247a (diff)
downloadreference_model-d1a08ce27ef8d0f6cf77e1b864610aade06edc5c.tar.gz
Compliance mode testing for CONV2D
Added CONV2D data generation. Updated verify dot product check to latest specification. Updated test generator and python datagenerator library to create const files during test generation. Add support for compliance test sets to conformance test_select. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I5be3b761a1e3ef259c058e493877cd5a89d5778b
Diffstat (limited to 'verif')
-rw-r--r--verif/conformance/test_select.py26
-rw-r--r--verif/conformance/tosa_main_profile_ops_info.json1
-rw-r--r--verif/generator/datagenerator.py59
-rw-r--r--verif/generator/tosa_arg_gen.py108
-rw-r--r--verif/generator/tosa_test_gen.py130
-rw-r--r--verif/generator/tosa_utils.py14
-rw-r--r--verif/tests/test_tosa_datagenerator.py14
7 files changed, 247 insertions, 105 deletions
diff --git a/verif/conformance/test_select.py b/verif/conformance/test_select.py
index b7bbfc3..faefc85 100644
--- a/verif/conformance/test_select.py
+++ b/verif/conformance/test_select.py
@@ -125,6 +125,8 @@ class Operator:
# Working set of param_names - updated for negative tests
wks_param_names = None
+ COMPLIANCE_SETS = ("_s0", "_s1", "_s2", "_s3", "_s4", "_s5")
+
def __init__(
self,
test_dir: Path,
@@ -258,7 +260,15 @@ class Operator:
if (not negative and "ERRORIF" not in str(path)) or (
negative and "ERRORIF" in str(path)
):
- yield path
+ # Check for compliance test set paths
+ suffix = path.name[-3:]
+ if suffix in Operator.COMPLIANCE_SETS:
+ if suffix != Operator.COMPLIANCE_SETS[0]:
+ # Only return one of the test sets
+ continue
+ yield path.with_name(path.name[:-3])
+ else:
+ yield path
@classmethod
def get_test_paths(cls, test_dir: Path, negative):
@@ -343,7 +353,12 @@ class Operator:
for k in path_params:
unused_values[k].discard(path_params[k])
logger.debug(f"FOUND wanted: {path.name}")
- yield path
+ if path.exists():
+ yield path
+ else:
+ # Compliance test series - expand to all sets
+ for s in Operator.COMPLIANCE_SETS:
+ yield path.with_name(f"{path.name}{s}")
# search for tests that match any unused parameter values
for n, path in enumerate(sorted(list(unused_paths))):
@@ -359,7 +374,12 @@ class Operator:
unused_values[p].discard(path_params[p])
sparsity = self.sparsity[k] if k in self.sparsity else 0
logger.debug(f"FOUND unused [{k}/{n}/{sparsity}]: {path.name}")
- yield path
+ if path.exists():
+ yield path
+ else:
+ # Compliance test series - expand to all sets
+ for s in Operator.COMPLIANCE_SETS:
+ yield path.with_name(f"{path.name}{s}")
break
if not self.ignore_missing:
diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json
index 9c18879..a090479 100644
--- a/verif/conformance/tosa_main_profile_ops_info.json
+++ b/verif/conformance/tosa_main_profile_ops_info.json
@@ -598,6 +598,7 @@
"profile": [
"tosa-mi"
],
+ "support_for": [ "lazy_data_gen" ],
"generation": {
"standard": {
"negative_dim_range": "1,10",
diff --git a/verif/generator/datagenerator.py b/verif/generator/datagenerator.py
index 408c83e..0d59084 100644
--- a/verif/generator/datagenerator.py
+++ b/verif/generator/datagenerator.py
@@ -6,7 +6,7 @@ import json
from pathlib import Path
import numpy as np
-from schemavalidation import schemavalidation
+import schemavalidation.schemavalidation as sch
class GenerateError(Exception):
@@ -14,7 +14,15 @@ class GenerateError(Exception):
class GenerateLibrary:
- """Python interface to the C generate library."""
+ """Python interface to the C generate library.
+
+ Simple usage to write out all input files:
+ set_config(test_desc)
+ write_numpy_files(test_path)
+
+ To get data buffers (for const data):
+ get_tensor_data(tensor_name)
+ """
def __init__(self, generate_lib_path):
"""Find the library and set up the interface."""
@@ -22,6 +30,8 @@ class GenerateLibrary:
if not self.lib_path.is_file():
raise GenerateError(f"Could not find generate library - {self.lib_path}")
+ self.schema_validator = sch.TestDescSchemaValidator()
+
self.test_desc = None
self.json_config = None
self.lib = ct.cdll.LoadLibrary(self.lib_path)
@@ -51,8 +61,7 @@ class GenerateLibrary:
raise GenerateError("No meta/data_gen section found in desc.json")
# Validate the config versus the schema
- tdsv = schemavalidation.TestDescSchemaValidator()
- tdsv.validate_config(test_desc)
+ self.schema_validator.validate_config(test_desc)
self.test_desc = test_desc
self.json_config = test_desc["meta"]["data_gen"]
@@ -72,25 +81,25 @@ class GenerateLibrary:
return buffer, size_bytes
- def _data_gen_write(
- self, test_path: Path, json_bytes: bytes, ifm_name: str, ifm_file: str
- ):
- """Generate the named tensor data and save it in numpy format."""
+ def _data_gen_array(self, json_config: str, tensor_name: str):
+ """Generate the named tensor data and return a numpy array."""
try:
- tensor = self.json_config["tensors"][ifm_name]
+ tensor = json_config["tensors"][tensor_name]
dtype = tensor["data_type"]
shape = tuple(tensor["shape"])
except KeyError as e:
raise GenerateError(
- f"Missing data in desc.json for input {ifm_name} - {repr(e)}"
+ f"Missing data in json config for input {tensor_name} - {repr(e)}"
)
buffer, size_bytes = self._create_buffer(dtype, shape)
buffer_ptr = ct.cast(buffer, ct.c_void_p)
+ json_bytes = bytes(json.dumps(json_config), "utf8")
+
result = self.tgd_generate_data(
ct.c_char_p(json_bytes),
- ct.c_char_p(bytes(ifm_name, "utf8")),
+ ct.c_char_p(bytes(tensor_name, "utf8")),
buffer_ptr,
ct.c_size_t(size_bytes),
)
@@ -100,11 +109,19 @@ class GenerateLibrary:
arr = np.ctypeslib.as_array(buffer)
arr = np.reshape(arr, shape)
+ return arr
+
+ def _data_gen_write(
+ self, test_path: Path, json_config: str, ifm_name: str, ifm_file: str
+ ):
+ """Generate the named tensor data and save it in numpy format."""
+ arr = self._data_gen_array(json_config, ifm_name)
+
file_name = test_path / ifm_file
np.save(file_name, arr)
def write_numpy_files(self, test_path: Path):
- """Write out all the specified tensors to numpy data files."""
+ """Write out all the desc.json input tensors to numpy data files."""
if self.test_desc is None or self.json_config is None:
raise GenerateError("Cannot write numpy files as no config set up")
@@ -114,12 +131,10 @@ class GenerateLibrary:
except KeyError as e:
raise GenerateError(f"Missing data in desc.json - {repr(e)}")
- json_bytes = bytes(json.dumps(self.json_config), "utf8")
-
failures = []
for iname, ifile in zip(ifm_names, ifm_files):
try:
- self._data_gen_write(test_path, json_bytes, iname, ifile)
+ self._data_gen_write(test_path, self.json_config, iname, ifile)
except GenerateError as e:
failures.append(
f"ERROR: Failed to create data for tensor {iname} - {repr(e)}"
@@ -128,6 +143,20 @@ class GenerateLibrary:
if len(failures) > 0:
raise GenerateError("\n".join(failures))
+ def get_tensor_data(self, tensor_name: str, json_config=None):
+ """Get a numpy array for a named tensor in the data_gen meta data."""
+ if json_config is None:
+ if self.json_config is None:
+ raise GenerateError("Cannot get tensor data as no config set up")
+ json_config = self.json_config
+ else:
+ # Validate the given config
+ self.schema_validator.validate_config(
+ json_config, schema_type=sch.TD_SCHEMA_DATA_GEN
+ )
+
+ return self._data_gen_array(json_config, tensor_name)
+
def main(argv=None):
"""Simple command line interface for the data generator."""
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index f7837a0..32f4341 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -638,9 +638,9 @@ class TosaTensorValuesGen:
if (
error_name is not None
or not gtu.dtypeIsSupportedByCompliance(dtypeList[0])
- or opName in ("avg_pool2d",)
+ or "data_gen" not in testGen.TOSA_OP_LIST[opName]
):
- # Fall back to original path when dealing with unsupported types
+ # Fall back to original path when dealing with unsupported types or ops
# First turn off lazy data gen so we always produce data
lazy_data_gen = testGen.args.lazy_data_gen
@@ -660,7 +660,11 @@ class TosaTensorValuesGen:
# Create data generator meta-data
dg_type = argsDict["dg_type"]
- dg_tens_meta = {}
+ tens_data = {
+ "version": "0.1",
+ "tensors": {},
+ }
+ dg_tens_meta = tens_data["tensors"]
tens_ser_list = []
for idx, shape in enumerate(shapeList):
@@ -669,15 +673,12 @@ class TosaTensorValuesGen:
tens_meta["data_type"] = gtu.DTYPE_ATTRIBUTES[dtypeList[idx]]["json"]
tens_meta["shape"] = [int(i) for i in shape]
tens_meta["input_pos"] = idx
- tens_meta["op"] = opName.upper()
+ tens_meta["op"] = gtu.getOpNameFromOpListName(opName).upper()
if idx < pCount:
tens_meta["input_type"] = "VARIABLE"
- tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], None)
else:
tens_meta["input_type"] = "CONSTANT"
- tens = testGen.ser.addConst(shape, dtypeList[idx], None)
- tens_ser_list.append(tens)
if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
info = {}
@@ -691,23 +692,55 @@ class TosaTensorValuesGen:
elif dg_type == gtu.DataGenType.DOT_PRODUCT:
info = {}
info["s"] = argsDict["s"]
- info["ks"] = argsDict["ks"]
- for key in gtu.DG_DOT_PRODUCT_OPTIONAL_INFO:
- if key in argsDict:
- if key.endswith("_type"):
- info[key] = gtu.DTYPE_ATTRIBUTES[argsDict[key]]["json"]
- else:
- info[key] = argsDict[key]
+ info["ks"] = int(argsDict["ks"])
+ if "acc_type" in argsDict:
+ # Convert type number into JSON name
+ info["acc_type"] = gtu.DTYPE_ATTRIBUTES[argsDict["acc_type"]][
+ "json"
+ ]
+ if "kernel" in argsDict:
+ info["kernel"] = [int(k) for k in argsDict["kernel"]]
+ if "axis" in argsDict:
+ info["axis"] = int(argsDict["axis"])
tens_meta["dot_product_info"] = info
else:
# TODO - other data gen type
assert False, "TODO: support other data gen types"
+
+ # Using the finished generate config meta data - generate the data if
+ # needed and assign a tensor name from the serializer
+
+ # Need to generate data when not lazy or for the bias tensor as we need
+ # to work out if the bias data is non-zero for compliance
+ if not testGen.args.lazy_data_gen or (
+ idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT
+ ):
+ # Give this tensor a temporary name until we get one from the serializer
+ temp_name = f"placeholder_{idx}"
+ dg_tens_meta[temp_name] = tens_meta
+ # Create data now using the temporary name to access meta details
+ data = testGen.dgl.get_tensor_data(temp_name, tens_data)
+ # Remove the item as we will give it the correct name later
+ del dg_tens_meta[temp_name]
+
+ if idx == 2 and dg_type == gtu.DataGenType.DOT_PRODUCT:
+ # The KS value used by compliance verification is altered when the
+ # bias data is non-zero
+ if max(abs(data)) > 0.0:
+ argsDict["ksb"] = argsDict["ks"] + 1
+
+ if testGen.args.lazy_data_gen:
+ data = None
+
+ if tens_meta["input_type"] == "VARIABLE":
+ tens = testGen.ser.addPlaceholder(shape, dtypeList[idx], data)
+ else:
+ tens = testGen.ser.addConst(shape, dtypeList[idx], data)
+
+ tens_ser_list.append(tens)
+ # Add the meta data to the list using the serializer tensor name
dg_tens_meta[tens.name] = tens_meta
- tens_data = {
- "version": "0.1",
- "tensors": dg_tens_meta,
- }
return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
@staticmethod
@@ -1206,8 +1239,11 @@ class TosaArgGen:
accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
- # Check the rank
+ # Op type checks
conv3d = opName.startswith("conv3d")
+ depthwise = opName.startswith("depthwise")
+
+ # Check the rank
rank = 5 if conv3d else 4
if error_name != ErrorIf.WrongRank:
assert len(ifm_shape) == rank
@@ -1215,8 +1251,12 @@ class TosaArgGen:
# kernel rank omits channels
k_rank = rank - 2
- k_pos = 0 if opName.startswith("depthwise") else 1
+ k_pos = 0 if depthwise else 1
k_shape = tuple(filter_shape[k_pos : (k_pos + k_rank)])
+ # compliance size - KS
+ k_size = gtu.product(k_shape)
+ if not depthwise:
+ k_size *= ifm_shape[-1]
if not testGen.args.level8k:
# Generate comprehensive argument lists
@@ -1363,6 +1403,24 @@ class TosaArgGen:
# Test will consume too much memory - skip it
continue
+ # Compliance - number of dot product calculations
+ if depthwise:
+ # TODO - add support
+ dots = 0
+ else:
+ dots = gtu.product(
+ (ifm_shape[0], *outputs, filter_shape[0])
+ )
+ args_dict = {
+ "acc_type": accum_dtype,
+ "stride": s,
+ "pad": p,
+ "dilation": d,
+ "kernel": k_shape,
+ "ks": k_size,
+ "dot_products": dots,
+ }
+
# Support for larger values than 9 needs different delimiter
delim = "" if max(s + p + d) <= 9 else "x"
arg_list.append(
@@ -1373,11 +1431,19 @@ class TosaArgGen:
delim.join([str(x) for x in p]),
delim.join([str(x) for x in d]),
),
- [accum_dtype, s, p, d],
+ args_dict,
)
)
n += 1
+ arg_list = TosaArgGen._add_data_generators(
+ testGen,
+ opName,
+ dtypes[0],
+ arg_list,
+ error_name,
+ )
+ # Return list of tuples: (arg_str, args_dict)
return arg_list
@staticmethod
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 17cbd8f..54b624e 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -56,11 +56,9 @@ class TosaTestGen:
self.random_fp_high = max(args.tensor_fp_value_range)
# JSON schema validation
self.descSchemaValidator = TestDescSchemaValidator()
- # Data generator library when not generating the data later
- if not args.lazy_data_gen:
- self.dgl = GenerateLibrary(args.generate_lib_path)
- else:
- self.dgl = None
+ # Data generator library is sometimes needed for compliance set up
+ # even if we are generating the data later (lazy_data_generation)
+ self.dgl = GenerateLibrary(args.generate_lib_path)
def createSerializer(self, opName, testPath):
self.testPath = os.path.join(opName, testPath)
@@ -108,11 +106,6 @@ class TosaTestGen:
fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
json.dump(metaData["data_gen"], fd)
fd.write(')";\n\n')
- else:
- # Generate the data
- self.dgl.set_config(desc)
- self.dgl.write_numpy_files(path)
-
if "compliance" in metaData:
# Output datagen meta data as CPP data
path_md = path / f"{testName}_meta_compliance.cpp"
@@ -293,9 +286,15 @@ class TosaTestGen:
low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
)
- def tensorComplianceMetaData(self, op, argsDict, outputTensor, errorName):
- if errorName or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype):
- # No compliance for error tests or other data types currently
+ def tensorComplianceMetaData(
+ self, op, inputType, argsDict, outputTensor, errorName
+ ):
+ if (
+ errorName
+ or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
+ or not gtu.dtypeIsSupportedByCompliance(inputType)
+ ):
+ # No compliance for error tests or unsupported types currently
return None
# Create compliance meta data for expected output tensor
@@ -308,7 +307,9 @@ class TosaTestGen:
mode = gtu.ComplianceMode.DOT_PRODUCT
compliance_tens["dot_product_info"] = {
"s": argsDict["s"],
- "ks": argsDict["ks"],
+ "ks": int(argsDict["ksb"])
+ if "ksb" in argsDict
+ else int(argsDict["ks"]),
}
elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
mode = gtu.ComplianceMode.FP_SPECIAL
@@ -741,31 +742,30 @@ class TosaTestGen:
error_name,
qinfo,
)
- if gtu.dtypeIsSupportedByCompliance(inputs[0].dtype):
- compliance = self.tensorComplianceMetaData(
- op, args_dict, result_tensor, error_name
- )
- else:
- compliance = None
+ compliance = self.tensorComplianceMetaData(
+ op, inputs[0].dtype, args_dict, result_tensor, error_name
+ )
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_conv2d(
self,
op,
- ifm,
- filter,
- bias,
- accum_dtype,
- strides,
- padding,
- dilations,
+ inputs,
+ args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
+ assert len(inputs) == 3
+ ifm, filter, bias = inputs
+ accum_dtype = args_dict["acc_type"]
+ strides = args_dict["stride"]
+ padding = args_dict["pad"]
+ dilations = args_dict["dilation"]
+
assert len(padding) == 4
- result_tens = OutputShaper.conv2dOp(
+ result_tensor = OutputShaper.conv2dOp(
self.ser,
self.rng,
ifm,
@@ -784,12 +784,12 @@ class TosaTestGen:
):
qinfo = [
TosaQuantGen.getZeroPoint(self, ifm.dtype),
- TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
- output_list = [result_tens.name]
+ output_list = [result_tensor.name]
num_operands = sum(op["operands"])
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
@@ -802,7 +802,7 @@ class TosaTestGen:
op=op,
input_dtype=ifm.dtype,
weight_dtype=filter.dtype,
- output_dtype=result_tens.dtype,
+ output_dtype=result_tensor.dtype,
qinfo=qinfo,
input_list=input_list,
num_operands=num_operands,
@@ -812,7 +812,7 @@ class TosaTestGen:
dilation=dilations,
input_shape=ifm.shape,
weight_shape=filter.shape,
- output_shape=result_tens.shape,
+ output_shape=result_tensor.shape,
):
return None
@@ -820,22 +820,29 @@ class TosaTestGen:
attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
self.ser.addOperator(op["op"], input_list, output_list, attr)
- return result_tens
+
+ compliance = self.tensorComplianceMetaData(
+ op, ifm.dtype, args_dict, result_tensor, error_name
+ )
+
+ return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_conv3d(
self,
op,
- ifm,
- filter,
- bias,
- accum_dtype,
- strides,
- padding,
- dilations,
+ inputs,
+ args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
+ assert len(inputs) == 3
+ ifm, filter, bias = inputs
+ accum_dtype = args_dict["acc_type"]
+ strides = args_dict["stride"]
+ padding = args_dict["pad"]
+ dilations = args_dict["dilation"]
+
assert len(padding) == 6
result_tens = OutputShaper.conv3dOp(
self.ser,
@@ -960,17 +967,19 @@ class TosaTestGen:
def build_depthwise_conv2d(
self,
op,
- ifm,
- filter,
- bias,
- accum_dtype,
- strides,
- padding,
- dilations,
+ inputs,
+ args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
+ assert len(inputs) == 3
+ ifm, filter, bias = inputs
+ accum_dtype = args_dict["acc_type"]
+ strides = args_dict["stride"]
+ padding = args_dict["pad"]
+ dilations = args_dict["dilation"]
+
result_tens = OutputShaper.depthwiseConv2dOp(
self.ser,
self.rng,
@@ -1121,12 +1130,9 @@ class TosaTestGen:
self.ser.addOperator(op["op"], input_list, output_list, attr)
- if gtu.dtypeIsSupportedByCompliance(a.dtype):
- compliance = self.tensorComplianceMetaData(
- op, args_dict, result_tensor, error_name
- )
- else:
- compliance = None
+ compliance = self.tensorComplianceMetaData(
+ op, a.dtype, args_dict, result_tensor, error_name
+ )
return TosaTestGen.BuildInfo(result_tensor, compliance)
@@ -1431,12 +1437,9 @@ class TosaTestGen:
self.ser.addOperator(op["op"], input_list, output_list, attr)
- if gtu.dtypeIsSupportedByCompliance(a.dtype):
- compliance = self.tensorComplianceMetaData(
- op, args_dict, result_tensor, error_name
- )
- else:
- compliance = None
+ compliance = self.tensorComplianceMetaData(
+ op, a.dtype, args_dict, result_tensor, error_name
+ )
return TosaTestGen.BuildInfo(result_tensor, compliance)
@@ -2911,7 +2914,7 @@ class TosaTestGen:
"build_fcn": (
build_conv2d,
TosaTensorGen.tgConv2D,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agConv,
),
"qgen": TosaQuantGen.qgConv,
@@ -2931,6 +2934,9 @@ class TosaTestGen:
TosaErrorValidator.evConvOutputShapeMismatch,
TosaErrorValidator.evConvOutputShapeNonInteger,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.DOT_PRODUCT,),
+ },
"template": True,
},
# Templated operator. Filled in by createDynamicOpLists
@@ -2941,7 +2947,7 @@ class TosaTestGen:
"build_fcn": (
build_conv3d,
TosaTensorGen.tgConv3D,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agConv,
),
"qgen": TosaQuantGen.qgConv,
@@ -2972,7 +2978,7 @@ class TosaTestGen:
"build_fcn": (
build_depthwise_conv2d,
TosaTensorGen.tgDepthwiseConv2D,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agConv,
),
"qgen": TosaQuantGen.qgConv,
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 14afaa7..7fc5b52 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -51,15 +51,21 @@ class DataGenType(IntEnum):
OP_SPECIAL = 4
-# Additional (optional) data for dot product data generator
-DG_DOT_PRODUCT_OPTIONAL_INFO = ("acc_type", "kernel", "axis")
-
-
def dtypeIsSupportedByCompliance(dtype):
"""Types supported by the new data generation and compliance flow."""
+ if isinstance(dtype, list) or isinstance(dtype, tuple):
+ dtype = dtype[0]
return dtype in (DType.FP32,)
+def getOpNameFromOpListName(opName):
+ """Get the op name from a TOSA_OP_LIST name that can have suffixes."""
+ for name in ("conv2d", "depthwise_conv2d", "transpose_conv2d", "conv3d"):
+ if opName.startswith(name):
+ return name
+ return opName
+
+
def valueToName(item, value):
"""Get the name of an attribute with the given value.
diff --git a/verif/tests/test_tosa_datagenerator.py b/verif/tests/test_tosa_datagenerator.py
index ba0235c..4f3d7fd 100644
--- a/verif/tests/test_tosa_datagenerator.py
+++ b/verif/tests/test_tosa_datagenerator.py
@@ -114,3 +114,17 @@ def test_generate_dot_product_check_fail_names():
for f in json_config["ifm_file"]:
file = TEST_DIR / f
assert not file.is_file()
+
+
+@pytest.mark.postcommit
+def test_generate_tensor_data_check():
+ glib = GenerateLibrary(GENERATE_LIB_PATH)
+ assert glib
+
+ json_config = JSON_DATAGEN_DOT_PRODUCT["meta"]["data_gen"]
+
+ for n in JSON_DATAGEN_DOT_PRODUCT["ifm_name"]:
+ arr = glib.get_tensor_data(n, json_config)
+
+ assert arr.shape == tuple(json_config["tensors"][n]["shape"])
+ assert arr.dtype == np.float32