aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/reduction.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/reduction.cc')
-rw-r--r--reference_model/src/ops/reduction.cc18
1 files changed, 14 insertions, 4 deletions
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc
index 107c7a8..8c1c4d0 100644
--- a/reference_model/src/ops/reduction.cc
+++ b/reference_model/src/ops/reduction.cc
@@ -50,20 +50,30 @@ int ReduceNode<Rank, Dtype>::checkTensorAttributes()
if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank())
{
- printNodeValidationError("Reduce axis must between [0, input_rank - 1]");
+ printNodeValidationError("ReduceOp: axis must between [0, input_rank - 1]");
return 1;
}
- if (inputs[0]->matchRank(*outputs[0]))
+ if (inputs[0]->matchRankType(*outputs[0]))
{
- printNodeValidationError("Input and output tensor ranks must match");
+ printNodeValidationError("ReduceOp: Input and output tensor ranks must match");
+ return 1;
+ }
+
+ if (outputs[0]->getShape()[attribute->axis()] != 1)
+ {
+ printNodeValidationError("ReduceOp: Output tensor shape[axis] needs to be 1.");
return 1;
}
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
- ASSERT_MEM(in && out);
+ if ((!in) || (!out))
+ {
+ printNodeValidationError("ReduceOp: Input or output fail to cast to Eigen tensor since rank/type not expected");
+ return 1;
+ }
dims[0] = this->attribute->axis();