aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/ewise_ternary.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/ewise_ternary.cc')
-rw-r--r--reference_model/src/ops/ewise_ternary.cc24
1 files changed, 11 insertions, 13 deletions
diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc
index 64c4412..c265077 100644
--- a/reference_model/src/ops/ewise_ternary.cc
+++ b/reference_model/src/ops/ewise_ternary.cc
@@ -47,10 +47,11 @@ int OpSelectBase<Rank, Dtype>::checkTensorAttributes()
}
// output and input must be the same types
- if (inputs[0]->matchRank(*outputs[0]) || inputs[1]->matchRankType(*outputs[0]) ||
- inputs[2]->matchRankType(*outputs[0]))
+ if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */) ||
+ inputs[1]->matchRankTypeShape(*outputs[0], true /* broadcastOk */) ||
+ inputs[2]->matchRankTypeShape(*outputs[0], true /* broadcastOk */))
{
- printNodeValidationError("Failure to match input and output rank and type");
+ printNodeValidationError("Failure to match input and output rank/type/shape");
return 1;
}
@@ -71,19 +72,16 @@ int OpSelectBase<Rank, Dtype>::eval()
template <int Rank, DType Dtype>
int OpSelect<Rank, Dtype>::broadcast()
{
- std::vector<int> cond_shape = this->cond->getShape();
- std::vector<int> then_shape = this->then_val->getShape();
- std::vector<int> else_shape = this->else_val->getShape();
- std::vector<int> out_shape = this->out->getShape();
+ const std::vector<int>& cond_shape = this->cond->getShape();
+ const std::vector<int>& then_shape = this->then_val->getShape();
+ const std::vector<int>& else_shape = this->else_val->getShape();
+ const std::vector<int>& output_shape = this->out->getShape();
for (int i = 0; i < Rank; i++)
{
- this->bcast_cond[i] = (cond_shape[i] == 1) ? std::max(then_shape[i], else_shape[i]) : 1;
- this->bcast_then[i] = (then_shape[i] == 1) ? std::max(cond_shape[i], else_shape[i]) : 1;
- this->bcast_else[i] = (else_shape[i] == 1) ? std::max(then_shape[i], cond_shape[i]) : 1;
- ERROR_IF((this->bcast_cond[i] * cond_shape[i]) != out_shape[i], "SELECT broadcast invariant failed");
- ERROR_IF((this->bcast_then[i] * then_shape[i]) != out_shape[i], "SELECT broadcast invariant failed");
- ERROR_IF((this->bcast_else[i] * else_shape[i]) != out_shape[i], "SELECT broadcast invariant failed");
+ this->bcast_cond[i] = (cond_shape[i] != output_shape[i] && cond_shape[i] == 1) ? output_shape[i] : 1;
+ this->bcast_then[i] = (then_shape[i] != output_shape[i] && then_shape[i] == 1) ? output_shape[i] : 1;
+ this->bcast_else[i] = (else_shape[i] != output_shape[i] && else_shape[i] == 1) ? output_shape[i] : 1;
}
return 0;