diff options
Diffstat (limited to 'reference_model/src/ops/reduction.cc')
-rw-r--r-- | reference_model/src/ops/reduction.cc | 111 |
1 files changed, 94 insertions, 17 deletions
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc index cd9d55f..bf8ba57 100644 --- a/reference_model/src/ops/reduction.cc +++ b/reference_model/src/ops/reduction.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> ReduceNode<Rank, Dtype>::ReduceNode(SubgraphTraverser* sgt_, const Op& op_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, op_, id_) { @@ -30,14 +30,14 @@ ReduceNode<Rank, Dtype>::ReduceNode(SubgraphTraverser* sgt_, const Op& op_, Tosa INIT_ATTRIBUTE(Axis); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> ReduceNode<Rank, Dtype>::~ReduceNode() { if (attribute) delete attribute; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int ReduceNode<Rank, Dtype>::checkTensorAttributes() { if (validateRequiredOperands()) @@ -100,7 +100,7 @@ struct AnyReducer { bool finalize(const bool accum) const { return accum; } }; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpReduceAll<Rank, Dtype>::eval() { this->out->getTensor() = this->in->getTensor().reduce(this->dims, AllReducer()).reshape(this->out->getTensor().dimensions()); @@ -108,7 +108,7 @@ int OpReduceAll<Rank, Dtype>::eval() return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpReduceAny<Rank, Dtype>::eval() { this->out->getTensor() = this->in->getTensor().reduce(this->dims, AnyReducer()).reshape(this->out->getTensor().dimensions()); @@ -116,7 +116,7 @@ int OpReduceAny<Rank, Dtype>::eval() return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpReduceMax<Rank, Dtype>::eval() { this->out->getTensor() = this->in->getTensor().maximum(this->dims).reshape(this->out->getTensor().dimensions()); @@ -124,7 +124,7 @@ int OpReduceMax<Rank, Dtype>::eval() return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpReduceMin<Rank, Dtype>::eval() { this->out->getTensor() = this->in->getTensor().minimum(this->dims).reshape(this->out->getTensor().dimensions()); @@ -132,35 +132,74 @@ int OpReduceMin<Rank, Dtype>::eval() return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpReduceProduct<Rank, Dtype>::eval() { switch(Dtype) { - case DType_FP16: - case DType_BF16: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);}); break; - default: + case TOSA_REF_TYPE_FP32: this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()); break; + default: + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); + } + + return GraphNode::eval(); +} + +struct ProductDoubleReducer +{ + static const bool PacketAccess = false; + void reduce(const double val, double* accum) + { + *accum *= val; + } + double initialize() const + { + return 1.0; + } + double finalize(const double accum) const + { + return accum; + } +}; + +template <int Rank, TOSA_REF_TYPE Dtype> +int OpReduceProductDouble<Rank, Dtype>::eval() +{ + switch (Dtype) + { + case TOSA_REF_TYPE_FP64: + this->out->getTensor() = this->in->getTensor() + .reduce(this->dims, ProductDoubleReducer()) + .reshape(this->out->getTensor().dimensions()); + break; + default: + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpReduceSum<Rank, Dtype>::eval() { switch(Dtype) { - case DType_FP16: - case DType_BF16: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);}); break; - default: + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_INT32: this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()); break; + default: + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return GraphNode::eval(); @@ -183,7 +222,7 @@ struct SumRequiresReducer { SubgraphTraverser* parent_sgt; }; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpReduceSumInt<Rank, Dtype>::eval() { this->out->getTensor() = this->in->getTensor().reduce(this->dims, SumRequiresReducer(this->parent_sgt)).reshape(this->out->getTensor().dimensions()); @@ -191,6 +230,40 @@ int OpReduceSumInt<Rank, Dtype>::eval() return GraphNode::eval(); } +struct SumDoubleReducer +{ + static const bool PacketAccess = false; + void reduce(const double val, double* accum) + { + *accum += val; + } + double initialize() const + { + return 0.0; + } + double finalize(const double accum) const + { + return accum; + } +}; + +template <int Rank, TOSA_REF_TYPE Dtype> +int OpReduceSumDouble<Rank, Dtype>::eval() +{ + switch (Dtype) + { + case TOSA_REF_TYPE_FP64: + this->out->getTensor() = this->in->getTensor() + .reduce(this->dims, SumDoubleReducer()) + .reshape(this->out->getTensor().dimensions()); + break; + default: + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); + } + + return GraphNode::eval(); +} + // template explicit instantiation DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL); @@ -202,6 +275,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16); @@ -209,12 +283,15 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProductDouble, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumDouble, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32); |