aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops')
-rw-r--r--reference_model/src/ops/ewise_unary.cc68
-rw-r--r--reference_model/src/ops/ewise_unary.h4
-rw-r--r--reference_model/src/ops/op_factory.cc12
3 files changed, 82 insertions, 2 deletions
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc
index d92cde1..dd9ea5a 100644
--- a/reference_model/src/ops/ewise_unary.cc
+++ b/reference_model/src/ops/ewise_unary.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -162,6 +162,34 @@ int OpClz<Rank, Dtype>::register_fcn()
}
template <int Rank, TOSA_REF_TYPE Dtype>
+int OpCos<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
+ this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(cos(a)); };
+ break;
+ case TOSA_REF_TYPE_FP64:
+ if (g_func_config.abs_mode)
+ {
+ // ABS_ERROR bounds return 1.0
+ this->fcn = [](InEigenType a) -> OutEigenType { return 1.0; };
+ }
+ else
+ {
+ this->fcn = [](InEigenType a) -> OutEigenType { return cos(a); };
+ };
+ break;
+ default:
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
+ }
+
+ return 0;
+}
+
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpExp<Rank, Dtype>::register_fcn()
{
switch (Dtype)
@@ -366,6 +394,34 @@ int OpRsqrt<Rank, Dtype>::register_fcn()
return 0;
}
+template <int Rank, TOSA_REF_TYPE Dtype>
+int OpSin<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
+ this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(sin(a)); };
+ break;
+ case TOSA_REF_TYPE_FP64:
+ if (g_func_config.abs_mode)
+ {
+ // ABS_ERROR bounds return 1.0
+ this->fcn = [](InEigenType a) -> OutEigenType { return 1.0; };
+ }
+ else
+ {
+ this->fcn = [](InEigenType a) -> OutEigenType { return sin(a); };
+ };
+ break;
+ default:
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
+ }
+
+ return 0;
+}
+
// template explicit instantiation
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, FP16);
@@ -393,6 +449,11 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, BF16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP64);
+
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32);
@@ -423,6 +484,11 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP64);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, BF16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP64);
+
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32);
diff --git a/reference_model/src/ops/ewise_unary.h b/reference_model/src/ops/ewise_unary.h
index 21ee276..a447388 100644
--- a/reference_model/src/ops/ewise_unary.h
+++ b/reference_model/src/ops/ewise_unary.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -65,12 +65,14 @@ DEF_TEMPLATE_UNARY_OP(Abs, ABS)
DEF_TEMPLATE_UNARY_OP(BitwiseNot, BITWISE_NOT)
DEF_TEMPLATE_UNARY_OP(Ceil, CEIL)
DEF_TEMPLATE_UNARY_OP(Clz, CLZ)
+DEF_TEMPLATE_UNARY_OP(Cos, COS)
DEF_TEMPLATE_UNARY_OP(Exp, EXP)
DEF_TEMPLATE_UNARY_OP(Floor, FLOOR)
DEF_TEMPLATE_UNARY_OP(Log, LOG)
DEF_TEMPLATE_UNARY_OP(LogicalNot, LOGICAL_NOT)
DEF_TEMPLATE_UNARY_OP(Reciprocal, RECIPROCAL)
DEF_TEMPLATE_UNARY_OP(Rsqrt, RSQRT)
+DEF_TEMPLATE_UNARY_OP(Sin, SIN)
#undef DEF_TEMPLATE_UNARY_OP
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 6d66c07..1891ff4 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -299,6 +299,12 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
case Op_CLZ:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
break;
+ case Op_COS:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, BF16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCos, FP64);
+ break;
case Op_EXP:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16);
@@ -341,6 +347,12 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP64);
break;
+ case Op_SIN:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, BF16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSin, FP64);
+ break;
// ewise_ternary
case Op_SELECT: