aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTatWai Chong <tatwai.chong@arm.com>2024-02-12 16:53:23 -0800
committerEric Kunze <eric.kunze@arm.com>2024-02-22 02:08:36 +0000
commitc7bfa58c76e73aac772f714d8ae04cc875715689 (patch)
tree0491f8466ad6463ec03cbb2c80ccaa416d940b4a
parent2c34b4616a10539211e7006bc43f3c71e86c30bb (diff)
downloadreference_model-c7bfa58c76e73aac772f714d8ae04cc875715689.tar.gz
Change the shift of mul to tensor type
Right shift result on i32_t data type only, i.e. other data types don't carry the shift operand. In the spec, the shift type is a tensor in MT profile and is an attribute in BI/MI profiles. Currently we treat the shift as tensor throughout. In implementation, since `ternaryExpr` is not implemented in Eigen, decompose the original calculation into multiply and shift operation seperately, and execute them via `binaryExpr`. Change-Id: I349f4969545134ac5f13bc83032cd75cca3e7ba0 Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
-rw-r--r--reference_model/src/ops/ewise_binary.cc125
-rw-r--r--reference_model/src/ops/ewise_binary.h27
2 files changed, 113 insertions, 39 deletions
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index b513f9a..ed176f3 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -441,9 +441,100 @@ int OpMinimum<Rank, Dtype>::register_fcn()
}
template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
+int OpMul<Rank, InDtype, OutDtype>::eval()
+{
+ // All cases except in_out_t == int32_t go to the general binary op workflow.
+ if constexpr (InDtype != TOSA_REF_TYPE_INT32)
+ {
+ return BinaryNode<Rank, InDtype, OutDtype>::eval();
+ }
+ else
+ {
+ std::vector<int> calculated_shape;
+ this->broadcast(calculated_shape);
+
+ auto result_shape = this->result->getShape();
+ ERROR_IF(calculated_shape != result_shape,
+ "Broadcast_shape failure, calculated_shape and result_shape don't match");
+
+ TIn ia = this->a->getTensor().broadcast(this->bcast_a);
+ TIn ib = this->b->getTensor().broadcast(this->bcast_b);
+
+ using TInt64 = Eigen::Tensor<int64_t, Rank>;
+ TInt64 tmp_result = ia.binaryExpr(ib, this->mul_fcn);
+
+ // Retrieve `shift` value and construct a Eigen tensor instance for it.
+ s = dynamic_cast<TosaReference::TensorTemplate<TShift>*>(this->inputs[2]);
+ ASSERT_MEM(s);
+
+ int shift = s->getTensor()(0);
+ TIn is(ia);
+ is.setConstant(shift);
+
+ TOut result = tmp_result.binaryExpr(is, this->shr_fcn);
+ this->result->getTensor() = result;
+
+ return GraphNode::eval();
+ }
+}
+
+// Eigen operators requires tensor operands meet NumDims > 0, partial specialize
+// this like we did for the base class.
+template <>
+int OpMul<0, TOSA_REF_TYPE_INT32, TOSA_REF_TYPE_INT32>::eval()
+{
+ Eigen::Tensor<int64_t, 0> tmp_result = this->a->getTensor().binaryExpr(this->b->getTensor(), this->mul_fcn);
+
+ // Retrieve `shift` value.
+ s = dynamic_cast<TosaReference::TensorTemplate<TShift>*>(this->inputs[2]);
+ ASSERT_MEM(s);
+
+ Eigen::Tensor<int64_t, 0> shift;
+ shift.setConstant(s->getTensor()(0));
+
+ this->result->getTensor() = tmp_result.binaryExpr(shift, this->shr_fcn);
+
+ return GraphNode::eval();
+}
+
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int OpMul<Rank, InDtype, OutDtype>::register_fcn()
{
- int32_t shift = attribute->shift();
+ // Register evaluation function for in_out_t == int32_t case first as it supports shift
+ // right to int32_t output.
+ if constexpr (InDtype == TOSA_REF_TYPE_INT32)
+ {
+ // Perform multiplication on int32_t inputs to product int64_t result.
+ this->mul_fcn = [](InEigenType a, InEigenType b) -> int64_t {
+ int64_t result = static_cast<int64_t>(a) * static_cast<int64_t>(b);
+ return result;
+ };
+
+ // Convert data from int64_t to int32_t.
+ this->shr_fcn = [this](int64_t a, InEigenType shift) -> OutEigenType {
+ int64_t result;
+ if (shift > 0)
+ {
+ int64_t round = INT64_C(1) << (shift - 1);
+ result = a + round;
+ result = result >> shift;
+
+ REQUIRE(result >= QMin && result <= QMax,
+ "OpMul: result %" PRId64 " exceeds valid range [%" PRId64 ", %" PRId64 "]", result, QMin, QMax);
+ }
+ else
+ {
+ result = a;
+ int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
+ int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
+ REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range");
+ return static_cast<InEigenType>(result);
+ }
+ return static_cast<OutEigenType>(result);
+ };
+
+ return 0;
+ }
switch (InDtype)
{
@@ -455,31 +546,6 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn()
case TOSA_REF_TYPE_FP64:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
break;
- case TOSA_REF_TYPE_INT32:
- this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType {
- int64_t result;
- if (shift > 0)
- {
- int64_t round = INT64_C(1) << (shift - 1);
- result = static_cast<int64_t>(a) * static_cast<int64_t>(b) + round;
- result = result >> shift;
-
- REQUIRE(result >= QMin && result <= QMax,
- "OpMul: result %" PRId64 " exceeds valid range [%" PRId64 ", %" PRId64 "]", result, QMin,
- QMax);
- }
- else
- {
- result = static_cast<int64_t>(a) * b;
- int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
- int64_t i32_min_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::min());
- REQUIRE(result <= i32_max_in_64 && result >= i32_min_in_64, "OpMul: result not in i32 range");
- return static_cast<InEigenType>(result);
- }
-
- return static_cast<OutEigenType>(result);
- };
- break;
case TOSA_REF_TYPE_INT8:
case TOSA_REF_TYPE_INT16:
this->fcn = [](InEigenType lhs, InEigenType rhs) -> OutEigenType {
@@ -497,13 +563,6 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn()
return 0;
}
-template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
-OpMul<Rank, InDtype, OutDtype>::~OpMul()
-{
- if (attribute)
- delete attribute;
-}
-
template <int Rank, TOSA_REF_TYPE Dtype>
int OpPow<Rank, Dtype>::register_fcn()
{
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
index 1215c93..8d2e486 100644
--- a/reference_model/src/ops/ewise_binary.h
+++ b/reference_model/src/ops/ewise_binary.h
@@ -159,18 +159,33 @@ public:
OpMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: BinaryNode<Rank, InDtype, OutDtype>(sgt_, Op_MUL, id_)
{
- INIT_ATTRIBUTE(Mul);
+ if constexpr (InDtype == TOSA_REF_TYPE_INT32)
+ {
+ // Require `shift` operand.
+ this->setRequiredOperands(3, 1);
+ }
register_fcn();
}
- virtual ~OpMul();
static constexpr int64_t QMin = GetQMin<OutDtype>::value;
static constexpr int64_t QMax = GetQMax<OutDtype>::value;
- using InEigenType = typename GetEigenType<InDtype>::type;
- using OutEigenType = typename GetEigenType<OutDtype>::type;
- virtual int register_fcn();
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using ShiftEigenType = typename GetEigenType<TOSA_REF_TYPE_INT8>::type;
+
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+ using TShift = Eigen::Tensor<ShiftEigenType, 0>;
+
+ int register_fcn();
+ int eval();
+
+ // Note that INT64 is not natively supported in Dtype system.
+ std::function<int64_t(InEigenType, InEigenType)> mul_fcn;
+ std::function<OutEigenType(int64_t, InEigenType)> shr_fcn;
protected:
- TosaMulAttribute* attribute;
+ TosaReference::TensorTemplate<TShift>* s;
};
template <int Rank, TOSA_REF_TYPE InDtype>