From c42addcee5240d9a0846d3f7e8cb5f88c80e2975 Mon Sep 17 00:00:00 2001 From: Kevin Cheng Date: Tue, 28 Sep 2021 15:41:57 -0700 Subject: Removing rank 0 broadcast in binary op. Signed-off-by: Kevin Cheng Change-Id: I14bec5020c91f7abd6c1adc31068a22961330a97 --- reference_model/src/ops/ewise_binary.cc | 101 ++++++++------------------------ reference_model/src/ops/ewise_binary.h | 5 -- 2 files changed, 26 insertions(+), 80 deletions(-) diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index c33f646..a11d855 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -32,10 +32,8 @@ BinaryNodeBase::BinaryNodeBase(SubgraphTraverser* sgt_, setRequiredOperands(2, 1); setRequiredRank(0, 6); - a_rank = b_rank = max_input_rank = -1; - a = b = nullptr; - a_rank0 = b_rank0 = nullptr; - result = nullptr; + a = b = nullptr; + result = nullptr; fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); }; } @@ -55,54 +53,37 @@ int BinaryNodeBase::checkTensorAttributes() return 1; } - a_rank = inputs[0]->getRank(); - b_rank = inputs[1]->getRank(); - if (a_rank != 0 && b_rank != 0 && a_rank != b_rank) - { - printNodeValidationError("Binary operator input ranks must match"); - return 1; - } - - max_input_rank = a_rank >= b_rank ? a_rank : b_rank; - - // A & B must be the same types - if (inputs[0]->matchType(*inputs[1])) + // A & B must be the same rank and types + if (inputs[0]->matchRankType(*inputs[1])) { printNodeValidationError("Binary operator input types must match"); return 1; } - // Result's geometry must match, but the type may be wider - if (outputs[0]->getRank() != max_input_rank) - { - printNodeValidationError("Binary operator input and output genometry must match"); - return 1; - } - - if (a_rank == max_input_rank) - { - a = dynamic_cast*>(inputs[0]); - } - else - { - a_rank0 = dynamic_cast>*>(inputs[0]); - } - - if (b_rank == max_input_rank) + // Input and output rank must match + // If it's not MUL, type also needs to match as well. + if (nodeType != Op_MUL) { - b = dynamic_cast*>(inputs[1]); + if (inputs[0]->matchRankType(*outputs[0])) + { + printNodeValidationError("Binary operators (except MUL) input and output rank and type must match"); + return 1; + } } else { - b_rank0 = dynamic_cast>*>(inputs[1]); + if (inputs[0]->matchRank(*outputs[0])) + { + printNodeValidationError("MUL operator input and output rank must match"); + return 1; + } } + a = dynamic_cast*>(inputs[0]); + b = dynamic_cast*>(inputs[1]); result = dynamic_cast*>(outputs[0]); - // either a or b can be rank0 - // a_rank0 and b_rank0 can't be valid at the same time. - // if a and be are both rank0, they should be evaulated as 'a' and 'b', instead of 'a_rank0' and 'b_rank0' - ASSERT_MEM((a || a_rank0) && (b || b_rank0) && !(a_rank0 && b_rank0) && result); + ASSERT_MEM(a && b && result); return 0; } @@ -114,25 +95,10 @@ int BinaryNodeBase::broadcast() std::vector a_shape, b_shape; - if (a_rank == max_input_rank) - { - a_shape = a->getShape(); - } - else - { - a_shape.assign(max_input_rank, 1); - } + a_shape = a->getShape(); + b_shape = b->getShape(); - if (b_rank == max_input_rank) - { - b_shape = b->getShape(); - } - else - { - b_shape.assign(max_input_rank, 1); - } - - for (int i = 0; i < max_input_rank; i++) + for (int i = 0; i < (int)a_shape.size(); i++) { if (a_shape[i] != output_shape[i] && a_shape[i] == 1) { @@ -164,23 +130,8 @@ int BinaryNode::eval() reshaper.fill(1); TIn ia, ib; - if (this->a_rank == this->max_input_rank) - { - ia = this->a->getTensor().broadcast(this->bcast_a); - } - else - { - ia = this->a_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_a); - } - - if (this->b_rank == this->max_input_rank) - { - ib = this->b->getTensor().broadcast(this->bcast_b); - } - else - { - ib = this->b_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_b); - } + ia = this->a->getTensor().broadcast(this->bcast_a); + ib = this->b->getTensor().broadcast(this->bcast_b); this->result->getTensor() = ia.binaryExpr(ib, this->fcn); @@ -475,7 +426,7 @@ int OpMul::register_fcn() } else { - result = static_cast(a) * b; + result = static_cast(a) * b; int64_t i32_max_in_64 = static_cast(std::numeric_limits::max()); int64_t i32_min_in_64 = static_cast(std::numeric_limits::min()); REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range"); diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index 66da97a..fd4d408 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -63,12 +63,7 @@ protected: Eigen::array bcast_b; TosaReference::TensorTemplate* a; TosaReference::TensorTemplate* b; - TosaReference::TensorTemplate>* a_rank0; - TosaReference::TensorTemplate>* b_rank0; TosaReference::TensorTemplate* result; - int a_rank; - int b_rank; - int max_input_rank; }; // primary class -- cgit v1.2.1