From 1c3c847a4368817e2c9e3af66d5deb4c67993cbc Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Mon, 8 Nov 2021 11:19:10 -0800 Subject: Check valid broadcastable shape for binary and ternary ops Signed-off-by: Kevin Cheng Change-Id: I9ed3d8971a133b4cbb2cf7d827f4e69d55dee246 --- reference_model/src/ops/ewise_binary.cc | 41 +++++++++++++------------------- reference_model/src/ops/ewise_ternary.cc | 24 +++++++++---------- reference_model/src/tensor.h | 22 +++++++++++++++++ 3 files changed, 49 insertions(+), 38 deletions(-) diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index b199f69..287ad92 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -60,10 +60,18 @@ int BinaryNodeBase::checkTensorAttributes() return 1; } - if (inputs[0]->matchRank(*outputs[0])) + if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */)) { std::string err = - "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " input and output rank must match"; + "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " lhs input and output rank/shape must match"; + printNodeValidationError(err.c_str()); + return 1; + } + + if (inputs[1]->matchRankShape(*outputs[0], true /* broadcastOk */)) + { + std::string err = + "Binary operators " + std::string(EnumNamesOp()[nodeType]) + " rhs input and output rank/shape must match"; printNodeValidationError(err.c_str()); return 1; } @@ -82,31 +90,14 @@ int BinaryNodeBase::checkTensorAttributes() template int BinaryNodeBase::broadcast() { - auto output_shape = result->getTensor().dimensions(); - - std::vector a_shape, b_shape; - - a_shape = a->getShape(); - b_shape = b->getShape(); + const std::vector& a_shape = a->getShape(); + const std::vector& b_shape = b->getShape(); + const std::vector& output_shape = result->getShape(); - for (int i = 0; i < (int)a_shape.size(); i++) + for (int i = 0; i < Rank; i++) { - if (a_shape[i] != output_shape[i] && a_shape[i] == 1) - { - bcast_a[i] = output_shape[i]; - } - else - { - bcast_a[i] = 1; - } - if (b_shape[i] != output_shape[i] && b_shape[i] == 1) - { - bcast_b[i] = output_shape[i]; - } - else - { - bcast_b[i] = 1; - } + bcast_a[i] = (a_shape[i] != output_shape[i] && a_shape[i] == 1) ? output_shape[i] : 1; + bcast_b[i] = (b_shape[i] != output_shape[i] && b_shape[i] == 1) ? output_shape[i] : 1; } return 0; diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc index 64c4412..c265077 100644 --- a/reference_model/src/ops/ewise_ternary.cc +++ b/reference_model/src/ops/ewise_ternary.cc @@ -47,10 +47,11 @@ int OpSelectBase::checkTensorAttributes() } // output and input must be the same types - if (inputs[0]->matchRank(*outputs[0]) || inputs[1]->matchRankType(*outputs[0]) || - inputs[2]->matchRankType(*outputs[0])) + if (inputs[0]->matchRankShape(*outputs[0], true /* broadcastOk */) || + inputs[1]->matchRankTypeShape(*outputs[0], true /* broadcastOk */) || + inputs[2]->matchRankTypeShape(*outputs[0], true /* broadcastOk */)) { - printNodeValidationError("Failure to match input and output rank and type"); + printNodeValidationError("Failure to match input and output rank/type/shape"); return 1; } @@ -71,19 +72,16 @@ int OpSelectBase::eval() template int OpSelect::broadcast() { - std::vector cond_shape = this->cond->getShape(); - std::vector then_shape = this->then_val->getShape(); - std::vector else_shape = this->else_val->getShape(); - std::vector out_shape = this->out->getShape(); + const std::vector& cond_shape = this->cond->getShape(); + const std::vector& then_shape = this->then_val->getShape(); + const std::vector& else_shape = this->else_val->getShape(); + const std::vector& output_shape = this->out->getShape(); for (int i = 0; i < Rank; i++) { - this->bcast_cond[i] = (cond_shape[i] == 1) ? std::max(then_shape[i], else_shape[i]) : 1; - this->bcast_then[i] = (then_shape[i] == 1) ? std::max(cond_shape[i], else_shape[i]) : 1; - this->bcast_else[i] = (else_shape[i] == 1) ? std::max(then_shape[i], cond_shape[i]) : 1; - ERROR_IF((this->bcast_cond[i] * cond_shape[i]) != out_shape[i], "SELECT broadcast invariant failed"); - ERROR_IF((this->bcast_then[i] * then_shape[i]) != out_shape[i], "SELECT broadcast invariant failed"); - ERROR_IF((this->bcast_else[i] * else_shape[i]) != out_shape[i], "SELECT broadcast invariant failed"); + this->bcast_cond[i] = (cond_shape[i] != output_shape[i] && cond_shape[i] == 1) ? output_shape[i] : 1; + this->bcast_then[i] = (then_shape[i] != output_shape[i] && then_shape[i] == 1) ? output_shape[i] : 1; + this->bcast_else[i] = (else_shape[i] != output_shape[i] && else_shape[i] == 1) ? output_shape[i] : 1; } return 0; diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index 3fa23f9..5536583 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -148,6 +148,28 @@ public: return 0; } + const int matchRankShape(const Tensor& ref, const bool broadcastOk = false) const + { + if (matchRank(ref)) + return 1; + + for (size_t i = 0; i < shape.size(); i++) + { + if (shape[i] != ref.shape[i]) + { + if (!broadcastOk || + // For broadcasts, at least one operand must have size 1 + // if they don't both match + (broadcastOk && (shape[i] != 1 && ref.shape[i] != 1))) + { + return 1; + } + } + } + + return 0; + } + // Sometimes we might want to match several semi-compatible types, // so just check rank and size here const int matchRankSize(const Tensor& ref) const -- cgit v1.2.1