aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJerry Ge <jerry.ge@arm.com>2024-02-20 11:21:19 -0800
committerEric Kunze <eric.kunze@arm.com>2024-03-05 18:47:09 +0000
commit51bd4f5166c50a89017307b55dee0f5bda096b7b (patch)
tree84aa5e9bd1dc02856ae10a18f0a923e3c8efbf55
parent1408795800719139e26bafcece88bfc07582576d (diff)
downloadreference_model-51bd4f5166c50a89017307b55dee0f5bda096b7b.tar.gz
Add Tosa Sin/Cos operators
- Add Tosa Sin/Cos operators to reference_model - Add conformances tests Signed-off-by: Jerry Ge <jerry.ge@arm.com> Change-Id: I3f597ddf5dac2c64d6dd6aa15781b40b8468eaa6
-rw-r--r--reference_model/src/generate/generate_utils.cc2
-rw-r--r--reference_model/src/ops/ewise_unary.cc68
-rw-r--r--reference_model/src/ops/ewise_unary.h4
-rw-r--r--reference_model/src/ops/op_factory.cc12
-rw-r--r--reference_model/src/verify/verify_abs_error.cc3
-rw-r--r--reference_model/src/verify/verify_utils.cc7
-rw-r--r--reference_model/src/verify/verify_utils.h1
-rw-r--r--scripts/schemavalidation/compliance-config.schema.json4
-rw-r--r--verif/conformance/test_select.py12
-rw-r--r--verif/conformance/tosa_main_profile_ops_info.json124
-rw-r--r--verif/generator/tosa_test_gen.py48
11 files changed, 280 insertions, 5 deletions
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
index 271b7f5..d0d0194 100644
--- a/reference_model/src/generate/generate_utils.cc
+++ b/reference_model/src/generate/generate_utils.cc
@@ -51,6 +51,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Op,
{ Op::Op_CONCAT, "CONCAT" },
{ Op::Op_CONST, "CONST" },
{ Op::Op_CONV2D, "CONV2D" },
+ { Op::Op_COS, "COS" },
{ Op::Op_DEPTHWISE_CONV2D, "DEPTHWISE_CONV2D" },
{ Op::Op_CONV3D, "CONV3D" },
{ Op::Op_EQUAL, "EQUAL" },
@@ -85,6 +86,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Op,
{ Op::Op_SCATTER, "SCATTER" },
{ Op::Op_SELECT, "SELECT" },
{ Op::Op_SIGMOID, "SIGMOID" },
+ { Op::Op_SIN, "SIN" },
{ Op::Op_SLICE, "SLICE" },
{ Op::Op_SUB, "SUB" },
{ Op::Op_TANH, "TANH" },
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc
index d92cde1..dd9ea5a 100644
--- a/reference_model/src/ops/ewise_unary.cc
+++ b/reference_model/src/ops/ewise_unary.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -162,6 +162,34 @@ int OpClz<Rank, Dtype>::register_fcn()
}
template <int Rank, TOSA_REF_TYPE Dtype>
+int OpCos<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
+ this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(cos(a)); };
+ break;
+ case TOSA_REF_TYPE_FP64:
+ if (g_func_config.abs_mode)
+ {
+ // ABS_ERROR bounds return 1.0
+ this->fcn = [](InEigenType a) -> OutEigenType { return 1.0; };
+ }
+ else
+ {
+ this->fcn = [](InEigenType a) -> OutEigenType { return cos(a); };
+ };
+ break;
+ default:
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
+ }
+
+ return 0;
+}
+
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpExp<Rank, Dtype>::register_fcn()
{
switch (Dtype)
@@ -366,6 +394,34 @@ int OpRsqrt<Rank, Dtype>::register_fcn()
return 0;
}
+template <int Rank, TOSA_REF_TYPE Dtype>
+int OpSin<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
+ this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(sin(a)); };
+ break;
+ case TOSA_REF_TYPE_FP64:
+ if (g_func_config.abs_mode)
+ {
+ // ABS_ERROR bounds return 1.0
+ this->fcn = [](InEigenType a) -> OutEigenType { return 1.0; };
+ }
+ else
+ {
+ this->fcn = [](InEigenType a) -> OutEigenType { return sin(a); };
+ };
+ break;
+ default:
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
+ }
+
+ return 0;
+}
+
// template explicit instantiation
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, FP16);
@@ -393,6 +449,11 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, BF16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP64);
+
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32);
@@ -423,6 +484,11 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP64);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, BF16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP64);
+
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32);
diff --git a/reference_model/src/ops/ewise_unary.h b/reference_model/src/ops/ewise_unary.h
index 21ee276..a447388 100644
--- a/reference_model/src/ops/ewise_unary.h
+++ b/reference_model/src/ops/ewise_unary.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -65,12 +65,14 @@ DEF_TEMPLATE_UNARY_OP(Abs, ABS)
DEF_TEMPLATE_UNARY_OP(BitwiseNot, BITWISE_NOT)
DEF_TEMPLATE_UNARY_OP(Ceil, CEIL)
DEF_TEMPLATE_UNARY_OP(Clz, CLZ)
+DEF_TEMPLATE_UNARY_OP(Cos, COS)
DEF_TEMPLATE_UNARY_OP(Exp, EXP)
DEF_TEMPLATE_UNARY_OP(Floor, FLOOR)
DEF_TEMPLATE_UNARY_OP(Log, LOG)
DEF_TEMPLATE_UNARY_OP(LogicalNot, LOGICAL_NOT)
DEF_TEMPLATE_UNARY_OP(Reciprocal, RECIPROCAL)
DEF_TEMPLATE_UNARY_OP(Rsqrt, RSQRT)
+DEF_TEMPLATE_UNARY_OP(Sin, SIN)
#undef DEF_TEMPLATE_UNARY_OP
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 6d66c07..1891ff4 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -299,6 +299,12 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
case Op_CLZ:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
break;
+ case Op_COS:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, BF16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP64);
+ break;
case Op_EXP:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16);
@@ -341,6 +347,12 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP64);
break;
+ case Op_SIN:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, BF16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP64);
+ break;
// ewise_ternary
case Op_SELECT:
diff --git a/reference_model/src/verify/verify_abs_error.cc b/reference_model/src/verify/verify_abs_error.cc
index a7b7bc2..125045e 100644
--- a/reference_model/src/verify/verify_abs_error.cc
+++ b/reference_model/src/verify/verify_abs_error.cc
@@ -20,7 +20,6 @@
#include "half.hpp"
#include "verifiers.h"
-
namespace TosaReference
{
@@ -36,7 +35,7 @@ double calcErrorBound(double referenceValue, double boundsValue, const void* cfg
{
valBound = std::max(cfg->lowerBound, valBound);
}
- return exp2(-AccPrecision<OutType>::normal_frac) * valBound;
+ return exp2(-AccPrecision<OutType>::normal_frac / cfg->normalDivisor) * valBound;
}
} // namespace
diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc
index d598e2c..50a98e5 100644
--- a/reference_model/src/verify/verify_utils.cc
+++ b/reference_model/src/verify/verify_utils.cc
@@ -80,6 +80,10 @@ void from_json(const nlohmann::json& j, AbsErrorVerifyInfo& absErrorInfo)
{
j.at("lower_bound").get_to(absErrorInfo.lowerBound);
}
+ if (j.contains("normal_divisor"))
+ {
+ j.at("normal_divisor").get_to(absErrorInfo.normalDivisor);
+ }
}
void from_json(const nlohmann::json& j, RelativeVerifyInfo& rInfo)
@@ -108,7 +112,8 @@ void from_json(const nlohmann::json& j, VerifyConfig& cfg)
{
j.at("reduce_product_info").get_to(cfg.reduceProductInfo);
}
- cfg.absErrorInfo.lowerBound = 0;
+ cfg.absErrorInfo.lowerBound = 0;
+ cfg.absErrorInfo.normalDivisor = 1;
if (j.contains("abs_error_info"))
{
j.at("abs_error_info").get_to(cfg.absErrorInfo);
diff --git a/reference_model/src/verify/verify_utils.h b/reference_model/src/verify/verify_utils.h
index f53838a..9144317 100644
--- a/reference_model/src/verify/verify_utils.h
+++ b/reference_model/src/verify/verify_utils.h
@@ -82,6 +82,7 @@ struct AbsErrorVerifyInfo
AbsErrorVerifyInfo() = default;
double lowerBound;
+ double normalDivisor;
};
/// \brief relative verification meta-data
diff --git a/scripts/schemavalidation/compliance-config.schema.json b/scripts/schemavalidation/compliance-config.schema.json
index eb3ccde..f4a310c 100644
--- a/scripts/schemavalidation/compliance-config.schema.json
+++ b/scripts/schemavalidation/compliance-config.schema.json
@@ -70,6 +70,10 @@
"lower_bound": {
"description": "lower bound multiplier for error bounds",
"type": "number"
+ },
+ "normal_divisor": {
+ "description": "normal_divisor for error bounds",
+ "type": "number"
}
},
"additionalProperties": false
diff --git a/verif/conformance/test_select.py b/verif/conformance/test_select.py
index 55eef58..58d3f9f 100644
--- a/verif/conformance/test_select.py
+++ b/verif/conformance/test_select.py
@@ -848,6 +848,18 @@ class RsqrtOperator(Operator):
name = "rsqrt"
+class CosOperator(Operator):
+ """Test selector for the COS operator."""
+
+ name = "cos"
+
+
+class SinOperator(Operator):
+ """Test selector for the SIN operator."""
+
+ name = "sin"
+
+
class ScatterOperator(Operator):
"""Test selector for the SCATTER operator."""
diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json
index 18e078a..9d68dbf 100644
--- a/verif/conformance/tosa_main_profile_ops_info.json
+++ b/verif/conformance/tosa_main_profile_ops_info.json
@@ -2269,6 +2269,130 @@
}
}
},
+ "cos": {
+ "group": "ew_unary",
+ "profile": [
+ "tosa-mi"
+ ],
+ "support_for": [ "lazy_data_gen" ],
+ "generation": {
+ "standard": {
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp32",
+ "--target-dtype",
+ "fp16",
+ "--target-dtype",
+ "bf16",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "15,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3"
+ ],
+ [
+ "--target-dtype",
+ "fp16",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "1,15",
+ "--target-rank",
+ "4",
+ "--target-rank",
+ "5"
+ ],
+ [
+ "--target-dtype",
+ "fp32",
+ "--fp-values-range",
+ "-max,max",
+ "--target-shape",
+ "2,1,65537,1",
+ "--target-shape",
+ "3,1,65539,2,1"
+ ]
+ ]
+ }
+ },
+ "selection": {
+ "default": {
+ "params": {},
+ "permutes": [
+ "shape",
+ "type"
+ ]
+ }
+ }
+ },
+ "sin": {
+ "group": "ew_unary",
+ "profile": [
+ "tosa-mi"
+ ],
+ "support_for": [ "lazy_data_gen" ],
+ "generation": {
+ "standard": {
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp32",
+ "--target-dtype",
+ "fp16",
+ "--target-dtype",
+ "bf16",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "15,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3"
+ ],
+ [
+ "--target-dtype",
+ "fp32",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "1,15",
+ "--target-rank",
+ "4",
+ "--target-rank",
+ "5"
+ ],
+ [
+ "--target-dtype",
+ "fp16",
+ "--fp-values-range",
+ "-max,max",
+ "--target-shape",
+ "3,1,65534,2",
+ "--target-shape",
+ "65533,1,3,2,1"
+ ]
+ ]
+ }
+ },
+ "selection": {
+ "default": {
+ "params": {},
+ "permutes": [
+ "shape",
+ "type"
+ ]
+ }
+ }
+ },
"rsqrt": {
"group": "ew_unary",
"profile": [
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index b472087..978e735 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -392,6 +392,12 @@ class TosaTestGen:
compliance_tens["abs_error_info"] = {
"lower_bound": op["compliance"]["abs_error_lower_bound"]
}
+ elif op["op"] in (Op.SIN, Op.COS):
+ mode = gtu.ComplianceMode.ABS_ERROR
+ if "compliance" in op and "abs_error_normal_divisor" in op["compliance"]:
+ compliance_tens["abs_error_info"] = {
+ "normal_divisor": op["compliance"]["abs_error_normal_divisor"]
+ }
else:
mode = gtu.ComplianceMode.EXACT
compliance_tens["mode"] = gtu.ComplianceMode(mode).name
@@ -4036,6 +4042,27 @@ class TosaTestGen:
TosaErrorValidator.evWrongOutputList,
),
},
+ "cos": {
+ "op": Op.COS,
+ "operands": (1, 0),
+ "build_fcn": (
+ build_unary,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgLazyGenDefault,
+ TosaArgGen.agNone,
+ ),
+ "types": TYPE_FP,
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
+ "data_gen": {
+ "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
+ },
+ "compliance": {"abs_error_normal_divisor": 2},
+ },
"exp": {
"op": Op.EXP,
"operands": (1, 0),
@@ -4180,6 +4207,27 @@ class TosaTestGen:
},
"compliance": {"ulp": 2},
},
+ "sin": {
+ "op": Op.SIN,
+ "operands": (1, 0),
+ "build_fcn": (
+ build_unary,
+ TosaTensorGen.tgBasic,
+ TosaTensorValuesGen.tvgLazyGenDefault,
+ TosaArgGen.agNone,
+ ),
+ "types": TYPE_FP,
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ ),
+ "data_gen": {
+ "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
+ },
+ "compliance": {"abs_error_normal_divisor": 2},
+ },
# Elementwise Ternary operators
"select": {
"op": Op.SELECT,