aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/ewise_ternary.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/ewise_ternary.cc')
-rw-r--r--reference_model/src/ops/ewise_ternary.cc30
1 files changed, 19 insertions, 11 deletions
diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc
index fd2510f..5861cb2 100644
--- a/reference_model/src/ops/ewise_ternary.cc
+++ b/reference_model/src/ops/ewise_ternary.cc
@@ -20,9 +20,7 @@ using namespace Eigen;
using namespace tosa;
template <int Rank, TOSA_REF_TYPE Dtype>
-OpSelectBase<Rank, Dtype>::OpSelectBase(SubgraphTraverser* sgt_,
- TosaAttributeBase* attribute_,
- uint64_t id_)
+OpSelectBase<Rank, Dtype>::OpSelectBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: GraphNode(sgt_, Op_SELECT, id_)
{
setRequiredOperands(3, 1);
@@ -83,17 +81,26 @@ int OpSelect<Rank, Dtype>::broadcast(std::vector<int>& calculated_shape)
// calculates the broadcasted output shape
calculated_shape = cond_shape;
- for (size_t i = 0; i < calculated_shape.size(); i++) {
- if (calculated_shape[i] == 1) {
+ for (size_t i = 0; i < calculated_shape.size(); i++)
+ {
+ if (calculated_shape[i] == 1)
+ {
calculated_shape[i] = then_shape[i];
- } else {
- ERROR_IF(then_shape[i] != 1 && then_shape[i] != calculated_shape[i], "Broadcast_shape failure, input shapes are not compatible");
+ }
+ else
+ {
+ ERROR_IF(then_shape[i] != 1 && then_shape[i] != calculated_shape[i],
+ "Broadcast_shape failure, input shapes are not compatible");
}
- if (calculated_shape[i] == 1) {
+ if (calculated_shape[i] == 1)
+ {
calculated_shape[i] = else_shape[i];
- } else {
- ERROR_IF(else_shape[i] != 1 && else_shape[i] != calculated_shape[i], "Broadcast_shape failure, input shapes are not compatible");
+ }
+ else
+ {
+ ERROR_IF(else_shape[i] != 1 && else_shape[i] != calculated_shape[i],
+ "Broadcast_shape failure, input shapes are not compatible");
}
}
@@ -107,7 +114,8 @@ int OpSelect<Rank, Dtype>::eval()
this->broadcast(calculated_shape);
auto result_shape = this->out->getShape();
- ERROR_IF(calculated_shape != result_shape, "Broadcast_shape failure, calculated_shape and result_shape don't match");
+ ERROR_IF(calculated_shape != result_shape,
+ "Broadcast_shape failure, calculated_shape and result_shape don't match");
this->out->getTensor() = this->cond->getTensor()
.broadcast(this->bcast_cond)