aboutsummaryrefslogtreecommitdiff
path: root/reference_model/include
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-03-28 22:06:56 +0000
committerTai Ly <tai.ly@arm.com>2023-05-05 19:23:15 +0000
commita4d748b08accce06fab93e2d2b96e499b35ae89b (patch)
tree20a3957e1f45f65f35d5d67ecce1618659e388f0 /reference_model/include
parent0c71686875618b2e11290273b7a05b88ef8a8aae (diff)
downloadreference_model-a4d748b08accce06fab93e2d2b96e499b35ae89b.tar.gz
[reference model] Add precise mode
This adds --precise_mode=1 option to tosa_referece_model, which will cause reference model to convert all floating point tensors to FP64 tensors and compute all operators accordingly. Also adds optional -p arguments to test runners tosa_verif_run_tests.py and tosa_verif_framework_compiler_runner.py to run tests in precise mode Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I156055216ad61710096497a8fa1a653be2a602a3
Diffstat (limited to 'reference_model/include')
-rw-r--r--reference_model/include/dtype.h132
-rw-r--r--reference_model/include/func_config.h1
2 files changed, 133 insertions, 0 deletions
diff --git a/reference_model/include/dtype.h b/reference_model/include/dtype.h
new file mode 100644
index 0000000..4976b54
--- /dev/null
+++ b/reference_model/include/dtype.h
@@ -0,0 +1,132 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef TOSA_REFERENCE_DTYPE_H
+#define TOSA_REFERENCE_DTYPE_H
+
+#include "model_common.h"
+#include "tosa_generated.h"
+#include <cstdint>
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+// Reference Model version of tosa.fbs enum DType
+// Plus a FP64 data type for precise mode.
+enum TOSA_REF_TYPE : uint32_t
+{
+ TOSA_REF_TYPE_UNKNOWN = 0,
+ TOSA_REF_TYPE_BOOL = 1,
+ TOSA_REF_TYPE_UINT8 = 2,
+ TOSA_REF_TYPE_INT4 = 3,
+ TOSA_REF_TYPE_INT8 = 4,
+ TOSA_REF_TYPE_INT16 = 5,
+ TOSA_REF_TYPE_INT32 = 6,
+ TOSA_REF_TYPE_INT48 = 7,
+ TOSA_REF_TYPE_FP32 = 8,
+ TOSA_REF_TYPE_UINT16 = 9,
+ TOSA_REF_TYPE_FP16 = 10,
+ TOSA_REF_TYPE_BF16 = 11,
+ TOSA_REF_TYPE_FP64 = 99, // FP64 is special: add new data types above
+};
+
+inline const char* EnumNameTOSAREFTYPE(TOSA_REF_TYPE e)
+{
+ switch (e)
+ {
+ case TOSA_REF_TYPE_UNKNOWN:
+ return EnumNameDType(DType_UNKNOWN);
+ case TOSA_REF_TYPE_BOOL:
+ return EnumNameDType(DType_BOOL);
+ case TOSA_REF_TYPE_UINT8:
+ return EnumNameDType(DType_UINT8);
+ case TOSA_REF_TYPE_INT4:
+ return EnumNameDType(DType_INT4);
+ case TOSA_REF_TYPE_INT8:
+ return EnumNameDType(DType_INT8);
+ case TOSA_REF_TYPE_INT16:
+ return EnumNameDType(DType_INT16);
+ case TOSA_REF_TYPE_INT32:
+ return EnumNameDType(DType_INT32);
+ case TOSA_REF_TYPE_INT48:
+ return EnumNameDType(DType_INT48);
+ case TOSA_REF_TYPE_FP32:
+ return EnumNameDType(DType_FP32);
+ case TOSA_REF_TYPE_UINT16:
+ return EnumNameDType(DType_UINT16);
+ case TOSA_REF_TYPE_FP16:
+ return EnumNameDType(DType_FP16);
+ case TOSA_REF_TYPE_BF16:
+ return EnumNameDType(DType_BF16);
+ case TOSA_REF_TYPE_FP64:
+ return "FP64";
+ default:
+ assert(false);
+ }
+}
+
+// return corresponding TOSA_REF_TYPE for DType
+inline TOSA_REF_TYPE ConvertDType(const DType dtype)
+{
+ assert(DType_MAX == DType_BF16); // must update whenever DType_MAX changes
+
+ if (g_func_config.precise_mode)
+ {
+ // in precise mode, convert all floating DType to TOSA_REF_TYPE_FP64
+ switch (dtype)
+ {
+ case DType_FP16:
+ case DType_FP32:
+ case DType_BF16:
+ return TOSA_REF_TYPE_FP64;
+ default:
+ break;
+ }
+ }
+
+ switch (dtype)
+ {
+ case DType_BOOL:
+ return TOSA_REF_TYPE_BOOL;
+ case DType_UINT8:
+ return TOSA_REF_TYPE_UINT8;
+ case DType_INT4:
+ return TOSA_REF_TYPE_INT4;
+ case DType_INT8:
+ return TOSA_REF_TYPE_INT8;
+ case DType_INT16:
+ return TOSA_REF_TYPE_INT16;
+ case DType_INT32:
+ return TOSA_REF_TYPE_INT32;
+ case DType_INT48:
+ return TOSA_REF_TYPE_INT48;
+ case DType_FP32:
+ return TOSA_REF_TYPE_FP32;
+ case DType_UINT16:
+ return TOSA_REF_TYPE_UINT16;
+ case DType_FP16:
+ return TOSA_REF_TYPE_FP16;
+ case DType_BF16:
+ return TOSA_REF_TYPE_BF16;
+ default:
+ break;
+ }
+ return TOSA_REF_TYPE_UNKNOWN;
+}
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/include/func_config.h b/reference_model/include/func_config.h
index c1f8ef6..b92845b 100644
--- a/reference_model/include/func_config.h
+++ b/reference_model/include/func_config.h
@@ -48,6 +48,7 @@ struct func_config_t
uint32_t tosa_profile = 1;
uint32_t dump_intermediates = 0;
std::string fp_format = "0.5";
+ uint32_t precise_mode = 0;
bool float_is_big_endian = false; // Set in arith_util.h by float_is_big_endian()
tosa_level_t tosa_level;