aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/scatter_gather.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/scatter_gather.cc')
-rw-r--r--reference_model/src/ops/scatter_gather.cc22
1 files changed, 16 insertions, 6 deletions
diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc
index 65d61b6..bd16ad1 100644
--- a/reference_model/src/ops/scatter_gather.cc
+++ b/reference_model/src/ops/scatter_gather.cc
@@ -133,25 +133,25 @@ int OpScatter<Dtype>::checkTensorAttributes()
if (inputs[0]->getRank() != 3)
{
- printNodeValidationError("OpGather: values_in needs to be rank 3 tensor");
+ printNodeValidationError("OpScatter: values_in needs to be rank 3 tensor");
return 1;
}
if (inputs[1]->getRank() != 2)
{
- printNodeValidationError("OpGather: indices needs to be rank 2 tensor");
+ printNodeValidationError("OpScatter: indices needs to be rank 2 tensor");
return 1;
}
if (inputs[2]->getRank() != 3)
{
- printNodeValidationError("OpGather: input needs to be rank 3 tensor");
+ printNodeValidationError("OpScatter: input needs to be rank 3 tensor");
return 1;
}
if (outputs[0]->getRank() != 3)
{
- printNodeValidationError("OpGather: values_out needs to be rank 3 tensor");
+ printNodeValidationError("OpScatter: values_out needs to be rank 3 tensor");
return 1;
}
@@ -168,13 +168,13 @@ int OpScatter<Dtype>::checkTensorAttributes()
if (W != inputs[1]->getShape()[1])
{
- printNodeValidationError("OpGather: dimension W mismatch");
+ printNodeValidationError("OpScatter: dimension W mismatch");
return 1;
}
if (C != inputs[0]->getShape()[2] || C != inputs[2]->getShape()[2])
{
- printNodeValidationError("OpGather: dimension C mismatch");
+ printNodeValidationError("OpScatter: dimension C mismatch");
return 1;
}
@@ -201,6 +201,12 @@ int OpScatter<Dtype>::eval()
// Initializes the output tensor with the input value for values that are unchanged by the scatter operation.
this->values_out->getTensor() = this->values_in->getTensor();
+ // Create array to check for double modification of output
+ std::array<Eigen::DenseIndex, 3> arrshape;
+ std::copy_n(outputs[0]->getShape().begin(), 3, arrshape.begin());
+ Eigen::Tensor<bool, 3> output_modified(arrshape);
+ output_modified.setZero();
+
for (int n = 0; n < N; n++)
{
for (int w = 0; w < W; w++)
@@ -209,8 +215,12 @@ int OpScatter<Dtype>::eval()
REQUIRE(k >= 0 && k < K, "OpScatter: index(%d, %d)=%d exceed valid range [0, %d]", n, w, k, K);
for (int c = 0; c < C; c++)
{
+ REQUIRE(output_modified(n, k, c) == false,
+ "OpScatter: output index(%d, %d, %d) written to more than once %d", n, w, c,
+ output_modified(n, k, c));
EigenType value = this->input->getTensor()(n, w, c);
this->values_out->getTensor()(n, k, c) = value;
+ output_modified(n, k, c) = true;
}
}
}