diff options
author | Tai Ly <tai.ly@arm.com> | 2023-03-28 22:06:56 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2023-05-05 19:23:15 +0000 |
commit | a4d748b08accce06fab93e2d2b96e499b35ae89b (patch) | |
tree | 20a3957e1f45f65f35d5d67ecce1618659e388f0 /reference_model/include | |
parent | 0c71686875618b2e11290273b7a05b88ef8a8aae (diff) | |
download | reference_model-a4d748b08accce06fab93e2d2b96e499b35ae89b.tar.gz |
[reference model] Add precise mode
This adds --precise_mode=1 option to tosa_referece_model,
which will cause reference model to convert all floating point tensors
to FP64 tensors and compute all operators accordingly.
Also adds optional -p arguments to test runners tosa_verif_run_tests.py
and tosa_verif_framework_compiler_runner.py to run tests in precise mode
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I156055216ad61710096497a8fa1a653be2a602a3
Diffstat (limited to 'reference_model/include')
-rw-r--r-- | reference_model/include/dtype.h | 132 | ||||
-rw-r--r-- | reference_model/include/func_config.h | 1 |
2 files changed, 133 insertions, 0 deletions
diff --git a/reference_model/include/dtype.h b/reference_model/include/dtype.h new file mode 100644 index 0000000..4976b54 --- /dev/null +++ b/reference_model/include/dtype.h @@ -0,0 +1,132 @@ +// Copyright (c) 2023, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TOSA_REFERENCE_DTYPE_H +#define TOSA_REFERENCE_DTYPE_H + +#include "model_common.h" +#include "tosa_generated.h" +#include <cstdint> + +using namespace tosa; + +namespace TosaReference +{ + +// Reference Model version of tosa.fbs enum DType +// Plus a FP64 data type for precise mode. +enum TOSA_REF_TYPE : uint32_t +{ + TOSA_REF_TYPE_UNKNOWN = 0, + TOSA_REF_TYPE_BOOL = 1, + TOSA_REF_TYPE_UINT8 = 2, + TOSA_REF_TYPE_INT4 = 3, + TOSA_REF_TYPE_INT8 = 4, + TOSA_REF_TYPE_INT16 = 5, + TOSA_REF_TYPE_INT32 = 6, + TOSA_REF_TYPE_INT48 = 7, + TOSA_REF_TYPE_FP32 = 8, + TOSA_REF_TYPE_UINT16 = 9, + TOSA_REF_TYPE_FP16 = 10, + TOSA_REF_TYPE_BF16 = 11, + TOSA_REF_TYPE_FP64 = 99, // FP64 is special: add new data types above +}; + +inline const char* EnumNameTOSAREFTYPE(TOSA_REF_TYPE e) +{ + switch (e) + { + case TOSA_REF_TYPE_UNKNOWN: + return EnumNameDType(DType_UNKNOWN); + case TOSA_REF_TYPE_BOOL: + return EnumNameDType(DType_BOOL); + case TOSA_REF_TYPE_UINT8: + return EnumNameDType(DType_UINT8); + case TOSA_REF_TYPE_INT4: + return EnumNameDType(DType_INT4); + case TOSA_REF_TYPE_INT8: + return EnumNameDType(DType_INT8); + case TOSA_REF_TYPE_INT16: + return EnumNameDType(DType_INT16); + case TOSA_REF_TYPE_INT32: + return EnumNameDType(DType_INT32); + case TOSA_REF_TYPE_INT48: + return EnumNameDType(DType_INT48); + case TOSA_REF_TYPE_FP32: + return EnumNameDType(DType_FP32); + case TOSA_REF_TYPE_UINT16: + return EnumNameDType(DType_UINT16); + case TOSA_REF_TYPE_FP16: + return EnumNameDType(DType_FP16); + case TOSA_REF_TYPE_BF16: + return EnumNameDType(DType_BF16); + case TOSA_REF_TYPE_FP64: + return "FP64"; + default: + assert(false); + } +} + +// return corresponding TOSA_REF_TYPE for DType +inline TOSA_REF_TYPE ConvertDType(const DType dtype) +{ + assert(DType_MAX == DType_BF16); // must update whenever DType_MAX changes + + if (g_func_config.precise_mode) + { + // in precise mode, convert all floating DType to TOSA_REF_TYPE_FP64 + switch (dtype) + { + case DType_FP16: + case DType_FP32: + case DType_BF16: + return TOSA_REF_TYPE_FP64; + default: + break; + } + } + + switch (dtype) + { + case DType_BOOL: + return TOSA_REF_TYPE_BOOL; + case DType_UINT8: + return TOSA_REF_TYPE_UINT8; + case DType_INT4: + return TOSA_REF_TYPE_INT4; + case DType_INT8: + return TOSA_REF_TYPE_INT8; + case DType_INT16: + return TOSA_REF_TYPE_INT16; + case DType_INT32: + return TOSA_REF_TYPE_INT32; + case DType_INT48: + return TOSA_REF_TYPE_INT48; + case DType_FP32: + return TOSA_REF_TYPE_FP32; + case DType_UINT16: + return TOSA_REF_TYPE_UINT16; + case DType_FP16: + return TOSA_REF_TYPE_FP16; + case DType_BF16: + return TOSA_REF_TYPE_BF16; + default: + break; + } + return TOSA_REF_TYPE_UNKNOWN; +} + +}; // namespace TosaReference + +#endif diff --git a/reference_model/include/func_config.h b/reference_model/include/func_config.h index c1f8ef6..b92845b 100644 --- a/reference_model/include/func_config.h +++ b/reference_model/include/func_config.h @@ -48,6 +48,7 @@ struct func_config_t uint32_t tosa_profile = 1; uint32_t dump_intermediates = 0; std::string fp_format = "0.5"; + uint32_t precise_mode = 0; bool float_is_big_endian = false; // Set in arith_util.h by float_is_big_endian() tosa_level_t tosa_level; |