From 135c95544fda260e8ce622cff7835b886a97663f Mon Sep 17 00:00:00 2001 From: Jerry Ge Date: Tue, 23 May 2023 20:59:32 +0000 Subject: Add ERROR_IF to incorrect broadcast shapes Signed-off-by: Jerry Ge Change-Id: I7460ad9eed3ed5c7cec6e855a0303753ed28eb1c --- reference_model/src/ops/ewise_binary.cc | 19 +++++++++++++++++-- reference_model/src/ops/ewise_binary.h | 2 +- reference_model/src/ops/ewise_ternary.cc | 26 ++++++++++++++++++++++++-- reference_model/src/ops/ewise_ternary.h | 2 +- 4 files changed, 43 insertions(+), 6 deletions(-) (limited to 'reference_model/src/ops') diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 1e873e7..2bc894d 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -85,25 +85,40 @@ int BinaryNodeBase::checkTensorAttributes() } template -int BinaryNodeBase::broadcast() +int BinaryNodeBase::broadcast(std::vector& calculated_shape) { const std::vector& a_shape = a->getShape(); const std::vector& b_shape = b->getShape(); const std::vector& output_shape = result->getShape(); + // calculates the multipliers for Eigen for (int i = 0; i < Rank; i++) { bcast_a[i] = (a_shape[i] != output_shape[i] && a_shape[i] == 1) ? output_shape[i] : 1; bcast_b[i] = (b_shape[i] != output_shape[i] && b_shape[i] == 1) ? output_shape[i] : 1; } + // calculates the broadcasted output shape + calculated_shape = a_shape; + for (size_t i = 0; i < calculated_shape.size(); i++) { + if (calculated_shape[i] == 1) { + calculated_shape[i] = b_shape[i]; + } else { + ERROR_IF(b_shape[i] != 1 && b_shape[i] != calculated_shape[i], "Broadcast_shape failure, input shapes are not compatible"); + } + } + return 0; } template int BinaryNode::eval() { - this->broadcast(); + std::vector calculated_shape; + this->broadcast(calculated_shape); + + auto result_shape = this->result->getShape(); + ERROR_IF(calculated_shape != result_shape, "Broadcast_shape failure, calculated_shape and result_shape don't match"); Eigen::array reshaper; reshaper.fill(1); diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index 5f6e531..3a6f24c 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -55,7 +55,7 @@ public: using TOut = Eigen::Tensor; protected: - int broadcast(); + int broadcast(std::vector& calculated_shape); protected: std::function fcn; diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc index 16554b5..fd2510f 100644 --- a/reference_model/src/ops/ewise_ternary.cc +++ b/reference_model/src/ops/ewise_ternary.cc @@ -66,13 +66,14 @@ int OpSelectBase::eval() } template -int OpSelect::broadcast() +int OpSelect::broadcast(std::vector& calculated_shape) { const std::vector& cond_shape = this->cond->getShape(); const std::vector& then_shape = this->then_val->getShape(); const std::vector& else_shape = this->else_val->getShape(); const std::vector& output_shape = this->out->getShape(); + // calculates the multipliers for Eigen for (int i = 0; i < Rank; i++) { this->bcast_cond[i] = (cond_shape[i] != output_shape[i] && cond_shape[i] == 1) ? output_shape[i] : 1; @@ -80,13 +81,34 @@ int OpSelect::broadcast() this->bcast_else[i] = (else_shape[i] != output_shape[i] && else_shape[i] == 1) ? output_shape[i] : 1; } + // calculates the broadcasted output shape + calculated_shape = cond_shape; + for (size_t i = 0; i < calculated_shape.size(); i++) { + if (calculated_shape[i] == 1) { + calculated_shape[i] = then_shape[i]; + } else { + ERROR_IF(then_shape[i] != 1 && then_shape[i] != calculated_shape[i], "Broadcast_shape failure, input shapes are not compatible"); + } + + if (calculated_shape[i] == 1) { + calculated_shape[i] = else_shape[i]; + } else { + ERROR_IF(else_shape[i] != 1 && else_shape[i] != calculated_shape[i], "Broadcast_shape failure, input shapes are not compatible"); + } + } + return 0; } template int OpSelect::eval() { - this->broadcast(); + std::vector calculated_shape; + this->broadcast(calculated_shape); + + auto result_shape = this->out->getShape(); + ERROR_IF(calculated_shape != result_shape, "Broadcast_shape failure, calculated_shape and result_shape don't match"); + this->out->getTensor() = this->cond->getTensor() .broadcast(this->bcast_cond) .select(this->then_val->getTensor().broadcast(this->bcast_then), diff --git a/reference_model/src/ops/ewise_ternary.h b/reference_model/src/ops/ewise_ternary.h index c6970cb..f24dfbe 100644 --- a/reference_model/src/ops/ewise_ternary.h +++ b/reference_model/src/ops/ewise_ternary.h @@ -63,7 +63,7 @@ public: : OpSelectBase(sgt_, attribute_, id_) {} virtual int eval(); - int broadcast(); + int broadcast(std::vector& calculated_shape); using InEigenType = typename OpSelectBase::InEigenType; }; -- cgit v1.2.1