aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-10-06 14:37:37 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-10-06 14:40:34 -0700
commitec5586c198e81fba43f598af9ecd7a54cf460ea3 (patch)
treee9d8d76607e4db1f6b77afb1052619837ecccaa6
parent478101bebd3058a1917d9a9d87ca6d030af71c47 (diff)
downloadreference_model-ec5586c198e81fba43f598af9ecd7a54cf460ea3.tar.gz
Fix reduction ERROR_IF cases
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: Id0e4ec849a9cf94c9fb04ca999738cc164dbb669
-rw-r--r--reference_model/src/ops/reduction.cc18
1 files changed, 14 insertions, 4 deletions
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc
index 107c7a8..8c1c4d0 100644
--- a/reference_model/src/ops/reduction.cc
+++ b/reference_model/src/ops/reduction.cc
@@ -50,20 +50,30 @@ int ReduceNode<Rank, Dtype>::checkTensorAttributes()
if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank())
{
- printNodeValidationError("Reduce axis must between [0, input_rank - 1]");
+ printNodeValidationError("ReduceOp: axis must between [0, input_rank - 1]");
return 1;
}
- if (inputs[0]->matchRank(*outputs[0]))
+ if (inputs[0]->matchRankType(*outputs[0]))
{
- printNodeValidationError("Input and output tensor ranks must match");
+ printNodeValidationError("ReduceOp: Input and output tensor ranks must match");
+ return 1;
+ }
+
+ if (outputs[0]->getShape()[attribute->axis()] != 1)
+ {
+ printNodeValidationError("ReduceOp: Output tensor shape[axis] needs to be 1.");
return 1;
}
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
- ASSERT_MEM(in && out);
+ if ((!in) || (!out))
+ {
+ printNodeValidationError("ReduceOp: Input or output fail to cast to Eigen tensor since rank/type not expected");
+ return 1;
+ }
dims[0] = this->attribute->axis();