From 51bd4f5166c50a89017307b55dee0f5bda096b7b Mon Sep 17 00:00:00 2001 From: Jerry Ge Date: Tue, 20 Feb 2024 11:21:19 -0800 Subject: Add Tosa Sin/Cos operators - Add Tosa Sin/Cos operators to reference_model - Add conformances tests Signed-off-by: Jerry Ge Change-Id: I3f597ddf5dac2c64d6dd6aa15781b40b8468eaa6 --- reference_model/src/generate/generate_utils.cc | 2 + reference_model/src/ops/ewise_unary.cc | 68 ++++++++++- reference_model/src/ops/ewise_unary.h | 4 +- reference_model/src/ops/op_factory.cc | 12 ++ reference_model/src/verify/verify_abs_error.cc | 3 +- reference_model/src/verify/verify_utils.cc | 7 +- reference_model/src/verify/verify_utils.h | 1 + .../schemavalidation/compliance-config.schema.json | 4 + verif/conformance/test_select.py | 12 ++ verif/conformance/tosa_main_profile_ops_info.json | 124 +++++++++++++++++++++ verif/generator/tosa_test_gen.py | 48 ++++++++ 11 files changed, 280 insertions(+), 5 deletions(-) diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc index 271b7f5..d0d0194 100644 --- a/reference_model/src/generate/generate_utils.cc +++ b/reference_model/src/generate/generate_utils.cc @@ -51,6 +51,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Op, { Op::Op_CONCAT, "CONCAT" }, { Op::Op_CONST, "CONST" }, { Op::Op_CONV2D, "CONV2D" }, + { Op::Op_COS, "COS" }, { Op::Op_DEPTHWISE_CONV2D, "DEPTHWISE_CONV2D" }, { Op::Op_CONV3D, "CONV3D" }, { Op::Op_EQUAL, "EQUAL" }, @@ -85,6 +86,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Op, { Op::Op_SCATTER, "SCATTER" }, { Op::Op_SELECT, "SELECT" }, { Op::Op_SIGMOID, "SIGMOID" }, + { Op::Op_SIN, "SIN" }, { Op::Op_SLICE, "SLICE" }, { Op::Op_SUB, "SUB" }, { Op::Op_TANH, "TANH" }, diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc index d92cde1..dd9ea5a 100644 --- a/reference_model/src/ops/ewise_unary.cc +++ b/reference_model/src/ops/ewise_unary.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -161,6 +161,34 @@ int OpClz::register_fcn() return 0; } +template +int OpCos::register_fcn() +{ + switch (Dtype) + { + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(cos(a)); }; + break; + case TOSA_REF_TYPE_FP64: + if (g_func_config.abs_mode) + { + // ABS_ERROR bounds return 1.0 + this->fcn = [](InEigenType a) -> OutEigenType { return 1.0; }; + } + else + { + this->fcn = [](InEigenType a) -> OutEigenType { return cos(a); }; + }; + break; + default: + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); + } + + return 0; +} + template int OpExp::register_fcn() { @@ -366,6 +394,34 @@ int OpRsqrt::register_fcn() return 0; } +template +int OpSin::register_fcn() +{ + switch (Dtype) + { + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(sin(a)); }; + break; + case TOSA_REF_TYPE_FP64: + if (g_func_config.abs_mode) + { + // ABS_ERROR bounds return 1.0 + this->fcn = [](InEigenType a) -> OutEigenType { return 1.0; }; + } + else + { + this->fcn = [](InEigenType a) -> OutEigenType { return sin(a); }; + }; + break; + default: + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); + } + + return 0; +} + // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, FP16); @@ -393,6 +449,11 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, BF16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP64); + DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32); @@ -423,6 +484,11 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP64); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, BF16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP64); + DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32); diff --git a/reference_model/src/ops/ewise_unary.h b/reference_model/src/ops/ewise_unary.h index 21ee276..a447388 100644 --- a/reference_model/src/ops/ewise_unary.h +++ b/reference_model/src/ops/ewise_unary.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -65,12 +65,14 @@ DEF_TEMPLATE_UNARY_OP(Abs, ABS) DEF_TEMPLATE_UNARY_OP(BitwiseNot, BITWISE_NOT) DEF_TEMPLATE_UNARY_OP(Ceil, CEIL) DEF_TEMPLATE_UNARY_OP(Clz, CLZ) +DEF_TEMPLATE_UNARY_OP(Cos, COS) DEF_TEMPLATE_UNARY_OP(Exp, EXP) DEF_TEMPLATE_UNARY_OP(Floor, FLOOR) DEF_TEMPLATE_UNARY_OP(Log, LOG) DEF_TEMPLATE_UNARY_OP(LogicalNot, LOGICAL_NOT) DEF_TEMPLATE_UNARY_OP(Reciprocal, RECIPROCAL) DEF_TEMPLATE_UNARY_OP(Rsqrt, RSQRT) +DEF_TEMPLATE_UNARY_OP(Sin, SIN) #undef DEF_TEMPLATE_UNARY_OP diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 6d66c07..1891ff4 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -299,6 +299,12 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_CLZ: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); break; + case Op_COS: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, BF16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP64); + break; case Op_EXP: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16); @@ -341,6 +347,12 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP64); break; + case Op_SIN: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, BF16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP64); + break; // ewise_ternary case Op_SELECT: diff --git a/reference_model/src/verify/verify_abs_error.cc b/reference_model/src/verify/verify_abs_error.cc index a7b7bc2..125045e 100644 --- a/reference_model/src/verify/verify_abs_error.cc +++ b/reference_model/src/verify/verify_abs_error.cc @@ -20,7 +20,6 @@ #include "half.hpp" #include "verifiers.h" - namespace TosaReference { @@ -36,7 +35,7 @@ double calcErrorBound(double referenceValue, double boundsValue, const void* cfg { valBound = std::max(cfg->lowerBound, valBound); } - return exp2(-AccPrecision::normal_frac) * valBound; + return exp2(-AccPrecision::normal_frac / cfg->normalDivisor) * valBound; } } // namespace diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc index d598e2c..50a98e5 100644 --- a/reference_model/src/verify/verify_utils.cc +++ b/reference_model/src/verify/verify_utils.cc @@ -80,6 +80,10 @@ void from_json(const nlohmann::json& j, AbsErrorVerifyInfo& absErrorInfo) { j.at("lower_bound").get_to(absErrorInfo.lowerBound); } + if (j.contains("normal_divisor")) + { + j.at("normal_divisor").get_to(absErrorInfo.normalDivisor); + } } void from_json(const nlohmann::json& j, RelativeVerifyInfo& rInfo) @@ -108,7 +112,8 @@ void from_json(const nlohmann::json& j, VerifyConfig& cfg) { j.at("reduce_product_info").get_to(cfg.reduceProductInfo); } - cfg.absErrorInfo.lowerBound = 0; + cfg.absErrorInfo.lowerBound = 0; + cfg.absErrorInfo.normalDivisor = 1; if (j.contains("abs_error_info")) { j.at("abs_error_info").get_to(cfg.absErrorInfo); diff --git a/reference_model/src/verify/verify_utils.h b/reference_model/src/verify/verify_utils.h index f53838a..9144317 100644 --- a/reference_model/src/verify/verify_utils.h +++ b/reference_model/src/verify/verify_utils.h @@ -82,6 +82,7 @@ struct AbsErrorVerifyInfo AbsErrorVerifyInfo() = default; double lowerBound; + double normalDivisor; }; /// \brief relative verification meta-data diff --git a/scripts/schemavalidation/compliance-config.schema.json b/scripts/schemavalidation/compliance-config.schema.json index eb3ccde..f4a310c 100644 --- a/scripts/schemavalidation/compliance-config.schema.json +++ b/scripts/schemavalidation/compliance-config.schema.json @@ -70,6 +70,10 @@ "lower_bound": { "description": "lower bound multiplier for error bounds", "type": "number" + }, + "normal_divisor": { + "description": "normal_divisor for error bounds", + "type": "number" } }, "additionalProperties": false diff --git a/verif/conformance/test_select.py b/verif/conformance/test_select.py index 55eef58..58d3f9f 100644 --- a/verif/conformance/test_select.py +++ b/verif/conformance/test_select.py @@ -848,6 +848,18 @@ class RsqrtOperator(Operator): name = "rsqrt" +class CosOperator(Operator): + """Test selector for the COS operator.""" + + name = "cos" + + +class SinOperator(Operator): + """Test selector for the SIN operator.""" + + name = "sin" + + class ScatterOperator(Operator): """Test selector for the SCATTER operator.""" diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json index 18e078a..9d68dbf 100644 --- a/verif/conformance/tosa_main_profile_ops_info.json +++ b/verif/conformance/tosa_main_profile_ops_info.json @@ -2269,6 +2269,130 @@ } } }, + "cos": { + "group": "ew_unary", + "profile": [ + "tosa-mi" + ], + "support_for": [ "lazy_data_gen" ], + "generation": { + "standard": { + "generator_args": [ + [ + "--target-dtype", + "fp32", + "--target-dtype", + "fp16", + "--target-dtype", + "bf16", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "15,64", + "--target-rank", + "1", + "--target-rank", + "2", + "--target-rank", + "3" + ], + [ + "--target-dtype", + "fp16", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "1,15", + "--target-rank", + "4", + "--target-rank", + "5" + ], + [ + "--target-dtype", + "fp32", + "--fp-values-range", + "-max,max", + "--target-shape", + "2,1,65537,1", + "--target-shape", + "3,1,65539,2,1" + ] + ] + } + }, + "selection": { + "default": { + "params": {}, + "permutes": [ + "shape", + "type" + ] + } + } + }, + "sin": { + "group": "ew_unary", + "profile": [ + "tosa-mi" + ], + "support_for": [ "lazy_data_gen" ], + "generation": { + "standard": { + "generator_args": [ + [ + "--target-dtype", + "fp32", + "--target-dtype", + "fp16", + "--target-dtype", + "bf16", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "15,64", + "--target-rank", + "1", + "--target-rank", + "2", + "--target-rank", + "3" + ], + [ + "--target-dtype", + "fp32", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "1,15", + "--target-rank", + "4", + "--target-rank", + "5" + ], + [ + "--target-dtype", + "fp16", + "--fp-values-range", + "-max,max", + "--target-shape", + "3,1,65534,2", + "--target-shape", + "65533,1,3,2,1" + ] + ] + } + }, + "selection": { + "default": { + "params": {}, + "permutes": [ + "shape", + "type" + ] + } + } + }, "rsqrt": { "group": "ew_unary", "profile": [ diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index b472087..978e735 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -392,6 +392,12 @@ class TosaTestGen: compliance_tens["abs_error_info"] = { "lower_bound": op["compliance"]["abs_error_lower_bound"] } + elif op["op"] in (Op.SIN, Op.COS): + mode = gtu.ComplianceMode.ABS_ERROR + if "compliance" in op and "abs_error_normal_divisor" in op["compliance"]: + compliance_tens["abs_error_info"] = { + "normal_divisor": op["compliance"]["abs_error_normal_divisor"] + } else: mode = gtu.ComplianceMode.EXACT compliance_tens["mode"] = gtu.ComplianceMode(mode).name @@ -4036,6 +4042,27 @@ class TosaTestGen: TosaErrorValidator.evWrongOutputList, ), }, + "cos": { + "op": Op.COS, + "operands": (1, 0), + "build_fcn": ( + build_unary, + TosaTensorGen.tgBasic, + TosaTensorValuesGen.tvgLazyGenDefault, + TosaArgGen.agNone, + ), + "types": TYPE_FP, + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, + "compliance": {"abs_error_normal_divisor": 2}, + }, "exp": { "op": Op.EXP, "operands": (1, 0), @@ -4180,6 +4207,27 @@ class TosaTestGen: }, "compliance": {"ulp": 2}, }, + "sin": { + "op": Op.SIN, + "operands": (1, 0), + "build_fcn": ( + build_unary, + TosaTensorGen.tgBasic, + TosaTensorValuesGen.tvgLazyGenDefault, + TosaArgGen.agNone, + ), + "types": TYPE_FP, + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + ), + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, + "compliance": {"abs_error_normal_divisor": 2}, + }, # Elementwise Ternary operators "select": { "op": Op.SELECT, -- cgit v1.2.1