aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 9361ccf..83081ee 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -5999,6 +5999,40 @@ class TosaTestGen:
self.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
)
tens.extend(placeholders)
+ elif op["op"] == Op.EQUAL and error_name is None:
+ assert (
+ pCount == 2 and cCount == 0
+ ), "Op.EQUAL must have 2 placeholders, 0 consts"
+ a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
+ b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
+ # Using random numbers means that it will be very unlikely that
+ # there are any matching (equal) values, therefore force that
+ # there are twice the number of matching values as the tensor rank
+ for num in range(0, len(shapeList[0]) * 2):
+ a_index = []
+ b_index = []
+ # Choose an index in each axis for the whole shape
+ for axis in range(0, len(shapeList[0])):
+ # Index can be up to the largest dimension in both shapes
+ index = np.int32(
+ self.rng.integers(
+ 0, max(shapeList[0][axis], shapeList[1][axis])
+ )
+ )
+ # Reduce the index down to a shape's dim for broadcasting
+ a_index.append(min(shapeList[0][axis] - 1, index))
+ b_index.append(min(shapeList[1][axis] - 1, index))
+
+ a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
+
+ placeholders = []
+ placeholders.append(
+ self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
+ )
+ placeholders.append(
+ self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
+ )
+ tens.extend(placeholders)
else:
tens.extend(
self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])