diff options
author | Kevin Cheng <kevin.cheng@arm.com> | 2021-11-08 11:19:10 -0800 |
---|---|---|
committer | Kevin Cheng <kevin.cheng@arm.com> | 2021-11-09 20:59:06 +0000 |
commit | 1c3c847a4368817e2c9e3af66d5deb4c67993cbc (patch) | |
tree | 913962b18e07dc5e0d837dc182b76066538ebc65 /reference_model/src/ops/ewise_ternary.cc | |
parent | 01c359de3ad4fc617ec7726fd7103749f6d3933f (diff) | |
download | reference_model-1c3c847a4368817e2c9e3af66d5deb4c67993cbc.tar.gz |
Check valid broadcastable shape for binary and ternary ops
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com>
Change-Id: I9ed3d8971a133b4cbb2cf7d827f4e69d55dee246
Diffstat (limited to 'reference_model/src/ops/ewise_ternary.cc')
-rw-r--r-- | reference_model/src/ops/ewise_ternary.cc | 24 |
1 files changed, 11 insertions, 13 deletions
diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc index 64c4412..c265077 100644 --- a/reference_model/src/ops/ewise_ternary.cc +++ b/reference_model/src/ops/ewise_ternary.cc @@ -47,10 +47,11 @@ int OpSelectBase<Rank, Dtype>::checkTensorAttributes() } // output and input must be the same types - if (inputs[0]->matchRank(*outputs[0]) || inputs[1]->matchRankType(*outputs[0]) || - inputs[2]->matchRankType(*outputs[0])) + if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */) || + inputs[1]->matchRankTypeShape(*outputs[0], true /* broadcastOk */) || + inputs[2]->matchRankTypeShape(*outputs[0], true /* broadcastOk */)) { - printNodeValidationError("Failure to match input and output rank and type"); + printNodeValidationError("Failure to match input and output rank/type/shape"); return 1; } @@ -71,19 +72,16 @@ int OpSelectBase<Rank, Dtype>::eval() template <int Rank, DType Dtype> int OpSelect<Rank, Dtype>::broadcast() { - std::vector<int> cond_shape = this->cond->getShape(); - std::vector<int> then_shape = this->then_val->getShape(); - std::vector<int> else_shape = this->else_val->getShape(); - std::vector<int> out_shape = this->out->getShape(); + const std::vector<int>& cond_shape = this->cond->getShape(); + const std::vector<int>& then_shape = this->then_val->getShape(); + const std::vector<int>& else_shape = this->else_val->getShape(); + const std::vector<int>& output_shape = this->out->getShape(); for (int i = 0; i < Rank; i++) { - this->bcast_cond[i] = (cond_shape[i] == 1) ? std::max(then_shape[i], else_shape[i]) : 1; - this->bcast_then[i] = (then_shape[i] == 1) ? std::max(cond_shape[i], else_shape[i]) : 1; - this->bcast_else[i] = (else_shape[i] == 1) ? std::max(then_shape[i], cond_shape[i]) : 1; - ERROR_IF((this->bcast_cond[i] * cond_shape[i]) != out_shape[i], "SELECT broadcast invariant failed"); - ERROR_IF((this->bcast_then[i] * then_shape[i]) != out_shape[i], "SELECT broadcast invariant failed"); - ERROR_IF((this->bcast_else[i] * else_shape[i]) != out_shape[i], "SELECT broadcast invariant failed"); + this->bcast_cond[i] = (cond_shape[i] != output_shape[i] && cond_shape[i] == 1) ? output_shape[i] : 1; + this->bcast_then[i] = (then_shape[i] != output_shape[i] && then_shape[i] == 1) ? output_shape[i] : 1; + this->bcast_else[i] = (else_shape[i] != output_shape[i] && else_shape[i] == 1) ? output_shape[i] : 1; } return 0; |