diff options
Diffstat (limited to 'reference_model/src/ops/activation_funcs.cc')
-rw-r--r-- | reference_model/src/ops/activation_funcs.cc | 56 |
1 files changed, 37 insertions, 19 deletions
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc index 24bd077..6681d6d 100644 --- a/reference_model/src/ops/activation_funcs.cc +++ b/reference_model/src/ops/activation_funcs.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpClamp<Rank, Dtype>::register_fcn() { // Check Tosa Level @@ -32,9 +32,9 @@ int OpClamp<Rank, Dtype>::register_fcn() switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: { InEigenType min = (InEigenType)attribute->min_fp(); InEigenType max = (InEigenType)attribute->max_fp(); @@ -43,8 +43,17 @@ int OpClamp<Rank, Dtype>::register_fcn() this->fcn = [min, max](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(a <= min ? min : a >= max ? max : a); }; } break; - case DType_INT8: - case DType_INT16: + case TOSA_REF_TYPE_FP64: + { + InEigenType min = (InEigenType)attribute->min_fp(); + InEigenType max = (InEigenType)attribute->max_fp(); + ERROR_IF(max < min, "OpClamp: max smaller than min"); + + this->fcn = [min, max](InEigenType a) -> OutEigenType { return (a <= min ? min : a >= max ? max : a); }; + } + break; + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: { InEigenType min = (InEigenType)attribute->min_int(); InEigenType max = (InEigenType)attribute->max_int(); @@ -53,19 +62,19 @@ int OpClamp<Rank, Dtype>::register_fcn() } break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpClamp<Rank, Dtype>::~OpClamp() { if (attribute) delete attribute; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpSigmoid<Rank, Dtype>::register_fcn() { // Check Tosa Level @@ -74,21 +83,24 @@ int OpSigmoid<Rank, Dtype>::register_fcn() switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(1.f / (1.f + (expf(-1.f * a)))); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return (1.L / (1.L + (exp(-1.L * a)))); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpTanh<Rank, Dtype>::register_fcn() { // Check Tosa Level @@ -97,13 +109,16 @@ int OpTanh<Rank, Dtype>::register_fcn() switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(tanhf(a)); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return tanh(a); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; @@ -115,11 +130,14 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP64); |