aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/reduction.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/reduction.cc')
-rw-r--r--reference_model/src/ops/reduction.cc20
1 files changed, 15 insertions, 5 deletions
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc
index fd48472..f07ffd7 100644
--- a/reference_model/src/ops/reduction.cc
+++ b/reference_model/src/ops/reduction.cc
@@ -192,7 +192,10 @@ int OpReduceSum<Rank, Dtype>::eval()
{
case TOSA_REF_TYPE_FP16:
case TOSA_REF_TYPE_BF16:
- this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);});
+ this->out->getTensor() = this->in->getTensor()
+ .sum(this->dims)
+ .reshape(this->out->getTensor().dimensions())
+ .unaryExpr([](float f) { return fpTrunc<Dtype>(f); });
break;
case TOSA_REF_TYPE_FP32:
case TOSA_REF_TYPE_INT32:
@@ -225,7 +228,9 @@ struct SumRequiresReducer {
template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceSumInt<Rank, Dtype>::eval()
{
- this->out->getTensor() = this->in->getTensor().reduce(this->dims, SumRequiresReducer(this->parent_sgt)).reshape(this->out->getTensor().dimensions());
+ this->out->getTensor() = this->in->getTensor()
+ .reduce(this->dims, SumRequiresReducer(this->parent_sgt))
+ .reshape(this->out->getTensor().dimensions());
return GraphNode::eval();
}
@@ -250,12 +255,17 @@ struct SumDoubleReducer
template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceSumDouble<Rank, Dtype>::eval()
{
+ typename ReduceNode<Rank, Dtype>::TIn in_val = this->in->getTensor();
+ if (g_func_config.abs_mode)
+ {
+ // in abs_mode: take abs values of in value
+ in_val = in_val.abs();
+ }
switch (Dtype)
{
case TOSA_REF_TYPE_FP64:
- this->out->getTensor() = this->in->getTensor()
- .reduce(this->dims, SumDoubleReducer())
- .reshape(this->out->getTensor().dimensions());
+ this->out->getTensor() =
+ in_val.reduce(this->dims, SumDoubleReducer()).reshape(this->out->getTensor().dimensions());
break;
default:
ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));