aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-09-28 15:41:57 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-09-28 15:41:57 -0700
commitc42addcee5240d9a0846d3f7e8cb5f88c80e2975 (patch)
treec3fc0720663d7505fd6723c05a9df18613b47544
parent6097c3db9a74a55d017e5168465c4e10b5793783 (diff)
downloadreference_model-c42addcee5240d9a0846d3f7e8cb5f88c80e2975.tar.gz
Removing rank 0 broadcast in binary op.
Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: I14bec5020c91f7abd6c1adc31068a22961330a97
-rw-r--r--reference_model/src/ops/ewise_binary.cc101
-rw-r--r--reference_model/src/ops/ewise_binary.h5
2 files changed, 26 insertions, 80 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index c33f646..a11d855 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -32,10 +32,8 @@ BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_,
setRequiredOperands(2, 1);
setRequiredRank(0, 6);
- a_rank = b_rank = max_input_rank = -1;
- a = b = nullptr;
- a_rank0 = b_rank0 = nullptr;
- result = nullptr;
+ a = b = nullptr;
+ result = nullptr;
fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); };
}
@@ -55,54 +53,37 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
return 1;
}
- a_rank = inputs[0]->getRank();
- b_rank = inputs[1]->getRank();
- if (a_rank != 0 && b_rank != 0 && a_rank != b_rank)
- {
- printNodeValidationError("Binary operator input ranks must match");
- return 1;
- }
-
- max_input_rank = a_rank >= b_rank ? a_rank : b_rank;
-
- // A & B must be the same types
- if (inputs[0]->matchType(*inputs[1]))
+ // A & B must be the same rank and types
+ if (inputs[0]->matchRankType(*inputs[1]))
{
printNodeValidationError("Binary operator input types must match");
return 1;
}
- // Result's geometry must match, but the type may be wider
- if (outputs[0]->getRank() != max_input_rank)
- {
- printNodeValidationError("Binary operator input and output genometry must match");
- return 1;
- }
-
- if (a_rank == max_input_rank)
- {
- a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
- }
- else
- {
- a_rank0 = dynamic_cast<TosaReference::TensorTemplate<ETensor0<InEigenType>>*>(inputs[0]);
- }
-
- if (b_rank == max_input_rank)
+ // Input and output rank must match
+ // If it's not MUL, type also needs to match as well.
+ if (nodeType != Op_MUL)
{
- b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+ if (inputs[0]->matchRankType(*outputs[0]))
+ {
+ printNodeValidationError("Binary operators (except MUL) input and output rank and type must match");
+ return 1;
+ }
}
else
{
- b_rank0 = dynamic_cast<TosaReference::TensorTemplate<ETensor0<InEigenType>>*>(inputs[1]);
+ if (inputs[0]->matchRank(*outputs[0]))
+ {
+ printNodeValidationError("MUL operator input and output rank must match");
+ return 1;
+ }
}
+ a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
- // either a or b can be rank0
- // a_rank0 and b_rank0 can't be valid at the same time.
- // if a and be are both rank0, they should be evaulated as 'a' and 'b', instead of 'a_rank0' and 'b_rank0'
- ASSERT_MEM((a || a_rank0) && (b || b_rank0) && !(a_rank0 && b_rank0) && result);
+ ASSERT_MEM(a && b && result);
return 0;
}
@@ -114,25 +95,10 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
std::vector<int> a_shape, b_shape;
- if (a_rank == max_input_rank)
- {
- a_shape = a->getShape();
- }
- else
- {
- a_shape.assign(max_input_rank, 1);
- }
+ a_shape = a->getShape();
+ b_shape = b->getShape();
- if (b_rank == max_input_rank)
- {
- b_shape = b->getShape();
- }
- else
- {
- b_shape.assign(max_input_rank, 1);
- }
-
- for (int i = 0; i < max_input_rank; i++)
+ for (int i = 0; i < (int)a_shape.size(); i++)
{
if (a_shape[i] != output_shape[i] && a_shape[i] == 1)
{
@@ -164,23 +130,8 @@ int BinaryNode<Rank, InDtype, OutDtype>::eval()
reshaper.fill(1);
TIn ia, ib;
- if (this->a_rank == this->max_input_rank)
- {
- ia = this->a->getTensor().broadcast(this->bcast_a);
- }
- else
- {
- ia = this->a_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_a);
- }
-
- if (this->b_rank == this->max_input_rank)
- {
- ib = this->b->getTensor().broadcast(this->bcast_b);
- }
- else
- {
- ib = this->b_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_b);
- }
+ ia = this->a->getTensor().broadcast(this->bcast_a);
+ ib = this->b->getTensor().broadcast(this->bcast_b);
this->result->getTensor() = ia.binaryExpr(ib, this->fcn);
@@ -475,7 +426,7 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn()
}
else
{
- result = static_cast<int64_t>(a) * b;
+ result = static_cast<int64_t>(a) * b;
int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range");
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
index 66da97a..fd4d408 100644
--- a/reference_model/src/ops/ewise_binary.h
+++ b/reference_model/src/ops/ewise_binary.h
@@ -63,12 +63,7 @@ protected:
Eigen::array<int, Rank> bcast_b;
TosaReference::TensorTemplate<TIn>* a;
TosaReference::TensorTemplate<TIn>* b;
- TosaReference::TensorTemplate<ETensor0<InEigenType>>* a_rank0;
- TosaReference::TensorTemplate<ETensor0<InEigenType>>* b_rank0;
TosaReference::TensorTemplate<TOut>* result;
- int a_rank;
- int b_rank;
- int max_input_rank;
};
// primary class