aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-12-07 14:17:57 +0000
committerEric Kunze <eric.kunze@arm.com>2023-12-11 15:01:06 +0000
commit194fe314695bdfeba5b12b837b70f392db91995b (patch)
tree86c8933f3dcf05f1eff94f6edfacd7666b28192f
parentaba79525d1348e0d964de22cef445089efaf3126 (diff)
downloadreference_model-194fe314695bdfeba5b12b837b70f392db91995b.tar.gz
Enforce no output rewrite REQUIRE in SCATTER
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I3555e7216d403d436bf6e39d4b16bb000645c4bb
-rw-r--r--reference_model/src/ops/scatter_gather.cc22
-rw-r--r--verif/generator/tosa_arg_gen.py12
-rw-r--r--verif/generator/tosa_test_gen.py24
3 files changed, 37 insertions, 21 deletions
diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc
index 65d61b6..bd16ad1 100644
--- a/reference_model/src/ops/scatter_gather.cc
+++ b/reference_model/src/ops/scatter_gather.cc
@@ -133,25 +133,25 @@ int OpScatter<Dtype>::checkTensorAttributes()
if (inputs[0]->getRank() != 3)
{
- printNodeValidationError("OpGather: values_in needs to be rank 3 tensor");
+ printNodeValidationError("OpScatter: values_in needs to be rank 3 tensor");
return 1;
}
if (inputs[1]->getRank() != 2)
{
- printNodeValidationError("OpGather: indices needs to be rank 2 tensor");
+ printNodeValidationError("OpScatter: indices needs to be rank 2 tensor");
return 1;
}
if (inputs[2]->getRank() != 3)
{
- printNodeValidationError("OpGather: input needs to be rank 3 tensor");
+ printNodeValidationError("OpScatter: input needs to be rank 3 tensor");
return 1;
}
if (outputs[0]->getRank() != 3)
{
- printNodeValidationError("OpGather: values_out needs to be rank 3 tensor");
+ printNodeValidationError("OpScatter: values_out needs to be rank 3 tensor");
return 1;
}
@@ -168,13 +168,13 @@ int OpScatter<Dtype>::checkTensorAttributes()
if (W != inputs[1]->getShape()[1])
{
- printNodeValidationError("OpGather: dimension W mismatch");
+ printNodeValidationError("OpScatter: dimension W mismatch");
return 1;
}
if (C != inputs[0]->getShape()[2] || C != inputs[2]->getShape()[2])
{
- printNodeValidationError("OpGather: dimension C mismatch");
+ printNodeValidationError("OpScatter: dimension C mismatch");
return 1;
}
@@ -201,6 +201,12 @@ int OpScatter<Dtype>::eval()
// Initializes the output tensor with the input value for values that are unchanged by the scatter operation.
this->values_out->getTensor() = this->values_in->getTensor();
+ // Create array to check for double modification of output
+ std::array<Eigen::DenseIndex, 3> arrshape;
+ std::copy_n(outputs[0]->getShape().begin(), 3, arrshape.begin());
+ Eigen::Tensor<bool, 3> output_modified(arrshape);
+ output_modified.setZero();
+
for (int n = 0; n < N; n++)
{
for (int w = 0; w < W; w++)
@@ -209,8 +215,12 @@ int OpScatter<Dtype>::eval()
REQUIRE(k >= 0 && k < K, "OpScatter: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K);
for (int c = 0; c < C; c++)
{
+ REQUIRE(output_modified(n, k, c) == false,
+ "OpScatter: output index(%d, %d, %d) written to more than once %d", n, w, c,
+ output_modified(n, k, c));
EigenType value = this->input->getTensor()(n, w, c);
this->values_out->getTensor()(n, k, c) = value;
+ output_modified(n, k, c) = true;
}
}
}
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 193da73..35253e0 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -213,17 +213,17 @@ class TosaTensorGen:
assert rank == 3
values_in_shape = testGen.makeShape(rank)
+ K = values_in_shape[1]
# ignore max batch size if target shape is set
if testGen.args.max_batch_size and not testGen.args.target_shapes:
values_in_shape[0] = min(values_in_shape[0], testGen.args.max_batch_size)
- W = testGen.randInt(
- testGen.args.tensor_shape_range[0], testGen.args.tensor_shape_range[1]
- )
- # Constrict W if one dimension is too large to keep tensor size reasonable
- if max(values_in_shape) > 5000:
- W = testGen.randInt(0, 16)
+ # Make sure W is not greater than K, as we can only write each output index
+ # once (having a W greater than K means that you have to repeat a K index)
+ W_min = min(testGen.args.tensor_shape_range[0], K)
+ W_max = min(testGen.args.tensor_shape_range[1], K)
+ W = testGen.randInt(W_min, W_max) if W_min < W_max else W_min
input_shape = [values_in_shape[0], W, values_in_shape[2]]
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index ba10dcf..53b0b75 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -1771,22 +1771,28 @@ class TosaTestGen:
def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
- # Create a new indicies tensor
- # here with data that doesn't exceed the dimensions of the values_in tensor
-
K = values_in.shape[1] # K
W = input.shape[1] # W
- indicies_arr = np.int32(
- self.rng.integers(low=0, high=K, size=[values_in.shape[0], W])
- ) # (N, W)
- indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
+
+ # Create an indices tensor here with data that doesn't exceed the
+ # dimension K of the values_in tensor and does NOT repeat the same K
+ # location as needed by the spec:
+ # "It is not permitted to repeat the same output index within a single
+ # SCATTER operation and so each output index occurs at most once."
+ assert K >= W
+ arr = []
+ for n in range(values_in.shape[0]):
+ # Get a shuffled list of output indices and limit it to size W
+ arr.append(self.rng.permutation(K)[:W])
+ indices_arr = np.array(arr, dtype=np.int32) # (N, W)
+ indices = self.ser.addConst(indices_arr.shape, DType.INT32, indices_arr)
result_tens = OutputShaper.scatterOp(
- self.ser, self.rng, values_in, indicies, input, error_name
+ self.ser, self.rng, values_in, indices, input, error_name
)
# Invalidate Input/Output list for error if checks.
- input_list = [values_in.name, indicies.name, input.name]
+ input_list = [values_in.name, indices.name, input.name]
output_list = [result_tens.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount