From a4d748b08accce06fab93e2d2b96e499b35ae89b Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Tue, 28 Mar 2023 22:06:56 +0000 Subject: [reference model] Add precise mode This adds --precise_mode=1 option to tosa_referece_model, which will cause reference model to convert all floating point tensors to FP64 tensors and compute all operators accordingly. Also adds optional -p arguments to test runners tosa_verif_run_tests.py and tosa_verif_framework_compiler_runner.py to run tests in precise mode Signed-off-by: Tai Ly Change-Id: I156055216ad61710096497a8fa1a653be2a602a3 --- reference_model/src/ops/type_conversion.cc | 116 +++++++++++++++++++---------- 1 file changed, 75 insertions(+), 41 deletions(-) (limited to 'reference_model/src/ops/type_conversion.cc') diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 9034add..68ffb1f 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template +template OpRescale::OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -35,14 +35,14 @@ OpRescale::OpRescale(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Rescale); } -template +template OpRescale::~OpRescale() { if (attribute) delete attribute; } -template +template int OpRescale::checkTensorAttributes() { // Check Tosa Level @@ -69,31 +69,33 @@ int OpRescale::checkTensorAttributes() ASSERT_MEM(in && out); - if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (InDtype != DType_UINT16) && (attribute->input_zp() != 0)) + if ((InDtype != TOSA_REF_TYPE_INT8) && (InDtype != TOSA_REF_TYPE_UINT8) && (InDtype != TOSA_REF_TYPE_UINT16) && + (attribute->input_zp() != 0)) { - printNodeValidationError("OpRescale: Input DType not INT8/UINT8/UINT16 and zero point not 0"); + printNodeValidationError("OpRescale: Input TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0"); return 1; } - if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (OutDtype != DType_UINT16) && (attribute->output_zp() != 0)) + if ((OutDtype != TOSA_REF_TYPE_INT8) && (OutDtype != TOSA_REF_TYPE_UINT8) && (OutDtype != TOSA_REF_TYPE_UINT16) && + (attribute->output_zp() != 0)) { - printNodeValidationError("OpRescale: Output DType not INT8/UINT8/UINT16 and zero point not 0"); + printNodeValidationError("OpRescale: Output TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0"); return 1; } - if ((InDtype == DType_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768))) + if ((InDtype == TOSA_REF_TYPE_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768))) { - printNodeValidationError("OpRescale: Input DType UINT16 and zero point not 0 or 32768"); + printNodeValidationError("OpRescale: Input TOSA_REF_TYPE UINT16 and zero point not 0 or 32768"); return 1; } - if ((OutDtype == DType_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768))) + if ((OutDtype == TOSA_REF_TYPE_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768))) { - printNodeValidationError("OpRescale: Output DType UINT16 and zero point not 0 or 32768"); + printNodeValidationError("OpRescale: Output TOSA_REF_TYPE UINT16 and zero point not 0 or 32768"); return 1; } - if (attribute->scale32() && (InDtype == DType_INT48)) + if (attribute->scale32() && (InDtype == TOSA_REF_TYPE_INT48)) { printNodeValidationError("OpRescale: Scale set to true but input type is INT48"); return 1; @@ -108,7 +110,7 @@ int OpRescale::checkTensorAttributes() return 0; } -template +template int OpRescale::eval() { int32_t input_zp = attribute->input_zp(); @@ -237,7 +239,7 @@ int OpRescale::eval() return GraphNode::eval(); } -template +template OpCast::OpCast(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -247,11 +249,11 @@ OpCast::OpCast(SubgraphTraverser* sgt_, setRequiredRank(0, 6); } -template +template OpCast::~OpCast() {} -template +template int OpCast::checkTensorAttributes() { // Check Tosa Level @@ -281,7 +283,7 @@ int OpCast::checkTensorAttributes() return 0; } -template +template int OpCast::eval() { this->out->getTensor() = this->in->getTensor().unaryExpr(cast_helper.get_fcn()); @@ -289,7 +291,7 @@ int OpCast::eval() return GraphNode::eval(); } -template +template CastHelper::CastHelper() { fcn = [](InEigenType in) -> OutEigenType { @@ -298,14 +300,14 @@ CastHelper::CastHelper() }; } -template -CastHelper::CastHelper() +template +CastHelper::CastHelper() { fcn = [](InEigenType in) -> bool { return (in != 0) ? true : false; }; } -template -CastHelper::CastHelper() +template +CastHelper::CastHelper() { fcn = [](bool in) -> OutEigenType { OutEigenType out = in ? (OutEigenType)1 : (OutEigenType)0; @@ -313,8 +315,8 @@ CastHelper::CastHelper() }; } -template -CastHelper::CastHelper() +template +CastHelper::CastHelper() { // Integer data converted to fp16 (stored as fp32) fcn = [](InEigenType in) -> float { @@ -324,17 +326,17 @@ CastHelper::CastHelper() }; } -CastHelper::CastHelper() +CastHelper::CastHelper() { // fp32 data converted to fp16 (stored as fp32) fcn = [](float in) -> float { - float out = fpTrunc(in); // truncate required for conversion from higher precision + float out = fpTrunc(in); // truncate required for conversion from higher precision return out; }; } -template -CastHelper::CastHelper() +template +CastHelper::CastHelper() { // Integer data converted to bf16 (stored as fp32) fcn = [](InEigenType in) -> float { @@ -343,16 +345,16 @@ CastHelper::CastHelper() }; } -CastHelper::CastHelper() +CastHelper::CastHelper() { // fp32 data converted to bf16 (stored as fp32) fcn = [](float in) -> float { - return fpTrunc(in); // truncate required for conversions from higher precision + return fpTrunc(in); // truncate required for conversions from higher precision }; } -template -CastHelper::CastHelper() +template +CastHelper::CastHelper() { // fp16 data (stored as fp32) converted to integer fcn = [](float in) -> OutEigenType { @@ -366,7 +368,7 @@ CastHelper::CastHelper() }; } -CastHelper::CastHelper() +CastHelper::CastHelper() { // No-op since fp16 values treated internally as their fp32 representation fcn = [](float in) -> OutEigenType { @@ -374,8 +376,8 @@ CastHelper::CastHelper() }; } -template -CastHelper::CastHelper() +template +CastHelper::CastHelper() { // bf16 data (stored as fp32) converted to integer fcn = [](float in) -> OutEigenType { @@ -386,7 +388,7 @@ CastHelper::CastHelper() }; } -CastHelper::CastHelper() +CastHelper::CastHelper() { // No-op since bf16 values treated as truncated fp32 internally fcn = [](InEigenType in) -> OutEigenType { @@ -394,8 +396,8 @@ CastHelper::CastHelper() }; } -template -CastHelper::CastHelper() +template +CastHelper::CastHelper() { // Integer data converted to fp32 fcn = [](InEigenType in) -> float { @@ -404,8 +406,8 @@ CastHelper::CastHelper() }; } -template -CastHelper::CastHelper() +template +CastHelper::CastHelper() { // fp32 data converted to integer fcn = [](float in) -> OutEigenType { @@ -416,6 +418,31 @@ CastHelper::CastHelper() }; } +template +CastHelper::CastHelper() +{ + switch (OutDtype) + { + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: + // fp64 data converted to integer + fcn = [](InEigenType in) -> OutEigenType { + OutEigenType out = std::rint(in); + out = std::max(out, OutMin); + out = std::min(out, OutMax); + return out; + }; + break; + case TOSA_REF_TYPE_FP64: + // no op + fcn = [](InEigenType in) -> OutEigenType { return in; }; + break; + default: + ASSERT_MSG(false, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(OutDtype)); + } +} + // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16); @@ -451,6 +478,13 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16); -- cgit v1.2.1