aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/reduction.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/reduction.cc')
-rw-r--r--reference_model/src/ops/reduction.cc111
1 files changed, 94 insertions, 17 deletions
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc
index cd9d55f..bf8ba57 100644
--- a/reference_model/src/ops/reduction.cc
+++ b/reference_model/src/ops/reduction.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -20,7 +20,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
ReduceNode<Rank, Dtype>::ReduceNode(SubgraphTraverser* sgt_, const Op& op_, TosaAttributeBase* attribute_, uint64_t id_)
: GraphNode(sgt_, op_, id_)
{
@@ -30,14 +30,14 @@ ReduceNode<Rank, Dtype>::ReduceNode(SubgraphTraverser* sgt_, const Op& op_, Tosa
INIT_ATTRIBUTE(Axis);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
ReduceNode<Rank, Dtype>::~ReduceNode()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int ReduceNode<Rank, Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -100,7 +100,7 @@ struct AnyReducer {
bool finalize(const bool accum) const { return accum; }
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceAll<Rank, Dtype>::eval()
{
this->out->getTensor() = this->in->getTensor().reduce(this->dims, AllReducer()).reshape(this->out->getTensor().dimensions());
@@ -108,7 +108,7 @@ int OpReduceAll<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceAny<Rank, Dtype>::eval()
{
this->out->getTensor() = this->in->getTensor().reduce(this->dims, AnyReducer()).reshape(this->out->getTensor().dimensions());
@@ -116,7 +116,7 @@ int OpReduceAny<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceMax<Rank, Dtype>::eval()
{
this->out->getTensor() = this->in->getTensor().maximum(this->dims).reshape(this->out->getTensor().dimensions());
@@ -124,7 +124,7 @@ int OpReduceMax<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceMin<Rank, Dtype>::eval()
{
this->out->getTensor() = this->in->getTensor().minimum(this->dims).reshape(this->out->getTensor().dimensions());
@@ -132,35 +132,74 @@ int OpReduceMin<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceProduct<Rank, Dtype>::eval()
{
switch(Dtype)
{
- case DType_FP16:
- case DType_BF16:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);});
break;
- default:
+ case TOSA_REF_TYPE_FP32:
this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions());
break;
+ default:
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
+ }
+
+ return GraphNode::eval();
+}
+
+struct ProductDoubleReducer
+{
+ static const bool PacketAccess = false;
+ void reduce(const double val, double* accum)
+ {
+ *accum *= val;
+ }
+ double initialize() const
+ {
+ return 1.0;
+ }
+ double finalize(const double accum) const
+ {
+ return accum;
+ }
+};
+
+template <int Rank, TOSA_REF_TYPE Dtype>
+int OpReduceProductDouble<Rank, Dtype>::eval()
+{
+ switch (Dtype)
+ {
+ case TOSA_REF_TYPE_FP64:
+ this->out->getTensor() = this->in->getTensor()
+ .reduce(this->dims, ProductDoubleReducer())
+ .reshape(this->out->getTensor().dimensions());
+ break;
+ default:
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceSum<Rank, Dtype>::eval()
{
switch(Dtype)
{
- case DType_FP16:
- case DType_BF16:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);});
break;
- default:
+ case TOSA_REF_TYPE_FP32:
+ case TOSA_REF_TYPE_INT32:
this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions());
break;
+ default:
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return GraphNode::eval();
@@ -183,7 +222,7 @@ struct SumRequiresReducer {
SubgraphTraverser* parent_sgt;
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceSumInt<Rank, Dtype>::eval()
{
this->out->getTensor() = this->in->getTensor().reduce(this->dims, SumRequiresReducer(this->parent_sgt)).reshape(this->out->getTensor().dimensions());
@@ -191,6 +230,40 @@ int OpReduceSumInt<Rank, Dtype>::eval()
return GraphNode::eval();
}
+struct SumDoubleReducer
+{
+ static const bool PacketAccess = false;
+ void reduce(const double val, double* accum)
+ {
+ *accum += val;
+ }
+ double initialize() const
+ {
+ return 0.0;
+ }
+ double finalize(const double accum) const
+ {
+ return accum;
+ }
+};
+
+template <int Rank, TOSA_REF_TYPE Dtype>
+int OpReduceSumDouble<Rank, Dtype>::eval()
+{
+ switch (Dtype)
+ {
+ case TOSA_REF_TYPE_FP64:
+ this->out->getTensor() = this->in->getTensor()
+ .reduce(this->dims, SumDoubleReducer())
+ .reshape(this->out->getTensor().dimensions());
+ break;
+ default:
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
+ }
+
+ return GraphNode::eval();
+}
+
// template explicit instantiation
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL);
@@ -202,6 +275,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16);
@@ -209,12 +283,15 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProductDouble, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumDouble, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);