diff options
Diffstat (limited to 'reference_model/src')
-rw-r--r-- | reference_model/src/generate/generate_utils.cc | 2 | ||||
-rw-r--r-- | reference_model/src/ops/ewise_unary.cc | 68 | ||||
-rw-r--r-- | reference_model/src/ops/ewise_unary.h | 4 | ||||
-rw-r--r-- | reference_model/src/ops/op_factory.cc | 12 | ||||
-rw-r--r-- | reference_model/src/verify/verify_abs_error.cc | 3 | ||||
-rw-r--r-- | reference_model/src/verify/verify_utils.cc | 7 | ||||
-rw-r--r-- | reference_model/src/verify/verify_utils.h | 1 |
7 files changed, 92 insertions, 5 deletions
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc index 271b7f5..d0d0194 100644 --- a/reference_model/src/generate/generate_utils.cc +++ b/reference_model/src/generate/generate_utils.cc @@ -51,6 +51,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Op, { Op::Op_CONCAT, "CONCAT" }, { Op::Op_CONST, "CONST" }, { Op::Op_CONV2D, "CONV2D" }, + { Op::Op_COS, "COS" }, { Op::Op_DEPTHWISE_CONV2D, "DEPTHWISE_CONV2D" }, { Op::Op_CONV3D, "CONV3D" }, { Op::Op_EQUAL, "EQUAL" }, @@ -85,6 +86,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Op, { Op::Op_SCATTER, "SCATTER" }, { Op::Op_SELECT, "SELECT" }, { Op::Op_SIGMOID, "SIGMOID" }, + { Op::Op_SIN, "SIN" }, { Op::Op_SLICE, "SLICE" }, { Op::Op_SUB, "SUB" }, { Op::Op_TANH, "TANH" }, diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc index d92cde1..dd9ea5a 100644 --- a/reference_model/src/ops/ewise_unary.cc +++ b/reference_model/src/ops/ewise_unary.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -162,6 +162,34 @@ int OpClz<Rank, Dtype>::register_fcn() } template <int Rank, TOSA_REF_TYPE Dtype> +int OpCos<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(cos(a)); }; + break; + case TOSA_REF_TYPE_FP64: + if (g_func_config.abs_mode) + { + // ABS_ERROR bounds return 1.0 + this->fcn = [](InEigenType a) -> OutEigenType { return 1.0; }; + } + else + { + this->fcn = [](InEigenType a) -> OutEigenType { return cos(a); }; + }; + break; + default: + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); + } + + return 0; +} + +template <int Rank, TOSA_REF_TYPE Dtype> int OpExp<Rank, Dtype>::register_fcn() { switch (Dtype) @@ -366,6 +394,34 @@ int OpRsqrt<Rank, Dtype>::register_fcn() return 0; } +template <int Rank, TOSA_REF_TYPE Dtype> +int OpSin<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(sin(a)); }; + break; + case TOSA_REF_TYPE_FP64: + if (g_func_config.abs_mode) + { + // ABS_ERROR bounds return 1.0 + this->fcn = [](InEigenType a) -> OutEigenType { return 1.0; }; + } + else + { + this->fcn = [](InEigenType a) -> OutEigenType { return sin(a); }; + }; + break; + default: + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); + } + + return 0; +} + // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, FP16); @@ -393,6 +449,11 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, BF16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP64); + DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32); @@ -423,6 +484,11 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP64); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, BF16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP64); + DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32); diff --git a/reference_model/src/ops/ewise_unary.h b/reference_model/src/ops/ewise_unary.h index 21ee276..a447388 100644 --- a/reference_model/src/ops/ewise_unary.h +++ b/reference_model/src/ops/ewise_unary.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -65,12 +65,14 @@ DEF_TEMPLATE_UNARY_OP(Abs, ABS) DEF_TEMPLATE_UNARY_OP(BitwiseNot, BITWISE_NOT) DEF_TEMPLATE_UNARY_OP(Ceil, CEIL) DEF_TEMPLATE_UNARY_OP(Clz, CLZ) +DEF_TEMPLATE_UNARY_OP(Cos, COS) DEF_TEMPLATE_UNARY_OP(Exp, EXP) DEF_TEMPLATE_UNARY_OP(Floor, FLOOR) DEF_TEMPLATE_UNARY_OP(Log, LOG) DEF_TEMPLATE_UNARY_OP(LogicalNot, LOGICAL_NOT) DEF_TEMPLATE_UNARY_OP(Reciprocal, RECIPROCAL) DEF_TEMPLATE_UNARY_OP(Rsqrt, RSQRT) +DEF_TEMPLATE_UNARY_OP(Sin, SIN) #undef DEF_TEMPLATE_UNARY_OP diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 6d66c07..1891ff4 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -299,6 +299,12 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_CLZ: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); break; + case Op_COS: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, BF16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP64); + break; case Op_EXP: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16); @@ -341,6 +347,12 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP64); break; + case Op_SIN: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, BF16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP64); + break; // ewise_ternary case Op_SELECT: diff --git a/reference_model/src/verify/verify_abs_error.cc b/reference_model/src/verify/verify_abs_error.cc index a7b7bc2..125045e 100644 --- a/reference_model/src/verify/verify_abs_error.cc +++ b/reference_model/src/verify/verify_abs_error.cc @@ -20,7 +20,6 @@ #include "half.hpp" #include "verifiers.h" - namespace TosaReference { @@ -36,7 +35,7 @@ double calcErrorBound(double referenceValue, double boundsValue, const void* cfg { valBound = std::max(cfg->lowerBound, valBound); } - return exp2(-AccPrecision<OutType>::normal_frac) * valBound; + return exp2(-AccPrecision<OutType>::normal_frac / cfg->normalDivisor) * valBound; } } // namespace diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc index d598e2c..50a98e5 100644 --- a/reference_model/src/verify/verify_utils.cc +++ b/reference_model/src/verify/verify_utils.cc @@ -80,6 +80,10 @@ void from_json(const nlohmann::json& j, AbsErrorVerifyInfo& absErrorInfo) { j.at("lower_bound").get_to(absErrorInfo.lowerBound); } + if (j.contains("normal_divisor")) + { + j.at("normal_divisor").get_to(absErrorInfo.normalDivisor); + } } void from_json(const nlohmann::json& j, RelativeVerifyInfo& rInfo) @@ -108,7 +112,8 @@ void from_json(const nlohmann::json& j, VerifyConfig& cfg) { j.at("reduce_product_info").get_to(cfg.reduceProductInfo); } - cfg.absErrorInfo.lowerBound = 0; + cfg.absErrorInfo.lowerBound = 0; + cfg.absErrorInfo.normalDivisor = 1; if (j.contains("abs_error_info")) { j.at("abs_error_info").get_to(cfg.absErrorInfo); diff --git a/reference_model/src/verify/verify_utils.h b/reference_model/src/verify/verify_utils.h index f53838a..9144317 100644 --- a/reference_model/src/verify/verify_utils.h +++ b/reference_model/src/verify/verify_utils.h @@ -82,6 +82,7 @@ struct AbsErrorVerifyInfo AbsErrorVerifyInfo() = default; double lowerBound; + double normalDivisor; }; /// \brief relative verification meta-data |