diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2022-03-23 11:59:22 +0000 |
---|---|---|
committer | Jeremy Johnson <jeremy.johnson@arm.com> | 2022-03-23 13:28:02 +0000 |
commit | 25669b31bae45b16d4e96ec13fa9cdeb417975f6 (patch) | |
tree | 4d27458e21d4944c45048d30bfdad6fce3027bb6 /verif/generator | |
parent | b7af461fa2e8712e762f417e9e7a3e8db9206ff7 (diff) | |
download | reference_model-25669b31bae45b16d4e96ec13fa9cdeb417975f6.tar.gz |
Improve EQUAL tests to have matching numbers
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: If785fdaab3026eca5a31888115fba8a6750e0460
Diffstat (limited to 'verif/generator')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 9361ccf..83081ee 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -5999,6 +5999,40 @@ class TosaTestGen: self.ser.addPlaceholder(shapeList[1], dtypeList[1], shift_arr) ) tens.extend(placeholders) + elif op["op"] == Op.EQUAL and error_name is None: + assert ( + pCount == 2 and cCount == 0 + ), "Op.EQUAL must have 2 placeholders, 0 consts" + a_arr = self.getRandTensor(shapeList[0], dtypeList[0]) + b_arr = self.getRandTensor(shapeList[1], dtypeList[1]) + # Using random numbers means that it will be very unlikely that + # there are any matching (equal) values, therefore force that + # there are twice the number of matching values as the tensor rank + for num in range(0, len(shapeList[0]) * 2): + a_index = [] + b_index = [] + # Choose an index in each axis for the whole shape + for axis in range(0, len(shapeList[0])): + # Index can be up to the largest dimension in both shapes + index = np.int32( + self.rng.integers( + 0, max(shapeList[0][axis], shapeList[1][axis]) + ) + ) + # Reduce the index down to a shape's dim for broadcasting + a_index.append(min(shapeList[0][axis] - 1, index)) + b_index.append(min(shapeList[1][axis] - 1, index)) + + a_arr[tuple(a_index)] = b_arr[tuple(b_index)] + + placeholders = [] + placeholders.append( + self.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr) + ) + placeholders.append( + self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr) + ) + tens.extend(placeholders) else: tens.extend( self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount]) |