aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-02-08 11:45:44 +0000
committerJeremy Johnson <jeremy.johnson@arm.com>2024-02-12 12:17:53 +0000
commit587cc84c2b8c4b0d030b5e257c9a32461c0969b9 (patch)
tree2465605214e9723636bad2645508300421967052
parent01e1c1c7f965ceb07e78a3b1ad063161c0f47b94 (diff)
downloadreference_model-587cc84c2b8c4b0d030b5e257c9a32461c0969b9.tar.gz
Update test builder internal interfaces
Move remaining ops from using testArgs to argsDict. All tvg/build_fcn function interfaces updated. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ie886fd931bd74608bda621363100bf8bfd7385e6
-rw-r--r--verif/generator/tosa_arg_gen.py117
-rw-r--r--verif/generator/tosa_test_gen.py237
-rw-r--r--verif/generator/tosa_verif_build_tests.py4
3 files changed, 220 insertions, 138 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 33e74b5..7ec0cfe 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -636,18 +636,6 @@ class TosaTensorValuesGen:
self.tensorList = tensorList
self.dataGenDict = dataGenDict
- @staticmethod
- def tvgDefault(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
- pCount, cCount = op["operands"]
-
- tens = []
- tens.extend(
- testGen.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
- )
- tens.extend(testGen.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
-
- return tens
-
# Default high value for random numbers
TVG_FLOAT_HIGH_VALUE = {
DType.FP32: (1 << 128) - (1 << (127 - 23)),
@@ -969,7 +957,7 @@ class TosaTensorValuesGen:
@staticmethod
def tvgCondIfWhileLoop(
- testGen, op, dtypeList, shapeList, testArgs, error_name=None
+ testGen, opName, dtypeList, shapeList, argsDict, error_name=None
):
if dtypeList[0] in (
DType.INT32,
@@ -979,9 +967,10 @@ class TosaTensorValuesGen:
# Limit input tensors with cond_if_binary or while_loop to stop
# saturation of add/sub ops with int32 and keep all logical shift
# values between 0 to 31 for int16 or int8
+ op = testGen.TOSA_OP_LIST[opName]
pCount, cCount = op["operands"]
pRemain = pCount
- placeholders = []
+ tens_ser_list = []
for idx, shape in enumerate(shapeList[:]):
if dtypeList[0] == DType.INT32:
arr = testGen.getRandTensor(shapeList[idx], DType.INT16)
@@ -990,32 +979,33 @@ class TosaTensorValuesGen:
testGen.rng.integers(low=0, high=32, size=shapeList[idx])
)
if pRemain > 0:
- placeholders.append(
+ tens_ser_list.append(
testGen.ser.addPlaceholder(shape, dtypeList[idx], arr)
)
pRemain -= 1
else:
- placeholders.append(
+ tens_ser_list.append(
testGen.ser.addConst(shape, dtypeList[idx], arr)
)
- return placeholders
+ return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
else:
- return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, error_name
+ return TosaTensorValuesGen.tvgLazyGenDefault(
+ testGen, opName, dtypeList, shapeList, argsDict, error_name
)
@staticmethod
def tvgArithmeticRightShift(
- testGen, op, dtypeList, shapeList, testArgs, error_name=None
+ testGen, opName, dtypeList, shapeList, argsDict, error_name=None
):
+ op = testGen.TOSA_OP_LIST[opName]
pCount, cCount = op["operands"]
# Force value of operand[1] to be within [0, num_bits]
assert (
pCount == 2 and cCount == 0
), "Op.ArithmeticRightShift must have 2 placeholders, 0 consts"
- placeholders = []
+ tens_ser_list = []
for idx, shape in enumerate(shapeList[:]):
if idx == 1:
if dtypeList[idx] == DType.INT8:
@@ -1030,23 +1020,23 @@ class TosaTensorValuesGen:
raise Exception("OpArithmeticRightShift: invalid input dtype")
else:
arr = testGen.getRandTensor(shape, dtypeList[idx])
- placeholders.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
+ tens_ser_list.append(testGen.ser.addPlaceholder(shape, dtypeList[idx], arr))
- return placeholders
+ return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
@staticmethod
- def tvgReshape(testGen, op, dtypeList, shapeList, argsDict, error_name=None):
+ def tvgReshape(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
dtypeList[1] = DType.SHAPE
shapeList[1] = [len(argsDict["new_shape"])]
# Create a new list for the pre-generated data in argsDict["fixed_data"]
argsDict["fixed_data"] = [None, argsDict["new_shape"]]
return TosaTensorValuesGen.tvgLazyGenDefault(
- testGen, op, dtypeList, shapeList, argsDict, error_name
+ testGen, opName, dtypeList, shapeList, argsDict, error_name
)
@staticmethod
- def tvgPad(testGen, op, dtypeList, shapeList, argsDict, error_name=None):
+ def tvgPad(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
# argsDict["pad"] is 2D array, need to flatten it to get list of values
pad_values = argsDict["pad"].flatten()
dtypeList[1] = DType.SHAPE
@@ -1055,11 +1045,11 @@ class TosaTensorValuesGen:
argsDict["fixed_data"] = [None, pad_values]
return TosaTensorValuesGen.tvgLazyGenDefault(
- testGen, op, dtypeList, shapeList, argsDict, error_name
+ testGen, opName, dtypeList, shapeList, argsDict, error_name
)
@staticmethod
- def tvgSlice(testGen, op, dtypeList, shapeList, argsDict, error_name=None):
+ def tvgSlice(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
dtypeList[1] = DType.SHAPE
shapeList[1] = [len(argsDict["start"])]
dtypeList[2] = DType.SHAPE
@@ -1068,17 +1058,17 @@ class TosaTensorValuesGen:
argsDict["fixed_data"] = [None, argsDict["start"], argsDict["size"]]
return TosaTensorValuesGen.tvgLazyGenDefault(
- testGen, op, dtypeList, shapeList, argsDict, error_name
+ testGen, opName, dtypeList, shapeList, argsDict, error_name
)
@staticmethod
- def tvgTile(testGen, op, dtypeList, shapeList, argsDict, error_name=None):
+ def tvgTile(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
dtypeList[1] = DType.SHAPE
shapeList[1] = [len(argsDict["multiples"])]
argsDict["fixed_data"] = [None, argsDict["multiples"]]
return TosaTensorValuesGen.tvgLazyGenDefault(
- testGen, op, dtypeList, shapeList, argsDict, error_name
+ testGen, opName, dtypeList, shapeList, argsDict, error_name
)
@staticmethod
@@ -2776,10 +2766,23 @@ class TosaArgGen:
int(double_round),
int(per_channel),
),
- [outDtype, scale32, double_round, per_channel],
+ {
+ "output_dtype": outDtype,
+ "scale": scale32,
+ "double_round": double_round,
+ "per_channel": per_channel,
+ },
)
)
+ arg_list = TosaArgGen._add_data_generators(
+ testGen,
+ opName,
+ inDtype,
+ arg_list,
+ error_name,
+ )
+ # Return list of tuples: (arg_str, args_dict)
return arg_list
@staticmethod
@@ -2808,9 +2811,20 @@ class TosaArgGen:
def agArithmeticRightShift(testGen, opName, shapeList, dtype, error_name=None):
arg_list = []
- arg_list.append(("roundTrue", [True]))
- arg_list.append(("roundFalse", [False]))
+ for round in (True, False):
+ args_dict = {
+ "round": round,
+ }
+ arg_list.append((f"round{round}", args_dict))
+ arg_list = TosaArgGen._add_data_generators(
+ testGen,
+ opName,
+ dtype,
+ arg_list,
+ error_name,
+ )
+ # Return list of tuples: (arg_str, args_dict)
return arg_list
@staticmethod
@@ -3414,9 +3428,18 @@ class TosaArgGen:
arg_list.append(
(
"",
- [table],
+ {"table": table},
)
)
+ # Now add data generator types
+ arg_list = TosaArgGen._add_data_generators(
+ testGen,
+ opName,
+ dtype,
+ arg_list,
+ error_name,
+ )
+ # Return list of tuples: (arg_str, args_dict)
return arg_list
def agCondIf(testGen, opName, shapeList, dtype, error_name=None):
@@ -3426,15 +3449,33 @@ class TosaArgGen:
arg_list = []
for c in [False, True]:
- arg_list.append(("cond{}".format(int(c)), [c]))
+ arg_list.append(("cond{}".format(int(c)), {"condition": c}))
+ # Now add data generator types
+ arg_list = TosaArgGen._add_data_generators(
+ testGen,
+ opName,
+ dtype,
+ arg_list,
+ error_name,
+ )
+ # Return list of tuples: (arg_str, args_dict)
return arg_list
def agWhileLoop(testGen, opName, shapeList, dtype, error_name=None):
# While loop: 0 iterations, 1, more than 1
arg_list = []
- for iter in [0, 1, 4]:
- arg_list.append(("iter{}".format(iter), [iter]))
+ for iterations in [0, 1, 4]:
+ arg_list.append(("iter{}".format(iterations), {"iterations": iterations}))
+ # Now add data generator types
+ arg_list = TosaArgGen._add_data_generators(
+ testGen,
+ opName,
+ dtype,
+ arg_list,
+ error_name,
+ )
+ # Return list of tuples: (arg_str, args_dict)
return arg_list
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 2d471c0..4ead982 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -519,15 +519,18 @@ class TosaTestGen:
return result_tens
def build_arithmetic_right_shift(
- self, op, a, b, round, validator_fcns=None, error_name=None
+ self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
):
- result_tens = OutputShaper.binaryBroadcastOp(
+ assert len(inputs) == 2
+ a, b = inputs
+ round = args_dict["round"]
+ result_tensor = OutputShaper.binaryBroadcastOp(
self.ser, self.rng, a, b, error_name
)
# Invalidate Input/Output list for error if checks.
input_list = [a.name, b.name]
- output_list = [result_tens.name]
+ output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
@@ -542,8 +545,8 @@ class TosaTestGen:
input1=a,
input2=b,
input_dtype=a.dtype,
- output_dtype=result_tens.dtype,
- result_tensors=[result_tens],
+ output_dtype=result_tensor.dtype,
+ result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -554,7 +557,12 @@ class TosaTestGen:
attr.ArithmeticRightShiftAttribute(round)
self.ser.addOperator(op["op"], input_list, output_list, attr)
- return result_tens
+
+ compliance = self.tensorComplianceMetaData(
+ op, a.dtype, args_dict, result_tensor, error_name
+ )
+
+ return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_mul(
self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
@@ -612,15 +620,20 @@ class TosaTestGen:
return TosaTestGen.BuildInfo(result_tensor, compliance)
- def build_table(self, op, a, table, validator_fcns=None, error_name=None):
- result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
+ def build_table(
+ self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
+ ):
+ assert len(inputs) == 1
+ a = inputs[0]
+ table = args_dict["table"]
+ result_tensor = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
attr = ts.TosaSerializerAttribute()
attr.TableAttribute(table)
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
- output_list = [result_tens.name]
+ output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
@@ -634,8 +647,8 @@ class TosaTestGen:
op=op,
input_shape=a.shape,
input_dtype=a.dtype,
- output_dtype=result_tens.dtype,
- result_tensors=[result_tens],
+ output_dtype=result_tensor.dtype,
+ result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -644,7 +657,11 @@ class TosaTestGen:
self.ser.addOperator(op["op"], input_list, output_list, attr)
- return result_tens
+ compliance = self.tensorComplianceMetaData(
+ op, a.dtype, args_dict, result_tensor, error_name
+ )
+
+ return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_select(
self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
@@ -2075,15 +2092,20 @@ class TosaTestGen:
def build_rescale(
self,
op,
- val,
- out_dtype,
- scale32,
- double_round,
- per_channel,
- validator_fcns,
- error_name,
+ inputs,
+ args_dict,
+ validator_fcns=None,
+ error_name=None,
+ qinfo=None,
):
- result_tens = OutputShaper.typeConversionOp(
+ assert len(inputs) == 1
+ val = inputs[0]
+ out_dtype = args_dict["output_dtype"]
+ scale32 = args_dict["scale"]
+ double_round = args_dict["double_round"]
+ per_channel = args_dict["per_channel"]
+
+ result_tensor = OutputShaper.typeConversionOp(
self.ser, self.rng, val, out_dtype, error_name
)
@@ -2203,7 +2225,7 @@ class TosaTestGen:
# Invalidate Input/Output list for error if checks.
input_list = [val.name]
- output_list = [result_tens.name]
+ output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
@@ -2224,7 +2246,7 @@ class TosaTestGen:
double_round=double_round,
input_list=input_list,
output_list=output_list,
- result_tensors=[result_tens],
+ result_tensors=[result_tensor],
num_operands=num_operands,
):
return None
@@ -2243,7 +2265,12 @@ class TosaTestGen:
)
self.ser.addOperator(op["op"], input_list, output_list, attr)
- return result_tens
+
+ compliance = self.tensorComplianceMetaData(
+ op, val.dtype, args_dict, result_tensor, error_name
+ )
+
+ return TosaTestGen.BuildInfo(result_tensor, compliance)
def _get_condition_tensor(self, op, cond, error_name):
if error_name == ErrorIf.CondIfCondNotMatchingBool:
@@ -2263,11 +2290,21 @@ class TosaTestGen:
return cond_tens
def build_cond_if_const(
- self, op, then_tens, else_tens, cond, validator_fcns=None, error_name=None
+ self,
+ op,
+ inputs,
+ args_dict,
+ validator_fcns=None,
+ error_name=None,
+ qinfo=None,
):
# For cond_if with constants, we're supplied with then/else tensors that we ignore
# (except for the generated shape) and the condition. Build Then/Else blocks
# and fill them with const nodes for the body.
+ assert len(inputs) == 2
+ then_tens, else_tens = inputs
+
+ cond = args_dict["condition"]
# Condition tensor
cond_tens = self._get_condition_tensor(op, cond, error_name)
@@ -2275,6 +2312,8 @@ class TosaTestGen:
# Make then/else tensors
out_shape = then_tens.shape
+ dtype = DType.INT32
+
# Create an incorrect output shape for error_if tests
if error_name in [
ErrorIf.CondIfOutputListThenGraphMismatch,
@@ -2293,7 +2332,7 @@ class TosaTestGen:
else_arr = np.int32(self.rng.integers(0, 256, size=out_shape))
# And the result tensor based on any of the outputs
- result_tens = self.ser.addOutput(out_shape, DType.INT32)
+ result_tensor = self.ser.addOutput(out_shape, dtype)
# Create the attribute with the names of the then/else blocks
then_block = "THEN_BLOCK"
@@ -2302,21 +2341,21 @@ class TosaTestGen:
attr.CondIfAttribute(then_block, else_block)
# Finally, build the op and the two blocks
- self.ser.addOperator(op["op"], [cond_tens.name], [result_tens.name], attr)
+ self.ser.addOperator(op["op"], [cond_tens.name], [result_tensor.name], attr)
self.ser.addBasicBlock(then_block)
# Build the actual then/else tensors inside their blocks
if error_name == ErrorIf.CondIfOutputListThenGraphMismatch:
- then_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
+ then_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
else:
- then_tens = self.ser.addConst(out_shape, DType.INT32, then_arr)
+ then_tens = self.ser.addConst(out_shape, dtype, then_arr)
self.ser.addOutputTensor(then_tens)
self.ser.addBasicBlock(else_block)
if error_name == ErrorIf.CondIfOutputListElseGraphMismatch:
- else_tens = self.ser.addConst(incorrect_shape, DType.INT32, incorrect_arr)
+ else_tens = self.ser.addConst(incorrect_shape, dtype, incorrect_arr)
else:
- else_tens = self.ser.addConst(out_shape, DType.INT32, else_arr)
+ else_tens = self.ser.addConst(out_shape, dtype, else_arr)
self.ser.addOutputTensor(else_tens)
if not TosaErrorValidator.evValidateErrorIfs(
@@ -2329,18 +2368,32 @@ class TosaTestGen:
):
return None
- return result_tens
+ compliance = self.tensorComplianceMetaData(
+ op, dtype, args_dict, result_tensor, error_name
+ )
+
+ return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_cond_if_binary(
- self, op, a, b, cond, validator_fcns=None, error_name=None
+ self,
+ op,
+ inputs,
+ args_dict,
+ validator_fcns=None,
+ error_name=None,
+ qinfo=None,
):
# For cond_if with a binary op in the then/else blocks, take a and b and
# alternately add or subtract them based on the condition
+ assert len(inputs) == 2
+ a, b = inputs
+
+ cond = args_dict["condition"]
# Condition tensor
cond_tens = self._get_condition_tensor(op, cond, error_name)
- result_tens = self.ser.addOutput(a.shape, a.dtype)
+ result_tensor = self.ser.addOutput(a.shape, a.dtype)
# Create the attribute with the names of the then/else blocks
then_block = "THEN_BLOCK"
@@ -2362,17 +2415,24 @@ class TosaTestGen:
# Finally, build the op and the two blocks
self.ser.addOperator(
- op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
+ op["op"], [cond_tens.name, a.name, b.name], [result_tensor.name], attr
)
if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
- then_op, else_op = Op.ADD, Op.SUB
+ then_op, else_op = self.TOSA_OP_LIST["add"], self.TOSA_OP_LIST["sub"]
elif a.dtype in (DType.INT8, DType.INT16):
- then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
+ then_op, else_op = (
+ self.TOSA_OP_LIST["logical_right_shift"],
+ self.TOSA_OP_LIST["logical_left_shift"],
+ )
else:
assert False, f"No tests for DType: {a.dtype}"
- for block, op in ((then_block, then_op), (else_block, else_op)):
+ # Determine the element-wise binary operation that compliance will need to
+ # check the results of
+ compliance_op = then_op if cond else else_op
+
+ for block, block_op in ((then_block, then_op), (else_block, else_op)):
self.ser.addBasicBlock(block)
if (
error_name == ErrorIf.CondIfInputListThenGraphMismatch
@@ -2398,7 +2458,7 @@ class TosaTestGen:
self.ser.addInputTensor(a)
self.ser.addInputTensor(b)
tens = self.ser.addOutput(a.shape, a.dtype)
- self.ser.addOperator(op, [a.name, b.name], [tens.name])
+ self.ser.addOperator(block_op["op"], [a.name, b.name], [tens.name])
if not TosaErrorValidator.evValidateErrorIfs(
self.ser,
@@ -2412,9 +2472,19 @@ class TosaTestGen:
):
return None
- return result_tens
+ compliance = self.tensorComplianceMetaData(
+ compliance_op, a.dtype, args_dict, result_tensor, error_name
+ )
+
+ return TosaTestGen.BuildInfo(result_tensor, compliance)
+
+ def build_while_loop(
+ self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
+ ):
+ assert len(inputs) == 1
+ a = inputs[0]
+ iter_val = args_dict["iterations"]
- def build_while_loop(self, op, a, iter_val, validator_fcns=None, error_name=None):
iter = self.ser.addPlaceholder([], DType.INT32, [np.int32(iter_val)])
cond_block = "COND_BLOCK"
@@ -2533,7 +2603,11 @@ class TosaTestGen:
):
return None
- return acc_out
+ compliance = self.tensorComplianceMetaData(
+ op, a.dtype, args_dict, acc_out, error_name
+ )
+
+ return TosaTestGen.BuildInfo(acc_out, compliance)
def build_fft2d(
self,
@@ -2891,7 +2965,7 @@ class TosaTestGen:
return testList
def serializeTest(
- self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs
+ self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, argsDict
):
try:
op = self.TOSA_OP_LIST[opName]
@@ -2947,60 +3021,27 @@ class TosaTestGen:
# Extra meta data for the desc.json
tensMeta = {}
- # Check we are using the new testArgs interface with an argsDict dictionary
- if isinstance(testArgs, dict):
- # New interface with args info in dictionary
- argsDict = testArgs
- assert "dg_type" in argsDict
- tvgInfo = tvgen_fcn(
- self, opName, dtypeList, shapeList, argsDict, error_name
- )
- if tvgInfo.dataGenDict:
- tensMeta["data_gen"] = tvgInfo.dataGenDict
- tens = tvgInfo.tensorList
-
- result = build_fcn(
- self,
- op,
- tens,
- argsDict,
- validator_fcns=error_if_validators,
- error_name=error_name,
- qinfo=qinfo,
- )
- else:
- # Old interface with args info in a list
- tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
-
- try:
- if error_if_validators is None:
- if qinfo is not None:
- result = build_fcn(self, op, *tens, *testArgs, qinfo)
- else:
- result = build_fcn(self, op, *tens, *testArgs)
- else:
- if qinfo is not None:
- result = build_fcn(
- self,
- op,
- *tens,
- *testArgs,
- validator_fcns=error_if_validators,
- error_name=error_name,
- qinfo=qinfo,
- )
- else:
- result = build_fcn(
- self,
- op,
- *tens,
- *testArgs,
- validator_fcns=error_if_validators,
- error_name=error_name,
- )
- except TypeError as e:
- print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
- raise e
+ # Check we are using the new interface with an argsDict dictionary
+ assert isinstance(
+ argsDict, dict
+ ), f"{opName} is not using new tvg/build_fcn interface"
+
+ # New interface with args info in dictionary
+ assert "dg_type" in argsDict
+ tvgInfo = tvgen_fcn(self, opName, dtypeList, shapeList, argsDict, error_name)
+ if tvgInfo.dataGenDict:
+ tensMeta["data_gen"] = tvgInfo.dataGenDict
+ tens = tvgInfo.tensorList
+
+ result = build_fcn(
+ self,
+ op,
+ tens,
+ argsDict,
+ validator_fcns=error_if_validators,
+ error_name=error_name,
+ qinfo=qinfo,
+ )
if result:
# The test is valid, serialize it
@@ -3874,7 +3915,7 @@ class TosaTestGen:
"build_fcn": (
build_table,
TosaTensorGen.tgBasic,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agTable,
),
"types": [DType.INT8, DType.INT16],
@@ -4686,7 +4727,7 @@ class TosaTestGen:
"build_fcn": (
build_rescale,
TosaTensorGen.tgBasic,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agRescale,
),
"types": [
diff --git a/verif/generator/tosa_verif_build_tests.py b/verif/generator/tosa_verif_build_tests.py
index d01e8a7..8012d93 100644
--- a/verif/generator/tosa_verif_build_tests.py
+++ b/verif/generator/tosa_verif_build_tests.py
@@ -321,7 +321,7 @@ def main(argv=None):
testStrings = []
try:
- for opName, testStr, dtype, error, shapeList, testArgs in testList:
+ for opName, testStr, dtype, error, shapeList, argsDict in testList:
# Check for and skip duplicate tests
if testStr in testStrings:
print(f"Skipping duplicate test: {testStr}")
@@ -331,7 +331,7 @@ def main(argv=None):
results.append(
ttg.serializeTest(
- opName, testStr, dtype, error, shapeList, testArgs
+ opName, testStr, dtype, error, shapeList, argsDict
)
)
except Exception as e: