diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2023-12-07 14:17:57 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2023-12-11 15:01:06 +0000 |
commit | 194fe314695bdfeba5b12b837b70f392db91995b (patch) | |
tree | 86c8933f3dcf05f1eff94f6edfacd7666b28192f /verif/generator/tosa_test_gen.py | |
parent | aba79525d1348e0d964de22cef445089efaf3126 (diff) | |
download | reference_model-194fe314695bdfeba5b12b837b70f392db91995b.tar.gz |
Enforce no output rewrite REQUIRE in SCATTER
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: I3555e7216d403d436bf6e39d4b16bb000645c4bb
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 24 |
1 files changed, 15 insertions, 9 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index ba10dcf..53b0b75 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -1771,22 +1771,28 @@ class TosaTestGen: def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None): - # Create a new indicies tensor - # here with data that doesn't exceed the dimensions of the values_in tensor - K = values_in.shape[1] # K W = input.shape[1] # W - indicies_arr = np.int32( - self.rng.integers(low=0, high=K, size=[values_in.shape[0], W]) - ) # (N, W) - indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr) + + # Create an indices tensor here with data that doesn't exceed the + # dimension K of the values_in tensor and does NOT repeat the same K + # location as needed by the spec: + # "It is not permitted to repeat the same output index within a single + # SCATTER operation and so each output index occurs at most once." + assert K >= W + arr = [] + for n in range(values_in.shape[0]): + # Get a shuffled list of output indices and limit it to size W + arr.append(self.rng.permutation(K)[:W]) + indices_arr = np.array(arr, dtype=np.int32) # (N, W) + indices = self.ser.addConst(indices_arr.shape, DType.INT32, indices_arr) result_tens = OutputShaper.scatterOp( - self.ser, self.rng, values_in, indicies, input, error_name + self.ser, self.rng, values_in, indices, input, error_name ) # Invalidate Input/Output list for error if checks. - input_list = [values_in.name, indicies.name, input.name] + input_list = [values_in.name, indices.name, input.name] output_list = [result_tens.name] pCount, cCount = op["operands"] num_operands = pCount + cCount |