diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2023-12-07 14:17:57 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2023-12-11 15:01:06 +0000 |
commit | 194fe314695bdfeba5b12b837b70f392db91995b (patch) | |
tree | 86c8933f3dcf05f1eff94f6edfacd7666b28192f /reference_model/src/ops | |
parent | aba79525d1348e0d964de22cef445089efaf3126 (diff) | |
download | reference_model-194fe314695bdfeba5b12b837b70f392db91995b.tar.gz |
Enforce no output rewrite REQUIRE in SCATTER
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: I3555e7216d403d436bf6e39d4b16bb000645c4bb
Diffstat (limited to 'reference_model/src/ops')
-rw-r--r-- | reference_model/src/ops/scatter_gather.cc | 22 |
1 files changed, 16 insertions, 6 deletions
diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc index 65d61b6..bd16ad1 100644 --- a/reference_model/src/ops/scatter_gather.cc +++ b/reference_model/src/ops/scatter_gather.cc @@ -133,25 +133,25 @@ int OpScatter<Dtype>::checkTensorAttributes() if (inputs[0]->getRank() != 3) { - printNodeValidationError("OpGather: values_in needs to be rank 3 tensor"); + printNodeValidationError("OpScatter: values_in needs to be rank 3 tensor"); return 1; } if (inputs[1]->getRank() != 2) { - printNodeValidationError("OpGather: indices needs to be rank 2 tensor"); + printNodeValidationError("OpScatter: indices needs to be rank 2 tensor"); return 1; } if (inputs[2]->getRank() != 3) { - printNodeValidationError("OpGather: input needs to be rank 3 tensor"); + printNodeValidationError("OpScatter: input needs to be rank 3 tensor"); return 1; } if (outputs[0]->getRank() != 3) { - printNodeValidationError("OpGather: values_out needs to be rank 3 tensor"); + printNodeValidationError("OpScatter: values_out needs to be rank 3 tensor"); return 1; } @@ -168,13 +168,13 @@ int OpScatter<Dtype>::checkTensorAttributes() if (W != inputs[1]->getShape()[1]) { - printNodeValidationError("OpGather: dimension W mismatch"); + printNodeValidationError("OpScatter: dimension W mismatch"); return 1; } if (C != inputs[0]->getShape()[2] || C != inputs[2]->getShape()[2]) { - printNodeValidationError("OpGather: dimension C mismatch"); + printNodeValidationError("OpScatter: dimension C mismatch"); return 1; } @@ -201,6 +201,12 @@ int OpScatter<Dtype>::eval() // Initializes the output tensor with the input value for values that are unchanged by the scatter operation. this->values_out->getTensor() = this->values_in->getTensor(); + // Create array to check for double modification of output + std::array<Eigen::DenseIndex, 3> arrshape; + std::copy_n(outputs[0]->getShape().begin(), 3, arrshape.begin()); + Eigen::Tensor<bool, 3> output_modified(arrshape); + output_modified.setZero(); + for (int n = 0; n < N; n++) { for (int w = 0; w < W; w++) @@ -209,8 +215,12 @@ int OpScatter<Dtype>::eval() REQUIRE(k >= 0 && k < K, "OpScatter: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K); for (int c = 0; c < C; c++) { + REQUIRE(output_modified(n, k, c) == false, + "OpScatter: output index(%d, %d, %d) written to more than once %d", n, w, c, + output_modified(n, k, c)); EigenType value = this->input->getTensor()(n, w, c); this->values_out->getTensor()(n, k, c) = value; + output_modified(n, k, c) = true; } } } |