aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/type_conversion.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/type_conversion.cc')
-rw-r--r--reference_model/src/ops/type_conversion.cc116
1 files changed, 75 insertions, 41 deletions
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 9034add..68ffb1f 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -24,7 +24,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -35,14 +35,14 @@ OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Rescale);
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
OpRescale<Rank, InDtype, OutDtype>::~OpRescale()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -69,31 +69,33 @@ int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
ASSERT_MEM(in && out);
- if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (InDtype != DType_UINT16) && (attribute->input_zp() != 0))
+ if ((InDtype != TOSA_REF_TYPE_INT8) && (InDtype != TOSA_REF_TYPE_UINT8) && (InDtype != TOSA_REF_TYPE_UINT16) &&
+ (attribute->input_zp() != 0))
{
- printNodeValidationError("OpRescale: Input DType not INT8/UINT8/UINT16 and zero point not 0");
+ printNodeValidationError("OpRescale: Input TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0");
return 1;
}
- if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (OutDtype != DType_UINT16) && (attribute->output_zp() != 0))
+ if ((OutDtype != TOSA_REF_TYPE_INT8) && (OutDtype != TOSA_REF_TYPE_UINT8) && (OutDtype != TOSA_REF_TYPE_UINT16) &&
+ (attribute->output_zp() != 0))
{
- printNodeValidationError("OpRescale: Output DType not INT8/UINT8/UINT16 and zero point not 0");
+ printNodeValidationError("OpRescale: Output TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0");
return 1;
}
- if ((InDtype == DType_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768)))
+ if ((InDtype == TOSA_REF_TYPE_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768)))
{
- printNodeValidationError("OpRescale: Input DType UINT16 and zero point not 0 or 32768");
+ printNodeValidationError("OpRescale: Input TOSA_REF_TYPE UINT16 and zero point not 0 or 32768");
return 1;
}
- if ((OutDtype == DType_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768)))
+ if ((OutDtype == TOSA_REF_TYPE_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768)))
{
- printNodeValidationError("OpRescale: Output DType UINT16 and zero point not 0 or 32768");
+ printNodeValidationError("OpRescale: Output TOSA_REF_TYPE UINT16 and zero point not 0 or 32768");
return 1;
}
- if (attribute->scale32() && (InDtype == DType_INT48))
+ if (attribute->scale32() && (InDtype == TOSA_REF_TYPE_INT48))
{
printNodeValidationError("OpRescale: Scale set to true but input type is INT48");
return 1;
@@ -108,7 +110,7 @@ int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int OpRescale<Rank, InDtype, OutDtype>::eval()
{
int32_t input_zp = attribute->input_zp();
@@ -237,7 +239,7 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
OpCast<Rank, InDtype, OutDtype>::OpCast(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -247,11 +249,11 @@ OpCast<Rank, InDtype, OutDtype>::OpCast(SubgraphTraverser* sgt_,
setRequiredRank(0, 6);
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
OpCast<Rank, InDtype, OutDtype>::~OpCast()
{}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int OpCast<Rank, InDtype, OutDtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -281,7 +283,7 @@ int OpCast<Rank, InDtype, OutDtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int OpCast<Rank, InDtype, OutDtype>::eval()
{
this->out->getTensor() = this->in->getTensor().unaryExpr(cast_helper.get_fcn());
@@ -289,7 +291,7 @@ int OpCast<Rank, InDtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
CastHelper<InDtype, OutDtype>::CastHelper()
{
fcn = [](InEigenType in) -> OutEigenType {
@@ -298,14 +300,14 @@ CastHelper<InDtype, OutDtype>::CastHelper()
};
}
-template <DType InDtype>
-CastHelper<InDtype, DType_BOOL>::CastHelper()
+template <TOSA_REF_TYPE InDtype>
+CastHelper<InDtype, TOSA_REF_TYPE_BOOL>::CastHelper()
{
fcn = [](InEigenType in) -> bool { return (in != 0) ? true : false; };
}
-template <DType OutDtype>
-CastHelper<DType_BOOL, OutDtype>::CastHelper()
+template <TOSA_REF_TYPE OutDtype>
+CastHelper<TOSA_REF_TYPE_BOOL, OutDtype>::CastHelper()
{
fcn = [](bool in) -> OutEigenType {
OutEigenType out = in ? (OutEigenType)1 : (OutEigenType)0;
@@ -313,8 +315,8 @@ CastHelper<DType_BOOL, OutDtype>::CastHelper()
};
}
-template <DType InDtype>
-CastHelper<InDtype, DType_FP16>::CastHelper()
+template <TOSA_REF_TYPE InDtype>
+CastHelper<InDtype, TOSA_REF_TYPE_FP16>::CastHelper()
{
// Integer data converted to fp16 (stored as fp32)
fcn = [](InEigenType in) -> float {
@@ -324,17 +326,17 @@ CastHelper<InDtype, DType_FP16>::CastHelper()
};
}
-CastHelper<DType_FP32, DType_FP16>::CastHelper()
+CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP16>::CastHelper()
{
// fp32 data converted to fp16 (stored as fp32)
fcn = [](float in) -> float {
- float out = fpTrunc<DType_FP16>(in); // truncate required for conversion from higher precision
+ float out = fpTrunc<TOSA_REF_TYPE_FP16>(in); // truncate required for conversion from higher precision
return out;
};
}
-template <DType InDtype>
-CastHelper<InDtype, DType_BF16>::CastHelper()
+template <TOSA_REF_TYPE InDtype>
+CastHelper<InDtype, TOSA_REF_TYPE_BF16>::CastHelper()
{
// Integer data converted to bf16 (stored as fp32)
fcn = [](InEigenType in) -> float {
@@ -343,16 +345,16 @@ CastHelper<InDtype, DType_BF16>::CastHelper()
};
}
-CastHelper<DType_FP32, DType_BF16>::CastHelper()
+CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_BF16>::CastHelper()
{
// fp32 data converted to bf16 (stored as fp32)
fcn = [](float in) -> float {
- return fpTrunc<DType_BF16>(in); // truncate required for conversions from higher precision
+ return fpTrunc<TOSA_REF_TYPE_BF16>(in); // truncate required for conversions from higher precision
};
}
-template <DType OutDtype>
-CastHelper<DType_FP16, OutDtype>::CastHelper()
+template <TOSA_REF_TYPE OutDtype>
+CastHelper<TOSA_REF_TYPE_FP16, OutDtype>::CastHelper()
{
// fp16 data (stored as fp32) converted to integer
fcn = [](float in) -> OutEigenType {
@@ -366,7 +368,7 @@ CastHelper<DType_FP16, OutDtype>::CastHelper()
};
}
-CastHelper<DType_FP16, DType_FP32>::CastHelper()
+CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP32>::CastHelper()
{
// No-op since fp16 values treated internally as their fp32 representation
fcn = [](float in) -> OutEigenType {
@@ -374,8 +376,8 @@ CastHelper<DType_FP16, DType_FP32>::CastHelper()
};
}
-template <DType OutDtype>
-CastHelper<DType_BF16, OutDtype>::CastHelper()
+template <TOSA_REF_TYPE OutDtype>
+CastHelper<TOSA_REF_TYPE_BF16, OutDtype>::CastHelper()
{
// bf16 data (stored as fp32) converted to integer
fcn = [](float in) -> OutEigenType {
@@ -386,7 +388,7 @@ CastHelper<DType_BF16, OutDtype>::CastHelper()
};
}
-CastHelper<DType_BF16, DType_FP32>::CastHelper()
+CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP32>::CastHelper()
{
// No-op since bf16 values treated as truncated fp32 internally
fcn = [](InEigenType in) -> OutEigenType {
@@ -394,8 +396,8 @@ CastHelper<DType_BF16, DType_FP32>::CastHelper()
};
}
-template <DType InDtype>
-CastHelper<InDtype, DType_FP32>::CastHelper()
+template <TOSA_REF_TYPE InDtype>
+CastHelper<InDtype, TOSA_REF_TYPE_FP32>::CastHelper()
{
// Integer data converted to fp32
fcn = [](InEigenType in) -> float {
@@ -404,8 +406,8 @@ CastHelper<InDtype, DType_FP32>::CastHelper()
};
}
-template <DType OutDtype>
-CastHelper<DType_FP32, OutDtype>::CastHelper()
+template <TOSA_REF_TYPE OutDtype>
+CastHelper<TOSA_REF_TYPE_FP32, OutDtype>::CastHelper()
{
// fp32 data converted to integer
fcn = [](float in) -> OutEigenType {
@@ -416,6 +418,31 @@ CastHelper<DType_FP32, OutDtype>::CastHelper()
};
}
+template <TOSA_REF_TYPE OutDtype>
+CastHelper<TOSA_REF_TYPE_FP64, OutDtype>::CastHelper()
+{
+ switch (OutDtype)
+ {
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
+ case TOSA_REF_TYPE_INT32:
+ // fp64 data converted to integer
+ fcn = [](InEigenType in) -> OutEigenType {
+ OutEigenType out = std::rint(in);
+ out = std::max<OutEigenType>(out, OutMin);
+ out = std::min<OutEigenType>(out, OutMax);
+ return out;
+ };
+ break;
+ case TOSA_REF_TYPE_FP64:
+ // no op
+ fcn = [](InEigenType in) -> OutEigenType { return in; };
+ break;
+ default:
+ ASSERT_MSG(false, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(OutDtype));
+ }
+}
+
// template explicit instantiation
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16);
@@ -451,6 +478,13 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16);