diff options
author | Tai Ly <tai.ly@arm.com> | 2023-03-28 22:06:56 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2023-05-05 19:23:15 +0000 |
commit | a4d748b08accce06fab93e2d2b96e499b35ae89b (patch) | |
tree | 20a3957e1f45f65f35d5d67ecce1618659e388f0 /reference_model/src/ops/type_conversion.cc | |
parent | 0c71686875618b2e11290273b7a05b88ef8a8aae (diff) | |
download | reference_model-a4d748b08accce06fab93e2d2b96e499b35ae89b.tar.gz |
[reference model] Add precise mode
This adds --precise_mode=1 option to tosa_referece_model,
which will cause reference model to convert all floating point tensors
to FP64 tensors and compute all operators accordingly.
Also adds optional -p arguments to test runners tosa_verif_run_tests.py
and tosa_verif_framework_compiler_runner.py to run tests in precise mode
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I156055216ad61710096497a8fa1a653be2a602a3
Diffstat (limited to 'reference_model/src/ops/type_conversion.cc')
-rw-r--r-- | reference_model/src/ops/type_conversion.cc | 116 |
1 files changed, 75 insertions, 41 deletions
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 9034add..68ffb1f 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -35,14 +35,14 @@ OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Rescale); } -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> OpRescale<Rank, InDtype, OutDtype>::~OpRescale() { if (attribute) delete attribute; } -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes() { // Check Tosa Level @@ -69,31 +69,33 @@ int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes() ASSERT_MEM(in && out); - if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (InDtype != DType_UINT16) && (attribute->input_zp() != 0)) + if ((InDtype != TOSA_REF_TYPE_INT8) && (InDtype != TOSA_REF_TYPE_UINT8) && (InDtype != TOSA_REF_TYPE_UINT16) && + (attribute->input_zp() != 0)) { - printNodeValidationError("OpRescale: Input DType not INT8/UINT8/UINT16 and zero point not 0"); + printNodeValidationError("OpRescale: Input TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0"); return 1; } - if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (OutDtype != DType_UINT16) && (attribute->output_zp() != 0)) + if ((OutDtype != TOSA_REF_TYPE_INT8) && (OutDtype != TOSA_REF_TYPE_UINT8) && (OutDtype != TOSA_REF_TYPE_UINT16) && + (attribute->output_zp() != 0)) { - printNodeValidationError("OpRescale: Output DType not INT8/UINT8/UINT16 and zero point not 0"); + printNodeValidationError("OpRescale: Output TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0"); return 1; } - if ((InDtype == DType_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768))) + if ((InDtype == TOSA_REF_TYPE_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768))) { - printNodeValidationError("OpRescale: Input DType UINT16 and zero point not 0 or 32768"); + printNodeValidationError("OpRescale: Input TOSA_REF_TYPE UINT16 and zero point not 0 or 32768"); return 1; } - if ((OutDtype == DType_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768))) + if ((OutDtype == TOSA_REF_TYPE_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768))) { - printNodeValidationError("OpRescale: Output DType UINT16 and zero point not 0 or 32768"); + printNodeValidationError("OpRescale: Output TOSA_REF_TYPE UINT16 and zero point not 0 or 32768"); return 1; } - if (attribute->scale32() && (InDtype == DType_INT48)) + if (attribute->scale32() && (InDtype == TOSA_REF_TYPE_INT48)) { printNodeValidationError("OpRescale: Scale set to true but input type is INT48"); return 1; @@ -108,7 +110,7 @@ int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes() return 0; } -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> int OpRescale<Rank, InDtype, OutDtype>::eval() { int32_t input_zp = attribute->input_zp(); @@ -237,7 +239,7 @@ int OpRescale<Rank, InDtype, OutDtype>::eval() return GraphNode::eval(); } -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> OpCast<Rank, InDtype, OutDtype>::OpCast(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -247,11 +249,11 @@ OpCast<Rank, InDtype, OutDtype>::OpCast(SubgraphTraverser* sgt_, setRequiredRank(0, 6); } -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> OpCast<Rank, InDtype, OutDtype>::~OpCast() {} -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> int OpCast<Rank, InDtype, OutDtype>::checkTensorAttributes() { // Check Tosa Level @@ -281,7 +283,7 @@ int OpCast<Rank, InDtype, OutDtype>::checkTensorAttributes() return 0; } -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> int OpCast<Rank, InDtype, OutDtype>::eval() { this->out->getTensor() = this->in->getTensor().unaryExpr(cast_helper.get_fcn()); @@ -289,7 +291,7 @@ int OpCast<Rank, InDtype, OutDtype>::eval() return GraphNode::eval(); } -template <DType InDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> CastHelper<InDtype, OutDtype>::CastHelper() { fcn = [](InEigenType in) -> OutEigenType { @@ -298,14 +300,14 @@ CastHelper<InDtype, OutDtype>::CastHelper() }; } -template <DType InDtype> -CastHelper<InDtype, DType_BOOL>::CastHelper() +template <TOSA_REF_TYPE InDtype> +CastHelper<InDtype, TOSA_REF_TYPE_BOOL>::CastHelper() { fcn = [](InEigenType in) -> bool { return (in != 0) ? true : false; }; } -template <DType OutDtype> -CastHelper<DType_BOOL, OutDtype>::CastHelper() +template <TOSA_REF_TYPE OutDtype> +CastHelper<TOSA_REF_TYPE_BOOL, OutDtype>::CastHelper() { fcn = [](bool in) -> OutEigenType { OutEigenType out = in ? (OutEigenType)1 : (OutEigenType)0; @@ -313,8 +315,8 @@ CastHelper<DType_BOOL, OutDtype>::CastHelper() }; } -template <DType InDtype> -CastHelper<InDtype, DType_FP16>::CastHelper() +template <TOSA_REF_TYPE InDtype> +CastHelper<InDtype, TOSA_REF_TYPE_FP16>::CastHelper() { // Integer data converted to fp16 (stored as fp32) fcn = [](InEigenType in) -> float { @@ -324,17 +326,17 @@ CastHelper<InDtype, DType_FP16>::CastHelper() }; } -CastHelper<DType_FP32, DType_FP16>::CastHelper() +CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP16>::CastHelper() { // fp32 data converted to fp16 (stored as fp32) fcn = [](float in) -> float { - float out = fpTrunc<DType_FP16>(in); // truncate required for conversion from higher precision + float out = fpTrunc<TOSA_REF_TYPE_FP16>(in); // truncate required for conversion from higher precision return out; }; } -template <DType InDtype> -CastHelper<InDtype, DType_BF16>::CastHelper() +template <TOSA_REF_TYPE InDtype> +CastHelper<InDtype, TOSA_REF_TYPE_BF16>::CastHelper() { // Integer data converted to bf16 (stored as fp32) fcn = [](InEigenType in) -> float { @@ -343,16 +345,16 @@ CastHelper<InDtype, DType_BF16>::CastHelper() }; } -CastHelper<DType_FP32, DType_BF16>::CastHelper() +CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_BF16>::CastHelper() { // fp32 data converted to bf16 (stored as fp32) fcn = [](float in) -> float { - return fpTrunc<DType_BF16>(in); // truncate required for conversions from higher precision + return fpTrunc<TOSA_REF_TYPE_BF16>(in); // truncate required for conversions from higher precision }; } -template <DType OutDtype> -CastHelper<DType_FP16, OutDtype>::CastHelper() +template <TOSA_REF_TYPE OutDtype> +CastHelper<TOSA_REF_TYPE_FP16, OutDtype>::CastHelper() { // fp16 data (stored as fp32) converted to integer fcn = [](float in) -> OutEigenType { @@ -366,7 +368,7 @@ CastHelper<DType_FP16, OutDtype>::CastHelper() }; } -CastHelper<DType_FP16, DType_FP32>::CastHelper() +CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP32>::CastHelper() { // No-op since fp16 values treated internally as their fp32 representation fcn = [](float in) -> OutEigenType { @@ -374,8 +376,8 @@ CastHelper<DType_FP16, DType_FP32>::CastHelper() }; } -template <DType OutDtype> -CastHelper<DType_BF16, OutDtype>::CastHelper() +template <TOSA_REF_TYPE OutDtype> +CastHelper<TOSA_REF_TYPE_BF16, OutDtype>::CastHelper() { // bf16 data (stored as fp32) converted to integer fcn = [](float in) -> OutEigenType { @@ -386,7 +388,7 @@ CastHelper<DType_BF16, OutDtype>::CastHelper() }; } -CastHelper<DType_BF16, DType_FP32>::CastHelper() +CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP32>::CastHelper() { // No-op since bf16 values treated as truncated fp32 internally fcn = [](InEigenType in) -> OutEigenType { @@ -394,8 +396,8 @@ CastHelper<DType_BF16, DType_FP32>::CastHelper() }; } -template <DType InDtype> -CastHelper<InDtype, DType_FP32>::CastHelper() +template <TOSA_REF_TYPE InDtype> +CastHelper<InDtype, TOSA_REF_TYPE_FP32>::CastHelper() { // Integer data converted to fp32 fcn = [](InEigenType in) -> float { @@ -404,8 +406,8 @@ CastHelper<InDtype, DType_FP32>::CastHelper() }; } -template <DType OutDtype> -CastHelper<DType_FP32, OutDtype>::CastHelper() +template <TOSA_REF_TYPE OutDtype> +CastHelper<TOSA_REF_TYPE_FP32, OutDtype>::CastHelper() { // fp32 data converted to integer fcn = [](float in) -> OutEigenType { @@ -416,6 +418,31 @@ CastHelper<DType_FP32, OutDtype>::CastHelper() }; } +template <TOSA_REF_TYPE OutDtype> +CastHelper<TOSA_REF_TYPE_FP64, OutDtype>::CastHelper() +{ + switch (OutDtype) + { + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: + // fp64 data converted to integer + fcn = [](InEigenType in) -> OutEigenType { + OutEigenType out = std::rint(in); + out = std::max<OutEigenType>(out, OutMin); + out = std::min<OutEigenType>(out, OutMax); + return out; + }; + break; + case TOSA_REF_TYPE_FP64: + // no op + fcn = [](InEigenType in) -> OutEigenType { return in; }; + break; + default: + ASSERT_MSG(false, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(OutDtype)); + } +} + // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16); @@ -451,6 +478,13 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16); |