aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py24
1 files changed, 15 insertions, 9 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index ba10dcf..53b0b75 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -1771,22 +1771,28 @@ class TosaTestGen:
def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
- # Create a new indicies tensor
- # here with data that doesn't exceed the dimensions of the values_in tensor
-
K = values_in.shape[1] # K
W = input.shape[1] # W
- indicies_arr = np.int32(
- self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
- ) # (N, W)
- indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
+
+ # Create an indices tensor here with data that doesn't exceed the
+ # dimension K of the values_in tensor and does NOT repeat the same K
+ # location as needed by the spec:
+ # "It is not permitted to repeat the same output index within a single
+ # SCATTER operation and so each output index occurs at most once."
+ assert K >= W
+ arr = []
+ for n in range(values_in.shape[0]):
+ # Get a shuffled list of output indices and limit it to size W
+ arr.append(self.rng.permutation(K)[:W])
+ indices_arr = np.array(arr, dtype=np.int32) # (N, W)
+ indices = self.ser.addConst(indices_arr.shape, DType.INT32, indices_arr)
result_tens = OutputShaper.scatterOp(
- self.ser, self.rng, values_in, indicies, input, error_name
+ self.ser, self.rng, values_in, indices, input, error_name
)
# Invalidate Input/Output list for error if checks.
- input_list = [values_in.name, indicies.name, input.name]
+ input_list = [values_in.name, indices.name, input.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount