diff options
Diffstat (limited to 'reference_model/src/ops/ewise_unary.cc')
-rw-r--r-- | reference_model/src/ops/ewise_unary.cc | 164 |
1 files changed, 99 insertions, 65 deletions
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc index 8dc37e2..514cb84 100644 --- a/reference_model/src/ops/ewise_unary.cc +++ b/reference_model/src/ops/ewise_unary.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> UnaryNode<Rank, Dtype>::UnaryNode(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_) : GraphNode(sgt_, op_, id_) { @@ -35,11 +35,11 @@ UnaryNode<Rank, Dtype>::UnaryNode(SubgraphTraverser* sgt_, const Op& op_, uint64 }; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> UnaryNode<Rank, Dtype>::~UnaryNode() {} -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int UnaryNode<Rank, Dtype>::checkTensorAttributes() { // Check Tosa Level @@ -69,7 +69,7 @@ int UnaryNode<Rank, Dtype>::checkTensorAttributes() return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int UnaryNode<Rank, Dtype>::eval() { this->result->getTensor() = this->a->getTensor().unaryExpr(this->fcn); @@ -77,71 +77,75 @@ int UnaryNode<Rank, Dtype>::eval() return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpAbs<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_FP32: // No fpTrunc for FP32 as it is a no-op - case DType_INT32: + case TOSA_REF_TYPE_FP32: // No fpTrunc for FP32 as it is a no-op + case TOSA_REF_TYPE_FP64: + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); }; break; - case DType_FP16: - case DType_BF16: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(a > (InEigenType)0 ? a : (-a)); }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpBitwiseNot<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_INT8: - case DType_INT16: - case DType_INT32: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a) -> OutEigenType { return ~a; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpCeil<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(ceilf(a)); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return ceil(a); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpClz<Rank, Dtype>::register_fcn() { int32_t num_bits; switch (Dtype) { - case DType_INT32: + case TOSA_REF_TYPE_INT32: num_bits = 32; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } this->fcn = [num_bits](int32_t a) -> int32_t { @@ -163,73 +167,82 @@ int OpClz<Rank, Dtype>::register_fcn() return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpExp<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(expf(a)); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return exp(a); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpFloor<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(floorf(a)); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return floor(a); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpLog<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(logf(a)); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return log(a); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpLogicalNot<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_BOOL: + case TOSA_REF_TYPE_BOOL: this->fcn = [](InEigenType a) -> OutEigenType { return !a; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpNegate<Rank, Dtype>::OpNegate(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -240,31 +253,37 @@ OpNegate<Rank, Dtype>::OpNegate(SubgraphTraverser* sgt_, register_fcn(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpNegate<Rank, Dtype>::~OpNegate() { if (attribute) delete attribute; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpNegate<Rank, Dtype>::register_fcn() { - ERROR_IF(Dtype != DType_INT8 && attribute->input1_zp() != 0, "OpNegate: zeropoint only for int8_t"); - ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpNegate: zeropoint only for int8_t"); + ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->input1_zp() != 0, "OpNegate: zeropoint only for int8_t"); + ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->output_zp() != 0, "OpNegate: zeropoint only for int8_t"); switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { InEigenType result = -(a); return fpTrunc<Dtype>(result); }; break; - case DType_INT16: - case DType_INT32: + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { + OutEigenType result = -(a); + return result; + }; + break; + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: this->fcn = [this](InEigenType a) -> OutEigenType { int64_t res_in_64 = 0L - a; int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max()); @@ -272,7 +291,7 @@ int OpNegate<Rank, Dtype>::register_fcn() REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpNegate: result not in acc type range (int32)"); int64_t max_clip_in_64, min_clip_in_64; - if (Dtype == DType_INT16) + if (Dtype == TOSA_REF_TYPE_INT16) { max_clip_in_64 = static_cast<int64_t>(std::numeric_limits<int16_t>::max()); min_clip_in_64 = static_cast<int64_t>(std::numeric_limits<int16_t>::min()); @@ -285,7 +304,7 @@ int OpNegate<Rank, Dtype>::register_fcn() return static_cast<InEigenType>(std::min<int64_t>(max_clip_in_64, std::max<int64_t>(min_clip_in_64, res_in_64))); }; break; - case DType_INT8: + case TOSA_REF_TYPE_INT8: this->fcn = [this](InEigenType a) -> OutEigenType { int64_t res_in_64 = 0 - (a - attribute->input1_zp()); int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max()); @@ -297,41 +316,47 @@ int OpNegate<Rank, Dtype>::register_fcn() }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpReciprocal<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(1.0 / a); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return (1.0L / a); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpRsqrt<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(1.0 / sqrtf(a)); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return (1.0L / sqrt(a)); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; @@ -345,11 +370,13 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16); @@ -358,20 +385,24 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL); @@ -381,11 +412,14 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP64); |