aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-11-15 15:52:06 +0000
committerJeremy Johnson <jeremy.johnson@arm.com>2023-11-23 14:09:14 +0000
commita015001dfbd0ed48caf54fd66b0509ee344a229e (patch)
tree5f99a7d2d4aba2db2e672efb1168db961f99a544 /verif/generator/tosa_arg_gen.py
parent0bbd8bcfb20ec834f18d0bb89fc69ba4e92b3019 (diff)
downloadreference_model-a015001dfbd0ed48caf54fd66b0509ee344a229e.tar.gz
Main Compliance testing support for COMPARISON ops
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Id6229cfaccad866b110630119eb045dbf6453bf5
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py21
1 files changed, 13 insertions, 8 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 6675025..9147605 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1119,14 +1119,18 @@ class TosaTensorValuesGen:
return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
@staticmethod
- def tvgEqual(testGen, op, dtypeList, shapeList, testArgs, error_name=None):
- if error_name is None:
+ def tvgEqual(testGen, opName, dtypeList, shapeList, argsDict, error_name=None):
+ if error_name is None and not gtu.dtypeIsSupportedByCompliance(dtypeList[0]):
+ # Integer
+ op = testGen.TOSA_OP_LIST[opName]
pCount, cCount = op["operands"]
assert (
pCount == 2 and cCount == 0
), "Op.EQUAL must have 2 placeholders, 0 consts"
+
a_arr = testGen.getRandTensor(shapeList[0], dtypeList[0])
b_arr = testGen.getRandTensor(shapeList[1], dtypeList[1])
+
# Using random numbers means that it will be very unlikely that
# there are any matching (equal) values, therefore force that
# there are twice the number of matching values as the tensor rank
@@ -1147,17 +1151,18 @@ class TosaTensorValuesGen:
a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
- placeholders = []
- placeholders.append(
+ tens_ser_list = []
+ tens_ser_list.append(
testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
)
- placeholders.append(
+ tens_ser_list.append(
testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
)
- return placeholders
+ return TosaTensorValuesGen.TVGInfo(tens_ser_list, None)
else:
- return TosaTensorValuesGen.tvgDefault(
- testGen, op, dtypeList, shapeList, testArgs, error_name
+ # ERROR_IF or floating point test
+ return TosaTensorValuesGen.tvgLazyGenDefault(
+ testGen, opName, dtypeList, shapeList, argsDict, error_name
)
@staticmethod