aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/ewise_binary.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/ewise_binary.cc')
-rw-r--r--reference_model/src/ops/ewise_binary.cc125
1 files changed, 92 insertions, 33 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index b513f9a..ed176f3 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -441,9 +441,100 @@ int OpMinimum<Rank, Dtype>::register_fcn()
}
template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
+int OpMul<Rank, InDtype, OutDtype>::eval()
+{
+ // All cases except in_out_t == int32_t go to the general binary op workflow.
+ if constexpr (InDtype != TOSA_REF_TYPE_INT32)
+ {
+ return BinaryNode<Rank, InDtype, OutDtype>::eval();
+ }
+ else
+ {
+ std::vector<int> calculated_shape;
+ this->broadcast(calculated_shape);
+
+ auto result_shape = this->result->getShape();
+ ERROR_IF(calculated_shape != result_shape,
+ "Broadcast_shape failure, calculated_shape and result_shape don't match");
+
+ TIn ia = this->a->getTensor().broadcast(this->bcast_a);
+ TIn ib = this->b->getTensor().broadcast(this->bcast_b);
+
+ using TInt64 = Eigen::Tensor<int64_t, Rank>;
+ TInt64 tmp_result = ia.binaryExpr(ib, this->mul_fcn);
+
+ // Retrieve `shift` value and construct a Eigen tensor instance for it.
+ s = dynamic_cast<TosaReference::TensorTemplate<TShift>*>(this->inputs[2]);
+ ASSERT_MEM(s);
+
+ int shift = s->getTensor()(0);
+ TIn is(ia);
+ is.setConstant(shift);
+
+ TOut result = tmp_result.binaryExpr(is, this->shr_fcn);
+ this->result->getTensor() = result;
+
+ return GraphNode::eval();
+ }
+}
+
+// Eigen operators requires tensor operands meet NumDims > 0, partial specialize
+// this like we did for the base class.
+template <>
+int OpMul<0, TOSA_REF_TYPE_INT32, TOSA_REF_TYPE_INT32>::eval()
+{
+ Eigen::Tensor<int64_t, 0> tmp_result = this->a->getTensor().binaryExpr(this->b->getTensor(), this->mul_fcn);
+
+ // Retrieve `shift` value.
+ s = dynamic_cast<TosaReference::TensorTemplate<TShift>*>(this->inputs[2]);
+ ASSERT_MEM(s);
+
+ Eigen::Tensor<int64_t, 0> shift;
+ shift.setConstant(s->getTensor()(0));
+
+ this->result->getTensor() = tmp_result.binaryExpr(shift, this->shr_fcn);
+
+ return GraphNode::eval();
+}
+
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int OpMul<Rank, InDtype, OutDtype>::register_fcn()
{
- int32_t shift = attribute->shift();
+ // Register evaluation function for in_out_t == int32_t case first as it supports shift
+ // right to int32_t output.
+ if constexpr (InDtype == TOSA_REF_TYPE_INT32)
+ {
+ // Perform multiplication on int32_t inputs to product int64_t result.
+ this->mul_fcn = [](InEigenType a, InEigenType b) -> int64_t {
+ int64_t result = static_cast<int64_t>(a) * static_cast<int64_t>(b);
+ return result;
+ };
+
+ // Convert data from int64_t to int32_t.
+ this->shr_fcn = [this](int64_t a, InEigenType shift) -> OutEigenType {
+ int64_t result;
+ if (shift > 0)
+ {
+ int64_t round = INT64_C(1) << (shift - 1);
+ result = a + round;
+ result = result >> shift;
+
+ REQUIRE(result >= QMin && result <= QMax,
+ "OpMul: result %" PRId64 " exceeds valid range [%" PRId64 ", %" PRId64 "]", result, QMin, QMax);
+ }
+ else
+ {
+ result = a;
+ int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
+ int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
+ REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range");
+ return static_cast<InEigenType>(result);
+ }
+ return static_cast<OutEigenType>(result);
+ };
+
+ return 0;
+ }
switch (InDtype)
{
@@ -455,31 +546,6 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn()
case TOSA_REF_TYPE_FP64:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
break;
- case TOSA_REF_TYPE_INT32:
- this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType {
- int64_t result;
- if (shift > 0)
- {
- int64_t round = INT64_C(1) << (shift - 1);
- result = static_cast<int64_t>(a) * static_cast<int64_t>(b) + round;
- result = result >> shift;
-
- REQUIRE(result >= QMin && result <= QMax,
- "OpMul: result %" PRId64 " exceeds valid range [%" PRId64 ", %" PRId64 "]", result, QMin,
- QMax);
- }
- else
- {
- result = static_cast<int64_t>(a) * b;
- int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
- int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
- REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range");
- return static_cast<InEigenType>(result);
- }
-
- return static_cast<OutEigenType>(result);
- };
- break;
case TOSA_REF_TYPE_INT8:
case TOSA_REF_TYPE_INT16:
this->fcn = [](InEigenType lhs, InEigenType rhs) -> OutEigenType {
@@ -497,13 +563,6 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn()
return 0;
}
-template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
-OpMul<Rank, InDtype, OutDtype>::~OpMul()
-{
- if (attribute)
- delete attribute;
-}
-
template <int Rank, TOSA_REF_TYPE Dtype>
int OpPow<Rank, Dtype>::register_fcn()
{