aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-11-20 16:15:30 +0000
committerEric Kunze <eric.kunze@arm.com>2023-11-30 18:52:24 +0000
commit3047625f7d4b3a77cb3a3480481122f7ba01be2d (patch)
tree125ce52f1b9f65090a0bdb1c2fafeb8e0c516425
parent35a3aa994cf18f735193a05a7eb2c61d497233d2 (diff)
downloadreference_model-3047625f7d4b3a77cb3a3480481122f7ba01be2d.tar.gz
Adjust random data ranges for Main Compliance to avoid FP inf and nan
POW - there are now 3 test sets to cover random ranges. Also added ROUND mode to data generator to force integer exponent values. LOG, EXP, RSQRT, REDUCE_SUM & FULLY_CONNECTED - have had their ranges reduced for each test. Fix generate library configuration defaults and checks. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ie5d3bd78f690cc787a2ca4eb9b4bd6808bd9238c
-rw-r--r--reference_model/src/generate/generate_dot_product.cc7
-rw-r--r--reference_model/src/generate/generate_pseudo_random.cc5
-rw-r--r--reference_model/src/generate/generate_utils.cc16
-rw-r--r--reference_model/src/generate/generate_utils.h1
-rw-r--r--scripts/schemavalidation/datagen-config.schema.json4
-rw-r--r--verif/conformance/test_select.py45
-rw-r--r--verif/generator/tosa_arg_gen.py302
-rw-r--r--verif/generator/tosa_test_gen.py14
8 files changed, 355 insertions, 39 deletions
diff --git a/reference_model/src/generate/generate_dot_product.cc b/reference_model/src/generate/generate_dot_product.cc
index 4054472..c8a2b13 100644
--- a/reference_model/src/generate/generate_dot_product.cc
+++ b/reference_model/src/generate/generate_dot_product.cc
@@ -387,7 +387,12 @@ bool generateDotProduct(const GenerateConfig& cfg, void* data, size_t size)
if (!generator)
{
WARNING("[Generator][DP] Requested generator could not be created!");
- return 0;
+ return false;
+ }
+ if (cfg.dotProductInfo.ks <= 0)
+ {
+ WARNING("[Generator][DP] Invalid test set kernel size %d.", cfg.dotProductInfo.ks);
+ return false;
}
// Select which generator to use
diff --git a/reference_model/src/generate/generate_pseudo_random.cc b/reference_model/src/generate/generate_pseudo_random.cc
index d8d2288..b51424d 100644
--- a/reference_model/src/generate/generate_pseudo_random.cc
+++ b/reference_model/src/generate/generate_pseudo_random.cc
@@ -93,6 +93,7 @@ bool generateFP32(const TosaReference::GenerateConfig& cfg, void* data, size_t s
const TosaReference::PseudoRandomInfo& prinfo = cfg.pseudoRandomInfo;
PseudoRandomGeneratorFloat<float>* generator;
+ bool roundMode = prinfo.round;
if (prinfo.range.size() == 2)
{
@@ -117,6 +118,10 @@ bool generateFP32(const TosaReference::GenerateConfig& cfg, void* data, size_t s
// Set every 4th value to 0 to enable better comparison testing
a[t] = 0.f;
}
+ else if (roundMode)
+ {
+ a[t] = std::roundf(a[t]);
+ }
}
return true;
}
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
index 1edc79d..58a3d33 100644
--- a/reference_model/src/generate/generate_utils.cc
+++ b/reference_model/src/generate/generate_utils.cc
@@ -116,6 +116,10 @@ void from_json(const nlohmann::json& j, PseudoRandomInfo& pseudoRandomInfo)
{
j.at("range").get_to(pseudoRandomInfo.range);
}
+ if (j.contains("round"))
+ {
+ j.at("round").get_to(pseudoRandomInfo.round);
+ }
}
void from_json(const nlohmann::json& j, GenerateConfig& cfg)
@@ -126,10 +130,22 @@ void from_json(const nlohmann::json& j, GenerateConfig& cfg)
j.at("input_pos").get_to(cfg.inputPos);
j.at("op").get_to(cfg.opType);
j.at("generator").get_to(cfg.generatorType);
+
+ // Set up defaults for dotProductInfo
+ cfg.dotProductInfo.s = -1;
+ cfg.dotProductInfo.ks = -1;
+ cfg.dotProductInfo.accType = DType_UNKNOWN;
+ cfg.dotProductInfo.kernel = std::vector<int32_t>();
+ cfg.dotProductInfo.axis = -1;
if (j.contains("dot_product_info"))
{
j.at("dot_product_info").get_to(cfg.dotProductInfo);
}
+
+ // Set up defaults for pseudoRandomInfo
+ cfg.pseudoRandomInfo.rngSeed = -1;
+ cfg.pseudoRandomInfo.range = std::vector<std::string>();
+ cfg.pseudoRandomInfo.round = false;
if (j.contains("pseudo_random_info"))
{
j.at("pseudo_random_info").get_to(cfg.pseudoRandomInfo);
diff --git a/reference_model/src/generate/generate_utils.h b/reference_model/src/generate/generate_utils.h
index 8d0f654..f9ec713 100644
--- a/reference_model/src/generate/generate_utils.h
+++ b/reference_model/src/generate/generate_utils.h
@@ -62,6 +62,7 @@ struct PseudoRandomInfo
int64_t rngSeed;
std::vector<std::string> range;
+ bool round;
};
/// \brief Generator configuration
diff --git a/scripts/schemavalidation/datagen-config.schema.json b/scripts/schemavalidation/datagen-config.schema.json
index 68789f6..08d564b 100644
--- a/scripts/schemavalidation/datagen-config.schema.json
+++ b/scripts/schemavalidation/datagen-config.schema.json
@@ -69,6 +69,10 @@
"description": "[low value, high value] as strings to allow ints and floats",
"type": "string"
}
+ },
+ "round": {
+ "type": "boolean",
+ "description": "force rounding of all values"
}
},
"additionalProperties": false,
diff --git a/verif/conformance/test_select.py b/verif/conformance/test_select.py
index faefc85..cebdf62 100644
--- a/verif/conformance/test_select.py
+++ b/verif/conformance/test_select.py
@@ -125,8 +125,6 @@ class Operator:
# Working set of param_names - updated for negative tests
wks_param_names = None
- COMPLIANCE_SETS = ("_s0", "_s1", "_s2", "_s3", "_s4", "_s5")
-
def __init__(
self,
test_dir: Path,
@@ -260,13 +258,13 @@ class Operator:
if (not negative and "ERRORIF" not in str(path)) or (
negative and "ERRORIF" in str(path)
):
- # Check for compliance test set paths
- suffix = path.name[-3:]
- if suffix in Operator.COMPLIANCE_SETS:
- if suffix != Operator.COMPLIANCE_SETS[0]:
- # Only return one of the test sets
- continue
- yield path.with_name(path.name[:-3])
+ # Check for test set paths
+ match = re.match(r"(.*)_s([0-9]+)", path.name)
+ if match:
+ if match.group(2) == "0":
+ # Only return the truncated test name
+ # of the first test of a set
+ yield path.with_name(match.group(1))
else:
yield path
@@ -298,6 +296,23 @@ class Operator:
params[param] = sorted(list(params[param]))
return params
+ @staticmethod
+ def _get_test_set_paths(path):
+ """Expand a path to find all the test sets."""
+ s = 0
+ paths = []
+ # Have a bound for the maximum test sets
+ while s < 100:
+ set_path = path.with_name(f"{path.name}_s{s}")
+ if set_path.exists():
+ paths.append(set_path)
+ else:
+ if s == 0:
+ logger.error(f"Could not find test set 0 - {str(set_path)}")
+ break
+ s += 1
+ return paths
+
def select_tests(self): # noqa: C901 (function too complex)
"""Generate the paths to the selected tests for this operator."""
if not self.test_paths:
@@ -356,9 +371,9 @@ class Operator:
if path.exists():
yield path
else:
- # Compliance test series - expand to all sets
- for s in Operator.COMPLIANCE_SETS:
- yield path.with_name(f"{path.name}{s}")
+ # Must be a test set - expand to all test sets
+ for p in Operator._get_test_set_paths(path):
+ yield p
# search for tests that match any unused parameter values
for n, path in enumerate(sorted(list(unused_paths))):
@@ -377,9 +392,9 @@ class Operator:
if path.exists():
yield path
else:
- # Compliance test series - expand to all sets
- for s in Operator.COMPLIANCE_SETS:
- yield path.with_name(f"{path.name}{s}")
+ # Must be a test set - expand to all test sets
+ for p in Operator._get_test_set_paths(path):
+ yield p
break
if not self.ignore_missing:
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 9147605..3057963 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -635,17 +635,40 @@ class TosaTensorValuesGen:
DType.BF16: (1 << 128) - (1 << (127 - 7)),
}
+ # Default lowest normal values for random numbers
+ TVG_FLOAT_LOW_VALUE = {
+ DType.FP32: np.exp2(-126),
+ DType.FP16: np.exp2(-14),
+ DType.BF16: np.exp2(-126),
+ }
+
@staticmethod
- def _get_data_range(testGen, dtype, highValueLookup):
+ def _get_data_range(testGen, dtype, highValueLookup, lowValueLookup=None):
+ # Return a tuple of (low,high) data range values for the given data
+ # type using a combination of per operator table limits, data limits
+ # and user supplied ranges for FP numbers
if dtype in highValueLookup:
- data_range = testGen.getDTypeRange(dtype, high_inclusive=True)
+ type_range = testGen.getDTypeRange(dtype, high_inclusive=True)
high_val = highValueLookup[dtype]
+ if lowValueLookup is not None and dtype in lowValueLookup:
+ low_val = lowValueLookup[dtype]
+ else:
+ low_val = -high_val
# Set the values to something that won't produce infinity whilst
- # respecting the default ranges if less than the high value
- return [
- max(-high_val, data_range[0]),
- min(high_val, data_range[1]),
- ]
+ # respecting the default ranges if more/less than the low/high
+ # values
+ data_range = (
+ max(low_val, type_range[0]),
+ min(high_val, type_range[1]),
+ )
+ if data_range[0] > data_range[1]:
+ # Invalid data range from low to high created due to user
+ # constraints revert to using internal ranges as they are
+ # known to work
+ msg = f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
+ warnings.warn(msg)
+ data_range = (low_val, high_val)
+ return data_range
return None
@staticmethod
@@ -664,9 +687,18 @@ class TosaTensorValuesGen:
# Fall back to internal data gen when dealing with unsupported types or ops
data_range = argsDict["data_range"] if "data_range" in argsDict else None
for idx, info in enumerate(zip(shapeList, dtypeList)):
+ roundMode = False
shape, dtype = info
+ if "data_range_list" in argsDict:
+ data_range = argsDict["data_range_list"][idx]["range"]
+ roundMode = (
+ "round" in argsDict["data_range_list"][idx]
+ and argsDict["data_range_list"][idx]["round"] is True
+ )
# Ignore lazy data gen option and create data array using any range limits
arr = testGen.getRandTensor(shape, dtype, data_range)
+ if roundMode:
+ arr = np.round(arr)
if idx < pCount:
tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtype, arr))
else:
@@ -699,7 +731,12 @@ class TosaTensorValuesGen:
info = {}
# TODO - generate seed for this generator based on test
info["rng_seed"] = 42
- if "data_range" in argsDict:
+
+ if "data_range_list" in argsDict:
+ data_range = argsDict["data_range_list"][idx]["range"]
+ if "round" in argsDict["data_range_list"][idx]:
+ info["round"] = argsDict["data_range_list"][idx]["round"]
+ elif "data_range" in argsDict:
data_range = argsDict["data_range"]
else:
data_range = testGen.getDTypeRange(
@@ -788,7 +825,7 @@ class TosaTensorValuesGen:
testGen, opName, dtypeList, shapeList, argsDict, error_name
)
- # Set the data range to half the largest value
+ # Set the ADD/SUB data range to half the largest value to avoid infinities
TVG_FLOAT_HIGH_VALUE_ADDSUB = {
DType.FP32: (TVG_FLOAT_HIGH_VALUE[DType.FP32] / 2),
DType.FP16: (TVG_FLOAT_HIGH_VALUE[DType.FP16] / 2),
@@ -987,7 +1024,8 @@ class TosaTensorValuesGen:
testGen, opName, dtypeList, shapeList, argsDict, error_name
)
- # Set the data range to the square root of the largest value
+ # Set the MUL data range to the square root of the largest value
+ # to avoid infinities
TVG_FLOAT_HIGH_VALUE_MUL = {
DType.FP32: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
DType.FP16: math.sqrt(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
@@ -1167,7 +1205,8 @@ class TosaTensorValuesGen:
@staticmethod
def tvgReduceSum(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
- if dtypeList[0] == DType.INT32:
+ dtype = dtypeList[0]
+ if dtype == DType.INT32:
op = testGen.TOSA_OP_LIST[opName]
pCount, cCount = op["operands"]
assert (
@@ -1181,14 +1220,219 @@ class TosaTensorValuesGen:
)
tens_ser_list = []
tens_ser_list.append(
- testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
+ testGen.ser.addPlaceholder(shapeList[0], dtype, values_arr)
)
return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
else:
# ERROR_IF or dot product floating point test
+ if (
+ error_name is None
+ and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
+ ):
+ # Limit ranges for (non error & non compliance) tests by using
+ # values that can be summed on any axis to not hit infinity
+ highval_lookup = {
+ dtype: TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype]
+ / max(shapeList[0])
+ }
+ data_range = TosaTensorValuesGen._get_data_range(
+ testGen, dtype, highval_lookup
+ )
+ assert data_range is not None
+ argsDict["data_range"] = data_range
+
+ return TosaTensorValuesGen.tvgLazyGenDefault(
+ testGen, opName, dtypeList, shapeList, argsDict, error_name
+ )
+
+ # Set the POW exponent high data range
+ TVG_FLOAT_HIGH_VALUE_POW_EXP = {
+ DType.FP32: 10.0,
+ DType.FP16: 10.0,
+ DType.BF16: 10.0,
+ }
+ # POW highest base value (within a safe margin of error) that can be raised
+ # to +ve exponent that doesn't become Infinity
+ TVG_FLOAT_HIGH_VALUE_POW_BASE = {
+ DType.FP32: math.floor(
+ math.pow(
+ TVG_FLOAT_HIGH_VALUE[DType.FP32],
+ 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
+ )
+ ),
+ DType.FP16: math.floor(
+ math.pow(
+ TVG_FLOAT_HIGH_VALUE[DType.FP16],
+ 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
+ )
+ ),
+ DType.BF16: math.floor(
+ math.pow(
+ TVG_FLOAT_HIGH_VALUE[DType.BF16],
+ 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
+ )
+ ),
+ }
+ # POW lowest base value (within a safe margin of error) that can be raised
+ # to -ve exponent that doesn't become Infinity
+ TVG_FLOAT_LOW_VALUE_POW_BASE = {
+ DType.FP32: math.ceil(
+ math.pow(
+ 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP32],
+ 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP32],
+ )
+ * 1000
+ )
+ / 1000,
+ DType.FP16: math.ceil(
+ math.pow(
+ 1.0 / TVG_FLOAT_HIGH_VALUE[DType.FP16],
+ 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.FP16],
+ )
+ * 1000
+ )
+ / 1000,
+ DType.BF16: math.ceil(
+ math.pow(
+ 1.0 / TVG_FLOAT_HIGH_VALUE[DType.BF16],
+ 1.0 / TVG_FLOAT_HIGH_VALUE_POW_EXP[DType.BF16],
+ )
+ * 1000
+ )
+ / 1000,
+ }
+
+ @staticmethod
+ def tvgPow(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
+ if error_name is not None:
return TosaTensorValuesGen.tvgLazyGenDefault(
testGen, opName, dtypeList, shapeList, argsDict, error_name
)
+ dtype = dtypeList[0]
+ # Different ranges for POW
+ test_set = argsDict["s"]
+ if test_set == 0:
+ # Positive base with fractional exponent
+ base_range = TosaTensorValuesGen._get_data_range(
+ testGen,
+ dtype,
+ TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
+ TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
+ )
+ exp_range = TosaTensorValuesGen._get_data_range(
+ testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
+ )
+ exp_round = False
+ else:
+ # Integer exponent
+ exp_range = TosaTensorValuesGen._get_data_range(
+ testGen, dtype, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_EXP
+ )
+ exp_round = True
+ if test_set == 1:
+ # Positive base
+ base_range = TosaTensorValuesGen._get_data_range(
+ testGen,
+ dtype,
+ TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE,
+ TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE,
+ )
+ else:
+ assert test_set == 2
+ # Negative base
+ # Supply new look up tables with negative values
+ base_range = TosaTensorValuesGen._get_data_range(
+ testGen,
+ dtype,
+ {dtype: -TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_POW_BASE[dtype]},
+ {dtype: -TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_POW_BASE[dtype]},
+ )
+
+ data_range_list = (
+ {
+ "range": base_range,
+ },
+ {
+ "range": exp_range,
+ "round": exp_round,
+ },
+ )
+ argsDict["data_range_list"] = data_range_list
+ return TosaTensorValuesGen.tvgLazyGenDefault(
+ testGen, opName, dtypeList, shapeList, argsDict, error_name
+ )
+
+ @staticmethod
+ def tvgLogRsqrt(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
+ # LOG & RSQRT data range from lowest expressible positive number to
+ # largest to avoid NaNs
+ data_range = TosaTensorValuesGen._get_data_range(
+ testGen,
+ dtypeList[0],
+ TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE,
+ TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE,
+ )
+ if data_range:
+ argsDict["data_range"] = data_range
+
+ return TosaTensorValuesGen.tvgLazyGenDefault(
+ testGen, opName, dtypeList, shapeList, argsDict, error_name
+ )
+
+ # Set the EXP data range to the log of the largest to smallest values
+ # to avoid infinities or making the result zero
+ TVG_FLOAT_HIGH_VALUE_EXP = {
+ DType.FP32: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP32]),
+ DType.FP16: math.log(TVG_FLOAT_HIGH_VALUE[DType.FP16]),
+ DType.BF16: math.log(TVG_FLOAT_HIGH_VALUE[DType.BF16]),
+ }
+ TVG_FLOAT_LOW_VALUE_EXP = {
+ DType.FP32: math.log(TVG_FLOAT_LOW_VALUE[DType.FP32]),
+ DType.FP16: math.log(TVG_FLOAT_LOW_VALUE[DType.FP16]),
+ DType.BF16: math.log(TVG_FLOAT_LOW_VALUE[DType.BF16]),
+ }
+
+ @staticmethod
+ def tvgExp(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
+ data_range = TosaTensorValuesGen._get_data_range(
+ testGen,
+ dtypeList[0],
+ TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE_EXP,
+ TosaTensorValuesGen.TVG_FLOAT_LOW_VALUE_EXP,
+ )
+ if data_range:
+ argsDict["data_range"] = data_range
+
+ return TosaTensorValuesGen.tvgLazyGenDefault(
+ testGen, opName, dtypeList, shapeList, argsDict, error_name
+ )
+
+ @staticmethod
+ def tvgFullyConnected(
+ testGen, opName, dtypeList, shapeList, argsDict, error_name=None
+ ):
+ dtype = dtypeList[0]
+ if (
+ error_name is None
+ and argsDict["dg_type"] != gtu.ComplianceMode.DOT_PRODUCT
+ and dtype in (DType.FP16, DType.BF16)
+ ):
+ # TODO - Remove once FP16 and BF16 enabled for DOT_PRODUCT compliance
+ # Limit ranges for (non error & non compliance) FP tests by using
+ # values that can be multiplied on any axis to not hit infinity/NaN
+ IC = shapeList[0][1]
+ highval_lookup = {
+ dtype: math.pow(TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], 1 / IC)
+ }
+ data_range = TosaTensorValuesGen._get_data_range(
+ testGen, dtype, highval_lookup
+ )
+ assert data_range is not None
+ argsDict["data_range"] = data_range
+
+ return TosaTensorValuesGen.tvgLazyGenDefault(
+ testGen, opName, dtypeList, shapeList, argsDict, error_name
+ )
class TosaArgGen:
@@ -1225,8 +1469,14 @@ class TosaArgGen:
for arg_str, args_dict in arg_list:
args_dict["dg_type"] = dg_type
if dg_type == gtu.DataGenType.PSEUDO_RANDOM:
- # Default test
- new_arg_list.append((arg_str, args_dict))
+ if error_name is None:
+ num_test_sets = (
+ args_dict["num_test_sets"]
+ if "num_test_sets" in args_dict
+ else 0
+ )
+ else:
+ num_test_sets = 0
elif dg_type == gtu.DataGenType.DOT_PRODUCT:
# Extra tests for each dot product test set
@@ -1245,11 +1495,17 @@ class TosaArgGen:
assert "ks" in args_dict
assert "acc_type" in args_dict
- for s in testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS:
- new_arg_str = f"{arg_str}_s{s}"
+ num_test_sets = testGen.TOSA_MI_DOT_PRODUCT_TEST_SETS
+
+ if num_test_sets > 0:
+ for s in range(0, num_test_sets):
+ new_arg_str = f"{arg_str}_s{s}" if arg_str else f"s{s}"
new_args_dict = args_dict.copy()
new_args_dict["s"] = s
new_arg_list.append((new_arg_str, new_args_dict))
+ else:
+ # Default is a single test
+ new_arg_list.append((arg_str, args_dict))
return new_arg_list
@@ -1268,6 +1524,20 @@ class TosaArgGen:
return arg_list
@staticmethod
+ def agPow(testGen, opName, shapeList, dtype, error_name=None):
+ """Pow operator needs different test sets to cover random numbers
+ without creating NaNs or Infs"""
+ arg_list = TosaArgGen._add_data_generators(
+ testGen,
+ opName,
+ dtype,
+ [("", {"num_test_sets": 3})],
+ error_name,
+ )
+ # Return list of tuples: (arg_str, args_dict)
+ return arg_list
+
+ @staticmethod
def agAxis(testGen, opName, shapeList, dtype, error_name=None):
"""Build the axis argument for operators that take a single axis"""
arg_list = []
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 7b44ced..63958a9 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -37,7 +37,7 @@ class TosaTestGen:
TOSA_8K_LEVEL_MAX_STRIDE = 8192
# Main compliance dot product statistical test range
- TOSA_MI_DOT_PRODUCT_TEST_SETS = range(0, 6)
+ TOSA_MI_DOT_PRODUCT_TEST_SETS = 6
TOSA_MI_DOT_PRODUCT_MIN = 1000
def __init__(self, args):
@@ -3074,7 +3074,7 @@ class TosaTestGen:
"build_fcn": (
build_fully_connected,
TosaTensorGen.tgFullyConnected,
- TosaTensorValuesGen.tvgLazyGenDefault,
+ TosaTensorValuesGen.tvgFullyConnected,
TosaArgGen.agFullyConnected,
),
"qgen": TosaQuantGen.qgConv,
@@ -3562,8 +3562,8 @@ class TosaTestGen:
"build_fcn": (
build_binary_broadcast,
TosaTensorGen.tgBroadcastFuzz,
- TosaTensorValuesGen.tvgLazyGenDefault,
- TosaArgGen.agNone,
+ TosaTensorValuesGen.tvgPow,
+ TosaArgGen.agPow,
),
"types": TYPE_FP,
"error_if_validators": (
@@ -3705,7 +3705,7 @@ class TosaTestGen:
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
- TosaTensorValuesGen.tvgLazyGenDefault,
+ TosaTensorValuesGen.tvgExp,
TosaArgGen.agNone,
),
"types": TYPE_FP,
@@ -3746,7 +3746,7 @@ class TosaTestGen:
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
- TosaTensorValuesGen.tvgLazyGenDefault,
+ TosaTensorValuesGen.tvgLogRsqrt,
TosaArgGen.agNone,
),
"types": TYPE_FP,
@@ -3828,7 +3828,7 @@ class TosaTestGen:
"build_fcn": (
build_unary,
TosaTensorGen.tgBasic,
- TosaTensorValuesGen.tvgLazyGenDefault,
+ TosaTensorValuesGen.tvgLogRsqrt,
TosaArgGen.agNone,
),
"types": TYPE_FP,