aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2022-03-23 11:59:22 +0000
committerJeremy Johnson <jeremy.johnson@arm.com>2022-03-23 13:28:02 +0000
commit25669b31bae45b16d4e96ec13fa9cdeb417975f6 (patch)
tree4d27458e21d4944c45048d30bfdad6fce3027bb6
parentb7af461fa2e8712e762f417e9e7a3e8db9206ff7 (diff)
downloadreference_model-25669b31bae45b16d4e96ec13fa9cdeb417975f6.tar.gz
Improve EQUAL tests to have matching numbers
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: If785fdaab3026eca5a31888115fba8a6750e0460
-rw-r--r--verif/generator/tosa_test_gen.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 9361ccf..83081ee 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -5999,6 +5999,40 @@ class TosaTestGen:
self.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr)
)
tens.extend(placeholders)
+ elif op["op"] == Op.EQUAL and error_name is None:
+ assert (
+ pCount == 2 and cCount == 0
+ ), "Op.EQUAL must have 2 placeholders, 0 consts"
+ a_arr = self.getRandTensor(shapeList[0], dtypeList[0])
+ b_arr = self.getRandTensor(shapeList[1], dtypeList[1])
+ # Using random numbers means that it will be very unlikely that
+ # there are any matching (equal) values, therefore force that
+ # there are twice the number of matching values as the tensor rank
+ for num in range(0, len(shapeList[0]) * 2):
+ a_index = []
+ b_index = []
+ # Choose an index in each axis for the whole shape
+ for axis in range(0, len(shapeList[0])):
+ # Index can be up to the largest dimension in both shapes
+ index = np.int32(
+ self.rng.integers(
+ 0, max(shapeList[0][axis], shapeList[1][axis])
+ )
+ )
+ # Reduce the index down to a shape's dim for broadcasting
+ a_index.append(min(shapeList[0][axis] - 1, index))
+ b_index.append(min(shapeList[1][axis] - 1, index))
+
+ a_arr[tuple(a_index)] = b_arr[tuple(b_index)]
+
+ placeholders = []
+ placeholders.append(
+ self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
+ )
+ placeholders.append(
+ self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
+ )
+ tens.extend(placeholders)
else:
tens.extend(
self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])