diff options
author | Jerry Ge <jerry.ge@arm.com> | 2023-05-23 20:59:32 +0000 |
---|---|---|
committer | Dominic Symes <dominic.symes@arm.com> | 2023-06-15 18:25:54 +0000 |
commit | 135c95544fda260e8ce622cff7835b886a97663f (patch) | |
tree | 5d46f8f48978112abff037309a827b5844ee80de /reference_model/src/ops/ewise_binary.cc | |
parent | cb7201e173961760c042cade591afe763c949c8f (diff) | |
download | reference_model-135c95544fda260e8ce622cff7835b886a97663f.tar.gz |
Add ERROR_IF to incorrect broadcast shapes
Signed-off-by: Jerry Ge <jerry.ge@arm.com>
Change-Id: I7460ad9eed3ed5c7cec6e855a0303753ed28eb1c
Diffstat (limited to 'reference_model/src/ops/ewise_binary.cc')
-rw-r--r-- | reference_model/src/ops/ewise_binary.cc | 19 |
1 files changed, 17 insertions, 2 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 1e873e7..2bc894d 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -85,25 +85,40 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes() } template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> -int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast() +int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast(std::vector<int>& calculated_shape) { const std::vector<int>& a_shape = a->getShape(); const std::vector<int>& b_shape = b->getShape(); const std::vector<int>& output_shape = result->getShape(); + // calculates the multipliers for Eigen for (int i = 0; i < Rank; i++) { bcast_a[i] = (a_shape[i] != output_shape[i] && a_shape[i] == 1) ? output_shape[i] : 1; bcast_b[i] = (b_shape[i] != output_shape[i] && b_shape[i] == 1) ? output_shape[i] : 1; } + // calculates the broadcasted output shape + calculated_shape = a_shape; + for (size_t i = 0; i < calculated_shape.size(); i++) { + if (calculated_shape[i] == 1) { + calculated_shape[i] = b_shape[i]; + } else { + ERROR_IF(b_shape[i] != 1 && b_shape[i] != calculated_shape[i], "Broadcast_shape failure, input shapes are not compatible"); + } + } + return 0; } template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> int BinaryNode<Rank, InDtype, OutDtype>::eval() { - this->broadcast(); + std::vector<int> calculated_shape; + this->broadcast(calculated_shape); + + auto result_shape = this->result->getShape(); + ERROR_IF(calculated_shape != result_shape, "Broadcast_shape failure, calculated_shape and result_shape don't match"); Eigen::array<int, Rank> reshaper; reshaper.fill(1); |