diff options
Diffstat (limited to 'reference_model/src/ops/ewise_ternary.cc')
-rw-r--r-- | reference_model/src/ops/ewise_ternary.cc | 26 |
1 files changed, 24 insertions, 2 deletions
diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc index 16554b5..fd2510f 100644 --- a/reference_model/src/ops/ewise_ternary.cc +++ b/reference_model/src/ops/ewise_ternary.cc @@ -66,13 +66,14 @@ int OpSelectBase<Rank, Dtype>::eval() } template <int Rank, TOSA_REF_TYPE Dtype> -int OpSelect<Rank, Dtype>::broadcast() +int OpSelect<Rank, Dtype>::broadcast(std::vector<int>& calculated_shape) { const std::vector<int>& cond_shape = this->cond->getShape(); const std::vector<int>& then_shape = this->then_val->getShape(); const std::vector<int>& else_shape = this->else_val->getShape(); const std::vector<int>& output_shape = this->out->getShape(); + // calculates the multipliers for Eigen for (int i = 0; i < Rank; i++) { this->bcast_cond[i] = (cond_shape[i] != output_shape[i] && cond_shape[i] == 1) ? output_shape[i] : 1; @@ -80,13 +81,34 @@ int OpSelect<Rank, Dtype>::broadcast() this->bcast_else[i] = (else_shape[i] != output_shape[i] && else_shape[i] == 1) ? output_shape[i] : 1; } + // calculates the broadcasted output shape + calculated_shape = cond_shape; + for (size_t i = 0; i < calculated_shape.size(); i++) { + if (calculated_shape[i] == 1) { + calculated_shape[i] = then_shape[i]; + } else { + ERROR_IF(then_shape[i] != 1 && then_shape[i] != calculated_shape[i], "Broadcast_shape failure, input shapes are not compatible"); + } + + if (calculated_shape[i] == 1) { + calculated_shape[i] = else_shape[i]; + } else { + ERROR_IF(else_shape[i] != 1 && else_shape[i] != calculated_shape[i], "Broadcast_shape failure, input shapes are not compatible"); + } + } + return 0; } template <int Rank, TOSA_REF_TYPE Dtype> int OpSelect<Rank, Dtype>::eval() { - this->broadcast(); + std::vector<int> calculated_shape; + this->broadcast(calculated_shape); + + auto result_shape = this->out->getShape(); + ERROR_IF(calculated_shape != result_shape, "Broadcast_shape failure, calculated_shape and result_shape don't match"); + this->out->getTensor() = this->cond->getTensor() .broadcast(this->bcast_cond) .select(this->then_val->getTensor().broadcast(this->bcast_then), |