diff options
Diffstat (limited to 'reference_model/src/ops/ewise_unary.cc')
-rw-r--r-- | reference_model/src/ops/ewise_unary.cc | 68 |
1 files changed, 67 insertions, 1 deletions
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc index d92cde1..dd9ea5a 100644 --- a/reference_model/src/ops/ewise_unary.cc +++ b/reference_model/src/ops/ewise_unary.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -162,6 +162,34 @@ int OpClz<Rank, Dtype>::register_fcn() } template <int Rank, TOSA_REF_TYPE Dtype> +int OpCos<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(cos(a)); }; + break; + case TOSA_REF_TYPE_FP64: + if (g_func_config.abs_mode) + { + // ABS_ERROR bounds return 1.0 + this->fcn = [](InEigenType a) -> OutEigenType { return 1.0; }; + } + else + { + this->fcn = [](InEigenType a) -> OutEigenType { return cos(a); }; + }; + break; + default: + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); + } + + return 0; +} + +template <int Rank, TOSA_REF_TYPE Dtype> int OpExp<Rank, Dtype>::register_fcn() { switch (Dtype) @@ -366,6 +394,34 @@ int OpRsqrt<Rank, Dtype>::register_fcn() return 0; } +template <int Rank, TOSA_REF_TYPE Dtype> +int OpSin<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(sin(a)); }; + break; + case TOSA_REF_TYPE_FP64: + if (g_func_config.abs_mode) + { + // ABS_ERROR bounds return 1.0 + this->fcn = [](InEigenType a) -> OutEigenType { return 1.0; }; + } + else + { + this->fcn = [](InEigenType a) -> OutEigenType { return sin(a); }; + }; + break; + default: + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); + } + + return 0; +} + // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, FP16); @@ -393,6 +449,11 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, BF16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP64); + DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32); @@ -423,6 +484,11 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP64); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, BF16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP64); + DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32); |