diff options
Diffstat (limited to 'reference_model/src/ops/ewise_ternary.cc')
-rw-r--r-- | reference_model/src/ops/ewise_ternary.cc | 30 |
1 files changed, 19 insertions, 11 deletions
diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc index fd2510f..5861cb2 100644 --- a/reference_model/src/ops/ewise_ternary.cc +++ b/reference_model/src/ops/ewise_ternary.cc @@ -20,9 +20,7 @@ using namespace Eigen; using namespace tosa; template <int Rank, TOSA_REF_TYPE Dtype> -OpSelectBase<Rank, Dtype>::OpSelectBase(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - uint64_t id_) +OpSelectBase<Rank, Dtype>::OpSelectBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_SELECT, id_) { setRequiredOperands(3, 1); @@ -83,17 +81,26 @@ int OpSelect<Rank, Dtype>::broadcast(std::vector<int>& calculated_shape) // calculates the broadcasted output shape calculated_shape = cond_shape; - for (size_t i = 0; i < calculated_shape.size(); i++) { - if (calculated_shape[i] == 1) { + for (size_t i = 0; i < calculated_shape.size(); i++) + { + if (calculated_shape[i] == 1) + { calculated_shape[i] = then_shape[i]; - } else { - ERROR_IF(then_shape[i] != 1 && then_shape[i] != calculated_shape[i], "Broadcast_shape failure, input shapes are not compatible"); + } + else + { + ERROR_IF(then_shape[i] != 1 && then_shape[i] != calculated_shape[i], + "Broadcast_shape failure, input shapes are not compatible"); } - if (calculated_shape[i] == 1) { + if (calculated_shape[i] == 1) + { calculated_shape[i] = else_shape[i]; - } else { - ERROR_IF(else_shape[i] != 1 && else_shape[i] != calculated_shape[i], "Broadcast_shape failure, input shapes are not compatible"); + } + else + { + ERROR_IF(else_shape[i] != 1 && else_shape[i] != calculated_shape[i], + "Broadcast_shape failure, input shapes are not compatible"); } } @@ -107,7 +114,8 @@ int OpSelect<Rank, Dtype>::eval() this->broadcast(calculated_shape); auto result_shape = this->out->getShape(); - ERROR_IF(calculated_shape != result_shape, "Broadcast_shape failure, calculated_shape and result_shape don't match"); + ERROR_IF(calculated_shape != result_shape, + "Broadcast_shape failure, calculated_shape and result_shape don't match"); this->out->getTensor() = this->cond->getTensor() .broadcast(this->bcast_cond) |