From a4d748b08accce06fab93e2d2b96e499b35ae89b Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Tue, 28 Mar 2023 22:06:56 +0000 Subject: [reference model] Add precise mode This adds --precise_mode=1 option to tosa_referece_model, which will cause reference model to convert all floating point tensors to FP64 tensors and compute all operators accordingly. Also adds optional -p arguments to test runners tosa_verif_run_tests.py and tosa_verif_framework_compiler_runner.py to run tests in precise mode Signed-off-by: Tai Ly Change-Id: I156055216ad61710096497a8fa1a653be2a602a3 --- reference_model/include/dtype.h | 132 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 reference_model/include/dtype.h (limited to 'reference_model/include/dtype.h') diff --git a/reference_model/include/dtype.h b/reference_model/include/dtype.h new file mode 100644 index 0000000..4976b54 --- /dev/null +++ b/reference_model/include/dtype.h @@ -0,0 +1,132 @@ +// Copyright (c) 2023, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TOSA_REFERENCE_DTYPE_H +#define TOSA_REFERENCE_DTYPE_H + +#include "model_common.h" +#include "tosa_generated.h" +#include + +using namespace tosa; + +namespace TosaReference +{ + +// Reference Model version of tosa.fbs enum DType +// Plus a FP64 data type for precise mode. +enum TOSA_REF_TYPE : uint32_t +{ + TOSA_REF_TYPE_UNKNOWN = 0, + TOSA_REF_TYPE_BOOL = 1, + TOSA_REF_TYPE_UINT8 = 2, + TOSA_REF_TYPE_INT4 = 3, + TOSA_REF_TYPE_INT8 = 4, + TOSA_REF_TYPE_INT16 = 5, + TOSA_REF_TYPE_INT32 = 6, + TOSA_REF_TYPE_INT48 = 7, + TOSA_REF_TYPE_FP32 = 8, + TOSA_REF_TYPE_UINT16 = 9, + TOSA_REF_TYPE_FP16 = 10, + TOSA_REF_TYPE_BF16 = 11, + TOSA_REF_TYPE_FP64 = 99, // FP64 is special: add new data types above +}; + +inline const char* EnumNameTOSAREFTYPE(TOSA_REF_TYPE e) +{ + switch (e) + { + case TOSA_REF_TYPE_UNKNOWN: + return EnumNameDType(DType_UNKNOWN); + case TOSA_REF_TYPE_BOOL: + return EnumNameDType(DType_BOOL); + case TOSA_REF_TYPE_UINT8: + return EnumNameDType(DType_UINT8); + case TOSA_REF_TYPE_INT4: + return EnumNameDType(DType_INT4); + case TOSA_REF_TYPE_INT8: + return EnumNameDType(DType_INT8); + case TOSA_REF_TYPE_INT16: + return EnumNameDType(DType_INT16); + case TOSA_REF_TYPE_INT32: + return EnumNameDType(DType_INT32); + case TOSA_REF_TYPE_INT48: + return EnumNameDType(DType_INT48); + case TOSA_REF_TYPE_FP32: + return EnumNameDType(DType_FP32); + case TOSA_REF_TYPE_UINT16: + return EnumNameDType(DType_UINT16); + case TOSA_REF_TYPE_FP16: + return EnumNameDType(DType_FP16); + case TOSA_REF_TYPE_BF16: + return EnumNameDType(DType_BF16); + case TOSA_REF_TYPE_FP64: + return "FP64"; + default: + assert(false); + } +} + +// return corresponding TOSA_REF_TYPE for DType +inline TOSA_REF_TYPE ConvertDType(const DType dtype) +{ + assert(DType_MAX == DType_BF16); // must update whenever DType_MAX changes + + if (g_func_config.precise_mode) + { + // in precise mode, convert all floating DType to TOSA_REF_TYPE_FP64 + switch (dtype) + { + case DType_FP16: + case DType_FP32: + case DType_BF16: + return TOSA_REF_TYPE_FP64; + default: + break; + } + } + + switch (dtype) + { + case DType_BOOL: + return TOSA_REF_TYPE_BOOL; + case DType_UINT8: + return TOSA_REF_TYPE_UINT8; + case DType_INT4: + return TOSA_REF_TYPE_INT4; + case DType_INT8: + return TOSA_REF_TYPE_INT8; + case DType_INT16: + return TOSA_REF_TYPE_INT16; + case DType_INT32: + return TOSA_REF_TYPE_INT32; + case DType_INT48: + return TOSA_REF_TYPE_INT48; + case DType_FP32: + return TOSA_REF_TYPE_FP32; + case DType_UINT16: + return TOSA_REF_TYPE_UINT16; + case DType_FP16: + return TOSA_REF_TYPE_FP16; + case DType_BF16: + return TOSA_REF_TYPE_BF16; + default: + break; + } + return TOSA_REF_TYPE_UNKNOWN; +} + +}; // namespace TosaReference + +#endif -- cgit v1.2.1