aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py15
1 files changed, 9 insertions, 6 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 1f54851..454013a 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -767,8 +767,10 @@ class TosaTensorValuesGen:
return TosaTensorValuesGen.TVGInfo(tens_ser_list, tens_data)
@staticmethod
- def tvgNegate(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
+ def tvgNegate(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
if dtypeList[0] == DType.INT32 and error_name is None:
+ # Integer test
+ op = testGen.TOSA_OP_LIST[opName]
pCount, cCount = op["operands"]
assert (
pCount == 1 and cCount == 0
@@ -780,14 +782,15 @@ class TosaTensorValuesGen:
arr = np.int32(
testGen.rng.integers(low=min_val, high=(max_val + 1), size=shapeList[0])
)
- placeholders = []
- placeholders.append(
+ tens_ser_list = []
+ tens_ser_list.append(
testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], arr)
)
- return placeholders
+ return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
else:
- return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, error_name
+ # ERROR_IF or floating point test
+ return TosaTensorValuesGen.tvgLazyGenDefault(
+ testGen, opName, dtypeList, shapeList, argsDict, error_name
)
# Set the data range to half the largest value