diff options
Diffstat (limited to 'reference_model/src/ops/reduction.cc')
-rw-r--r-- | reference_model/src/ops/reduction.cc | 20 |
1 files changed, 15 insertions, 5 deletions
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc index fd48472..f07ffd7 100644 --- a/reference_model/src/ops/reduction.cc +++ b/reference_model/src/ops/reduction.cc @@ -192,7 +192,10 @@ int OpReduceSum<Rank, Dtype>::eval() { case TOSA_REF_TYPE_FP16: case TOSA_REF_TYPE_BF16: - this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);}); + this->out->getTensor() = this->in->getTensor() + .sum(this->dims) + .reshape(this->out->getTensor().dimensions()) + .unaryExpr([](float f) { return fpTrunc<Dtype>(f); }); break; case TOSA_REF_TYPE_FP32: case TOSA_REF_TYPE_INT32: @@ -225,7 +228,9 @@ struct SumRequiresReducer { template <int Rank, TOSA_REF_TYPE Dtype> int OpReduceSumInt<Rank, Dtype>::eval() { - this->out->getTensor() = this->in->getTensor().reduce(this->dims, SumRequiresReducer(this->parent_sgt)).reshape(this->out->getTensor().dimensions()); + this->out->getTensor() = this->in->getTensor() + .reduce(this->dims, SumRequiresReducer(this->parent_sgt)) + .reshape(this->out->getTensor().dimensions()); return GraphNode::eval(); } @@ -250,12 +255,17 @@ struct SumDoubleReducer template <int Rank, TOSA_REF_TYPE Dtype> int OpReduceSumDouble<Rank, Dtype>::eval() { + typename ReduceNode<Rank, Dtype>::TIn in_val = this->in->getTensor(); + if (g_func_config.abs_mode) + { + // in abs_mode: take abs values of in value + in_val = in_val.abs(); + } switch (Dtype) { case TOSA_REF_TYPE_FP64: - this->out->getTensor() = this->in->getTensor() - .reduce(this->dims, SumDoubleReducer()) - .reshape(this->out->getTensor().dimensions()); + this->out->getTensor() = + in_val.reduce(this->dims, SumDoubleReducer()).reshape(this->out->getTensor().dimensions()); break; default: ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); |