From a4d748b08accce06fab93e2d2b96e499b35ae89b Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Tue, 28 Mar 2023 22:06:56 +0000 Subject: [reference model] Add precise mode This adds --precise_mode=1 option to tosa_referece_model, which will cause reference model to convert all floating point tensors to FP64 tensors and compute all operators accordingly. Also adds optional -p arguments to test runners tosa_verif_run_tests.py and tosa_verif_framework_compiler_runner.py to run tests in precise mode Signed-off-by: Tai Ly Change-Id: I156055216ad61710096497a8fa1a653be2a602a3 --- reference_model/CMakeLists.txt | 3 +- reference_model/include/dtype.h | 132 +++ reference_model/include/func_config.h | 1 + reference_model/src/arith_util.h | 15 +- reference_model/src/command_line_utils.h | 1 + reference_model/src/graph_node.h | 21 +- reference_model/src/model_runner_impl.cc | 8 +- reference_model/src/ops/activation_funcs.cc | 56 +- reference_model/src/ops/activation_funcs.h | 8 +- reference_model/src/ops/comparison.cc | 44 +- reference_model/src/ops/comparison.h | 26 +- reference_model/src/ops/control_flow.cc | 21 +- reference_model/src/ops/data_layout.cc | 93 +- reference_model/src/ops/data_layout.h | 22 +- reference_model/src/ops/data_nodes.cc | 11 +- reference_model/src/ops/data_nodes.h | 4 +- reference_model/src/ops/ewise_binary.cc | 210 ++-- reference_model/src/ops/ewise_binary.h | 36 +- reference_model/src/ops/ewise_ternary.cc | 18 +- reference_model/src/ops/ewise_ternary.h | 10 +- reference_model/src/ops/ewise_unary.cc | 164 ++- reference_model/src/ops/ewise_unary.h | 10 +- reference_model/src/ops/image.cc | 55 +- reference_model/src/ops/image.h | 4 +- reference_model/src/ops/op_factory.cc | 60 +- reference_model/src/ops/op_factory.h | 82 +- reference_model/src/ops/reduction.cc | 111 +- reference_model/src/ops/reduction.h | 38 +- reference_model/src/ops/scatter_gather.cc | 20 +- reference_model/src/ops/scatter_gather.h | 6 +- reference_model/src/ops/template_types.h | 96 +- reference_model/src/ops/tensor_ops.cc | 203 ++-- reference_model/src/ops/tensor_ops.h | 24 +- reference_model/src/ops/type_conversion.cc | 116 +- reference_model/src/ops/type_conversion.h | 99 +- reference_model/src/subgraph_traverser.cc | 90 +- reference_model/src/tensor.cc | 1251 ++++++++++++++++---- reference_model/src/tensor.h | 209 +++- thirdparty/serialization_lib | 2 +- .../tosa_verif_framework_compiler_runner.py | 16 +- verif/runner/tosa_refmodel_sut_run.py | 3 + verif/runner/tosa_verif_run_tests.py | 7 + 42 files changed, 2412 insertions(+), 994 deletions(-) create mode 100644 reference_model/include/dtype.h diff --git a/reference_model/CMakeLists.txt b/reference_model/CMakeLists.txt index 6494225..a086f4b 100644 --- a/reference_model/CMakeLists.txt +++ b/reference_model/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required (VERSION 3.4) -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2023, ARM Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -109,6 +109,7 @@ set(PUBLIC_HEADERS) list(APPEND PUBLIC_HEADERS include/debug_modes.def include/debug_types.h + include/dtype.h include/func_config.h include/func_debug.h include/graph_status.h diff --git a/reference_model/include/dtype.h b/reference_model/include/dtype.h new file mode 100644 index 0000000..4976b54 --- /dev/null +++ b/reference_model/include/dtype.h @@ -0,0 +1,132 @@ +// Copyright (c) 2023, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TOSA_REFERENCE_DTYPE_H +#define TOSA_REFERENCE_DTYPE_H + +#include "model_common.h" +#include "tosa_generated.h" +#include + +using namespace tosa; + +namespace TosaReference +{ + +// Reference Model version of tosa.fbs enum DType +// Plus a FP64 data type for precise mode. +enum TOSA_REF_TYPE : uint32_t +{ + TOSA_REF_TYPE_UNKNOWN = 0, + TOSA_REF_TYPE_BOOL = 1, + TOSA_REF_TYPE_UINT8 = 2, + TOSA_REF_TYPE_INT4 = 3, + TOSA_REF_TYPE_INT8 = 4, + TOSA_REF_TYPE_INT16 = 5, + TOSA_REF_TYPE_INT32 = 6, + TOSA_REF_TYPE_INT48 = 7, + TOSA_REF_TYPE_FP32 = 8, + TOSA_REF_TYPE_UINT16 = 9, + TOSA_REF_TYPE_FP16 = 10, + TOSA_REF_TYPE_BF16 = 11, + TOSA_REF_TYPE_FP64 = 99, // FP64 is special: add new data types above +}; + +inline const char* EnumNameTOSAREFTYPE(TOSA_REF_TYPE e) +{ + switch (e) + { + case TOSA_REF_TYPE_UNKNOWN: + return EnumNameDType(DType_UNKNOWN); + case TOSA_REF_TYPE_BOOL: + return EnumNameDType(DType_BOOL); + case TOSA_REF_TYPE_UINT8: + return EnumNameDType(DType_UINT8); + case TOSA_REF_TYPE_INT4: + return EnumNameDType(DType_INT4); + case TOSA_REF_TYPE_INT8: + return EnumNameDType(DType_INT8); + case TOSA_REF_TYPE_INT16: + return EnumNameDType(DType_INT16); + case TOSA_REF_TYPE_INT32: + return EnumNameDType(DType_INT32); + case TOSA_REF_TYPE_INT48: + return EnumNameDType(DType_INT48); + case TOSA_REF_TYPE_FP32: + return EnumNameDType(DType_FP32); + case TOSA_REF_TYPE_UINT16: + return EnumNameDType(DType_UINT16); + case TOSA_REF_TYPE_FP16: + return EnumNameDType(DType_FP16); + case TOSA_REF_TYPE_BF16: + return EnumNameDType(DType_BF16); + case TOSA_REF_TYPE_FP64: + return "FP64"; + default: + assert(false); + } +} + +// return corresponding TOSA_REF_TYPE for DType +inline TOSA_REF_TYPE ConvertDType(const DType dtype) +{ + assert(DType_MAX == DType_BF16); // must update whenever DType_MAX changes + + if (g_func_config.precise_mode) + { + // in precise mode, convert all floating DType to TOSA_REF_TYPE_FP64 + switch (dtype) + { + case DType_FP16: + case DType_FP32: + case DType_BF16: + return TOSA_REF_TYPE_FP64; + default: + break; + } + } + + switch (dtype) + { + case DType_BOOL: + return TOSA_REF_TYPE_BOOL; + case DType_UINT8: + return TOSA_REF_TYPE_UINT8; + case DType_INT4: + return TOSA_REF_TYPE_INT4; + case DType_INT8: + return TOSA_REF_TYPE_INT8; + case DType_INT16: + return TOSA_REF_TYPE_INT16; + case DType_INT32: + return TOSA_REF_TYPE_INT32; + case DType_INT48: + return TOSA_REF_TYPE_INT48; + case DType_FP32: + return TOSA_REF_TYPE_FP32; + case DType_UINT16: + return TOSA_REF_TYPE_UINT16; + case DType_FP16: + return TOSA_REF_TYPE_FP16; + case DType_BF16: + return TOSA_REF_TYPE_BF16; + default: + break; + } + return TOSA_REF_TYPE_UNKNOWN; +} + +}; // namespace TosaReference + +#endif diff --git a/reference_model/include/func_config.h b/reference_model/include/func_config.h index c1f8ef6..b92845b 100644 --- a/reference_model/include/func_config.h +++ b/reference_model/include/func_config.h @@ -48,6 +48,7 @@ struct func_config_t uint32_t tosa_profile = 1; uint32_t dump_intermediates = 0; std::string fp_format = "0.5"; + uint32_t precise_mode = 0; bool float_is_big_endian = false; // Set in arith_util.h by float_is_big_endian() tosa_level_t tosa_level; diff --git a/reference_model/src/arith_util.h b/reference_model/src/arith_util.h index 59bdf44..76f3bb8 100644 --- a/reference_model/src/arith_util.h +++ b/reference_model/src/arith_util.h @@ -30,11 +30,11 @@ #include #include #define __STDC_LIMIT_MACROS //enable min/max of plain data type +#include "dtype.h" #include "func_config.h" #include "func_debug.h" #include "half.hpp" #include "inttypes.h" -#include "tosa_generated.h" #include #include #include @@ -45,6 +45,7 @@ using namespace tosa; using namespace std; +using namespace TosaReference; inline size_t _count_one(uint64_t val) { @@ -259,27 +260,27 @@ inline bool float_is_big_endian() return f_as_bytes[0] != f_neg_as_bytes[0]; } -template +template float fpTrunc(float f_in) { - /* Truncates a float value based on the DType it represents.*/ + /* Truncates a float value based on the TOSA_REF_TYPE it represents.*/ switch (Dtype) { - case DType_BF16: + case TOSA_REF_TYPE_BF16: truncateFloatToBFloat(&f_in, 1); break; - case DType_FP16: + case TOSA_REF_TYPE_FP16: // Cast to temporary float16 value before casting back to float32 { half_float::half h = half_float::half_cast(f_in); f_in = half_float::half_cast(h); break; } - case DType_FP32: + case TOSA_REF_TYPE_FP32: // No-op for fp32 break; default: - ASSERT_MSG(false, "DType %s should not be float-truncated.", EnumNameDType(Dtype)); + ASSERT_MSG(false, "TOSA_REF_TYPE %s should not be float-truncated.", EnumNameTOSAREFTYPE(Dtype)); } return f_in; } diff --git a/reference_model/src/command_line_utils.h b/reference_model/src/command_line_utils.h index 4e6e555..8fdff01 100644 --- a/reference_model/src/command_line_utils.h +++ b/reference_model/src/command_line_utils.h @@ -59,6 +59,7 @@ int func_model_parse_cmd_line( ("tosa_level", "Set TOSA level (NONE, EIGHTK)", cxxopts::value(func_config.tosa_level)) ("dump_intermediates", "Dump intermediate tensors (0/1)", cxxopts::value(func_config.dump_intermediates)) + ("p,precise_mode", "Calculate floating point operations in FP64 (0/1)", cxxopts::value(func_config.precise_mode)) ("v,version", "print model version") ("i,input_tensor_file", "specify input tensor files", cxxopts::value>()) ("l,loglevel", func_debug.get_debug_verbosity_help_string(), cxxopts::value()) diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h index a9a336b..3433192 100644 --- a/reference_model/src/graph_node.h +++ b/reference_model/src/graph_node.h @@ -19,29 +19,30 @@ #include "attribute.h" #include "subgraph_traverser.h" #include "tensor.h" -#include "tosa_generated.h" #include -#define DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) template class TosaReference::OP; +#define DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) \ + template class TosaReference::OP; #define DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \ - template class TosaReference::OP; + template class TosaReference::OP; #define DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \ - template class TosaReference::OP; + template class TosaReference::OP; #define DEF_INSTANTIATE_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \ - template class TosaReference::OP; + template class TosaReference::OP; -#define DEF_INSTANTIATE_ONE_TYPE(OP, DTYPE) template class TosaReference::OP; +#define DEF_INSTANTIATE_ONE_TYPE(OP, DTYPE) template class TosaReference::OP; -#define DEF_INSTANTIATE_TWO_TYPE(OP, DTYPE1, DTYPE2) template class TosaReference::OP; +#define DEF_INSTANTIATE_TWO_TYPE(OP, DTYPE1, DTYPE2) \ + template class TosaReference::OP; -#define DEF_INSTANTIATE_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3) \ - template class TosaReference::OP; +#define DEF_INSTANTIATE_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3) \ + template class TosaReference::OP; #define DEF_INSTANTIATE_THREE_TYPE_RESIZE(OP, DTYPE1, DTYPE2, OP_TYPE) \ - template class TosaReference::OP; + template class TosaReference::OP; #define DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 0, DTYPE) \ diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc index fa39c75..8089a1a 100644 --- a/reference_model/src/model_runner_impl.cc +++ b/reference_model/src/model_runner_impl.cc @@ -186,13 +186,13 @@ int ModelRunnerImpl::setInput(std::string input_name, uint8_t* raw_ptr, size_t s int status = 0; switch (tensor->getDtype()) { - case DType::DType_FP16: { + case TOSA_REF_TYPE_FP16: { auto typed_ptr = reinterpret_cast(raw_ptr); const int elements = size / sizeof(half_float::half); status = setInput(input_name, ArrayProxy(elements, typed_ptr)); break; } - case DType::DType_FP32: { + case TOSA_REF_TYPE_FP32: { auto typed_ptr = reinterpret_cast(raw_ptr); const int elements = size / sizeof(float); status = setInput(input_name, ArrayProxy(elements, typed_ptr)); @@ -252,13 +252,13 @@ int ModelRunnerImpl::getOutput(std::string output_name, uint8_t* raw_ptr, size_t int status = 0; switch (tensor->getDtype()) { - case DType::DType_FP16: { + case TOSA_REF_TYPE_FP16: { auto typed_ptr = reinterpret_cast(raw_ptr); const int elements = size / sizeof(half_float::half); status = tensor->writeToVector(ArrayProxy(elements, typed_ptr)); break; } - case DType::DType_FP32: { + case TOSA_REF_TYPE_FP32: { auto typed_ptr = reinterpret_cast(raw_ptr); const int elements = size / sizeof(float); status = tensor->writeToVector(ArrayProxy(elements, typed_ptr)); diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc index 24bd077..6681d6d 100644 --- a/reference_model/src/ops/activation_funcs.cc +++ b/reference_model/src/ops/activation_funcs.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template +template int OpClamp::register_fcn() { // Check Tosa Level @@ -32,9 +32,9 @@ int OpClamp::register_fcn() switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: { InEigenType min = (InEigenType)attribute->min_fp(); InEigenType max = (InEigenType)attribute->max_fp(); @@ -43,8 +43,17 @@ int OpClamp::register_fcn() this->fcn = [min, max](InEigenType a) -> OutEigenType { return fpTrunc(a <= min ? min : a >= max ? max : a); }; } break; - case DType_INT8: - case DType_INT16: + case TOSA_REF_TYPE_FP64: + { + InEigenType min = (InEigenType)attribute->min_fp(); + InEigenType max = (InEigenType)attribute->max_fp(); + ERROR_IF(max < min, "OpClamp: max smaller than min"); + + this->fcn = [min, max](InEigenType a) -> OutEigenType { return (a <= min ? min : a >= max ? max : a); }; + } + break; + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: { InEigenType min = (InEigenType)attribute->min_int(); InEigenType max = (InEigenType)attribute->max_int(); @@ -53,19 +62,19 @@ int OpClamp::register_fcn() } break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template OpClamp::~OpClamp() { if (attribute) delete attribute; } -template +template int OpSigmoid::register_fcn() { // Check Tosa Level @@ -74,21 +83,24 @@ int OpSigmoid::register_fcn() switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(1.f / (1.f + (expf(-1.f * a)))); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return (1.L / (1.L + (exp(-1.L * a)))); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpTanh::register_fcn() { // Check Tosa Level @@ -97,13 +109,16 @@ int OpTanh::register_fcn() switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(tanhf(a)); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return tanh(a); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; @@ -115,11 +130,14 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP64); diff --git a/reference_model/src/ops/activation_funcs.h b/reference_model/src/ops/activation_funcs.h index 9a697cd..2372fcb 100644 --- a/reference_model/src/ops/activation_funcs.h +++ b/reference_model/src/ops/activation_funcs.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ using namespace tosa; namespace TosaReference { -template +template class OpClamp : public UnaryNode { public: @@ -45,7 +45,7 @@ protected: TosaClampAttribute* attribute; }; -template +template class OpSigmoid : public UnaryNode { public: @@ -61,7 +61,7 @@ public: virtual int register_fcn(); }; -template +template class OpTanh : public UnaryNode { public: diff --git a/reference_model/src/ops/comparison.cc b/reference_model/src/ops/comparison.cc index a5711eb..8a084c7 100644 --- a/reference_model/src/ops/comparison.cc +++ b/reference_model/src/ops/comparison.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template +template int OpEqual::register_fcn() { // Check Tosa Level @@ -31,20 +31,21 @@ int OpEqual::register_fcn() switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: - case DType_INT32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_INT32: + case TOSA_REF_TYPE_FP64: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpGreater::register_fcn() { // Check Tosa Level @@ -53,20 +54,21 @@ int OpGreater::register_fcn() switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: - case DType_INT32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_INT32: + case TOSA_REF_TYPE_FP64: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpGreaterEqual::register_fcn() { // Check Tosa Level @@ -75,14 +77,15 @@ int OpGreaterEqual::register_fcn() switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: - case DType_INT32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_INT32: + case TOSA_REF_TYPE_FP64: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; @@ -93,13 +96,16 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP64); diff --git a/reference_model/src/ops/comparison.h b/reference_model/src/ops/comparison.h index 29e6b5a..263df6a 100644 --- a/reference_model/src/ops/comparison.h +++ b/reference_model/src/ops/comparison.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,45 +24,45 @@ using namespace tosa; namespace TosaReference { -template -class OpEqual : public BinaryNode +template +class OpEqual : public BinaryNode { public: OpEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) - : BinaryNode(sgt_, Op_EQUAL, id_) + : BinaryNode(sgt_, Op_EQUAL, id_) { register_fcn(); } using InEigenType = typename GetEigenType::type; - using OutEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; virtual int register_fcn(); }; -template -class OpGreater : public BinaryNode +template +class OpGreater : public BinaryNode { public: OpGreater(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) - : BinaryNode(sgt_, Op_GREATER, id_) + : BinaryNode(sgt_, Op_GREATER, id_) { register_fcn(); } using InEigenType = typename GetEigenType::type; - using OutEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; virtual int register_fcn(); }; -template -class OpGreaterEqual : public BinaryNode +template +class OpGreaterEqual : public BinaryNode { public: OpGreaterEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) - : BinaryNode(sgt_, Op_EQUAL, id_) + : BinaryNode(sgt_, Op_EQUAL, id_) { register_fcn(); } using InEigenType = typename GetEigenType::type; - using OutEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; virtual int register_fcn(); }; diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc index f573d5b..03ad6c6 100644 --- a/reference_model/src/ops/control_flow.cc +++ b/reference_model/src/ops/control_flow.cc @@ -174,8 +174,8 @@ int OpCondIf::checkTensorAttributes() { ERROR_IF(getInputs().size() < 1, "OpCondIf: must have at least 1 operand"); - ERROR_IF(inputs[0]->getDtype() != DType_BOOL || inputs[0]->getRank() != 0, - "OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNamesDType()[inputs[0]->getDtype()], + ERROR_IF(inputs[0]->getDtype() != TOSA_REF_TYPE_BOOL || inputs[0]->getRank() != 0, + "OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNameTOSAREFTYPE(inputs[0]->getDtype()), inputs[0]->getRank()); cond = dynamic_cast*>(inputs[0]); @@ -223,9 +223,9 @@ int OpCondIf::checkTensorAttributes() std::string else_block_input_name = else_block->GetInputs()[i]; TosaSerializationTensor* then_block_input = then_block->GetTensorByName(then_block_input_name); TosaSerializationTensor* else_block_input = else_block->GetTensorByName(else_block_input_name); - ERROR_IF(operator_input->getDtype() != then_block_input->GetDtype(), + ERROR_IF(operator_input->getDtype() != ConvertDType(then_block_input->GetDtype()), "OpCondIf: input tensor type mismatch with then_block input type"); - ERROR_IF(operator_input->getDtype() != else_block_input->GetDtype(), + ERROR_IF(operator_input->getDtype() != ConvertDType(else_block_input->GetDtype()), "OpCondIf: input tensor type mismatch with else_block input type"); ERROR_IF(operator_input->getRank() != (int32_t)then_block_input->GetShape().size(), "OpCondIf: input tensor rank mismatch with then_block input rank"); @@ -247,9 +247,9 @@ int OpCondIf::checkTensorAttributes() std::string else_block_output_name = else_block->GetOutputs()[i]; TosaSerializationTensor* then_block_output = then_block->GetTensorByName(then_block_output_name); TosaSerializationTensor* else_block_output = else_block->GetTensorByName(else_block_output_name); - ERROR_IF(operator_output->getDtype() != then_block_output->GetDtype(), + ERROR_IF(operator_output->getDtype() != ConvertDType(then_block_output->GetDtype()), "OpCondIf: output tensor type mismatch with then_block output type"); - ERROR_IF(operator_output->getDtype() != else_block_output->GetDtype(), + ERROR_IF(operator_output->getDtype() != ConvertDType(else_block_output->GetDtype()), "OpCondIf: output tensor type mismatch with else_block output type"); ERROR_IF(operator_output->getRank() != (int32_t)then_block_output->GetShape().size(), "OpCondIf: output tensor rank mismatch with then_block output rank"); @@ -364,11 +364,11 @@ int OpWhileLoop::checkTensorAttributes() TosaSerializationTensor* body_block_input = body_block->GetTensorByName(body_block_input_name); TosaSerializationTensor* body_block_output = body_block->GetTensorByName(body_block_output_name); - ERROR_IF(operator_input->getDtype() != cond_block_input->GetDtype(), + ERROR_IF(operator_input->getDtype() != ConvertDType(cond_block_input->GetDtype()), "OpWhileLoop: input tensor type mismatch with cond_block input type"); - ERROR_IF(operator_input->getDtype() != body_block_input->GetDtype(), + ERROR_IF(operator_input->getDtype() != ConvertDType(body_block_input->GetDtype()), "OpWhileLoop: input tensor type mismatch with body_block input type"); - ERROR_IF(operator_input->getDtype() != body_block_output->GetDtype(), + ERROR_IF(operator_input->getDtype() != ConvertDType(body_block_output->GetDtype()), "OpWhileLoop: input tensor type mismatch with body_block output type"); ERROR_IF(operator_input->getRank() != (int32_t)cond_block_input->GetShape().size(), "OpWhileLoop: input tensor rank mismatch with cond_block input rank"); @@ -399,8 +399,7 @@ int OpWhileLoop::checkTensorAttributes() int OpWhileLoop::eval() { - - TosaReference::Tensor0 cond_output_ctensor(std::string("cond_output"), DType_BOOL, std::vector({})); + TosaReference::Tensor0 cond_output_ctensor("cond_output", DType_BOOL, std::vector({})); cond_output_ctensor.allocate(); std::vector cond_block_outputs; diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index a189466..442cef8 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -20,7 +20,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template +template OpConcat::OpConcat(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -32,14 +32,14 @@ OpConcat::OpConcat(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Axis); } -template +template OpConcat::~OpConcat() { if (attribute) delete attribute; } -template +template int OpConcat::checkTensorAttributes() { // Check Tosa Level @@ -100,7 +100,7 @@ int OpConcat::checkTensorAttributes() return 0; } -template +template int OpConcat::eval() { @@ -124,7 +124,7 @@ int OpConcat::eval() return GraphNode::eval(); } -template +template OpPad::OpPad(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -136,12 +136,12 @@ OpPad::OpPad(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Pad); } -template +template OpPad::~OpPad() { } -template +template int OpPad::checkTensorAttributes() { // Check Tosa Level @@ -185,22 +185,23 @@ int OpPad::checkTensorAttributes() return 0; } -template +template int OpPad::eval() { InEigenType pad_value = 0; switch (Dtype) { - case DType_BOOL: - case DType_INT8: - case DType_INT16: - case DType_INT32: + case TOSA_REF_TYPE_BOOL: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: pad_value = (InEigenType)attribute->pad_const_int(); break; - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_FP64: pad_value = (InEigenType)attribute->pad_const_fp(); break; default: @@ -213,7 +214,7 @@ int OpPad::eval() return GraphNode::eval(); } -template +template OpReshape::OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -225,14 +226,14 @@ OpReshape::OpReshape(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Reshape); } -template +template OpReshape::~OpReshape() { if (attribute) delete attribute; } -template +template int OpReshape::checkTensorAttributes() { // Check Tosa Level @@ -270,7 +271,7 @@ int OpReshape::checkTensorAttributes() return 0; } -template +template int OpReshape::eval() { for (int32_t d = 0; d < OutRank; d++) @@ -313,7 +314,7 @@ int OpReshape::eval() return GraphNode::eval(); } -template +template OpReverse::OpReverse(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -325,14 +326,14 @@ OpReverse::OpReverse(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Axis); } -template +template OpReverse::~OpReverse() { if (attribute) delete attribute; } -template +template int OpReverse::checkTensorAttributes() { // Check Tosa Level @@ -376,7 +377,7 @@ int OpReverse::checkTensorAttributes() return 0; } -template +template int OpReverse::eval() { out->getTensor() = in->getTensor().reverse(reverse_array); @@ -384,7 +385,7 @@ int OpReverse::eval() return GraphNode::eval(); } -template +template OpSlice::OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -396,14 +397,14 @@ OpSlice::OpSlice(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Slice); } -template +template OpSlice::~OpSlice() { if (attribute) delete attribute; } -template +template int OpSlice::checkTensorAttributes() { // Check Tosa Level @@ -449,7 +450,7 @@ int OpSlice::checkTensorAttributes() return 0; } -template +template int OpSlice::eval() { out->getTensor() = in->getTensor().slice(begin_array, size_array); @@ -457,7 +458,7 @@ int OpSlice::eval() return GraphNode::eval(); } -template +template OpTileBase::OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -469,14 +470,14 @@ OpTileBase::OpTileBase(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Tile); } -template +template OpTileBase::~OpTileBase() { if (attribute) delete attribute; } -template +template int OpTileBase::checkTensorAttributes() { // Check Tosa Level @@ -518,14 +519,14 @@ int OpTileBase::checkTensorAttributes() return 0; } -template +template int OpTile::eval() { // primary template shouldn't be called - FATAL_ERROR("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]); + FATAL_ERROR("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNameTOSAREFTYPE(Dtype)); } -template +template int OpTile<1, Dtype>::eval() { for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) @@ -537,7 +538,7 @@ int OpTile<1, Dtype>::eval() return GraphNode::eval(); } -template +template int OpTile<2, Dtype>::eval() { for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) @@ -553,7 +554,7 @@ int OpTile<2, Dtype>::eval() return GraphNode::eval(); } -template +template int OpTile<3, Dtype>::eval() { for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) @@ -573,7 +574,7 @@ int OpTile<3, Dtype>::eval() return GraphNode::eval(); } -template +template int OpTile<4, Dtype>::eval() { for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) @@ -597,7 +598,7 @@ int OpTile<4, Dtype>::eval() return GraphNode::eval(); } -template +template int OpTile<5, Dtype>::eval() { for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) @@ -626,7 +627,7 @@ int OpTile<5, Dtype>::eval() return GraphNode::eval(); } -template +template int OpTile<6, Dtype>::eval() { for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) @@ -659,7 +660,7 @@ int OpTile<6, Dtype>::eval() return GraphNode::eval(); } -template +template OpTranspose::OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -671,13 +672,13 @@ OpTranspose::OpTranspose(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Transpose); } -template +template OpTranspose::~OpTranspose() { if (attribute) delete attribute; } -template +template int OpTranspose::checkTensorAttributes() { // Check Tosa Level @@ -727,7 +728,7 @@ int OpTranspose::checkTensorAttributes() return 0; } -template +template int OpTranspose::eval() { out->getTensor() = in->getTensor().shuffle(perm_array); @@ -743,6 +744,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL) +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16); @@ -751,6 +753,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64); DEF_INSTANTIATE_RESHAPE(OpReshape, FP16); DEF_INSTANTIATE_RESHAPE(OpReshape, BF16); @@ -759,6 +762,7 @@ DEF_INSTANTIATE_RESHAPE(OpReshape, INT8); DEF_INSTANTIATE_RESHAPE(OpReshape, INT16); DEF_INSTANTIATE_RESHAPE(OpReshape, INT32); DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL); +DEF_INSTANTIATE_RESHAPE(OpReshape, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16); @@ -767,6 +771,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16); @@ -775,6 +780,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16); @@ -783,6 +789,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16); @@ -791,6 +798,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16); @@ -799,3 +807,4 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64); diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h index 3a6cb0d..94ce248 100644 --- a/reference_model/src/ops/data_layout.h +++ b/reference_model/src/ops/data_layout.h @@ -23,7 +23,7 @@ using namespace tosa; namespace TosaReference { -template +template class OpConcat : public GraphNode { public: @@ -45,7 +45,7 @@ protected: TosaReference::TensorTemplate* out; }; -template +template class OpPad : public GraphNode { public: @@ -66,7 +66,7 @@ protected: TosaPadAttribute* attribute; }; -template +template class OpReshape : public GraphNode { public: @@ -90,7 +90,7 @@ protected: TosaReference::TensorTemplate* out; }; -template +template class OpReverse : public GraphNode { public: @@ -112,7 +112,7 @@ protected: Eigen::array reverse_array; }; -template +template class OpSlice : public GraphNode { public: @@ -135,7 +135,7 @@ protected: TosaReference::TensorTemplate* out; }; -template +template class OpTileBase : public GraphNode { public: @@ -156,7 +156,7 @@ protected: }; // primary template for op tile -template +template class OpTile : public OpTileBase { public: @@ -170,12 +170,12 @@ protected: // partial specialization for specific rank #define DEF_OP_TILE_RANK(N) \ - template \ + template \ class OpTile : public OpTileBase \ { \ public: \ - OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \ - : OpTileBase(sgt_, attribute_, id_) \ + OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \ + : OpTileBase(sgt_, attribute_, id_) \ {} \ \ protected: \ @@ -191,7 +191,7 @@ DEF_OP_TILE_RANK(6) #undef DEF_OP_TILE_RANK -template +template class OpTranspose : public GraphNode { public: diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc index f5304a5..b7f987a 100644 --- a/reference_model/src/ops/data_nodes.cc +++ b/reference_model/src/ops/data_nodes.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -42,7 +42,7 @@ int OpConst::eval() return GraphNode::eval(); } -template +template OpIdentity::OpIdentity(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -52,11 +52,11 @@ OpIdentity::OpIdentity(SubgraphTraverser* sgt_, setRequiredRank(0, 6); } -template +template OpIdentity::~OpIdentity() {} -template +template int OpIdentity::checkTensorAttributes() { @@ -78,7 +78,7 @@ int OpIdentity::checkTensorAttributes() return 0; } -template +template int OpIdentity::eval() { out->getTensor() = in->getTensor(); @@ -96,3 +96,4 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP64); diff --git a/reference_model/src/ops/data_nodes.h b/reference_model/src/ops/data_nodes.h index 8761a08..395c667 100644 --- a/reference_model/src/ops/data_nodes.h +++ b/reference_model/src/ops/data_nodes.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ public: virtual int eval(); }; -template +template class OpIdentity : public GraphNode { public: diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 6aa0c0f..c5801e7 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -22,7 +22,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template +template BinaryNodeBase::BinaryNodeBase(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_) @@ -37,11 +37,11 @@ BinaryNodeBase::BinaryNodeBase(SubgraphTraverser* sgt_, fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); }; } -template +template BinaryNodeBase::~BinaryNodeBase() {} -template +template int BinaryNodeBase::checkTensorAttributes() { // Check Tosa Level @@ -90,7 +90,7 @@ int BinaryNodeBase::checkTensorAttributes() return 0; } -template +template int BinaryNodeBase::broadcast() { const std::vector& a_shape = a->getShape(); @@ -106,7 +106,7 @@ int BinaryNodeBase::broadcast() return 0; } -template +template int BinaryNode::eval() { this->broadcast(); @@ -124,7 +124,7 @@ int BinaryNode::eval() } // still need to partial specialize this, or Eigen will throw static assertion -template +template int BinaryNode<0, InDtype, OutDtype>::eval() { this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn); @@ -132,12 +132,12 @@ int BinaryNode<0, InDtype, OutDtype>::eval() return GraphNode::eval(); } -template +template int OpAdd::register_fcn() { switch (InDtype) { - case DType_INT32: + case TOSA_REF_TYPE_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { int64_t res_in_64 = static_cast(a) + b; int64_t i32_max_in_64 = static_cast(std::numeric_limits::max()); @@ -146,36 +146,39 @@ int OpAdd::register_fcn() return static_cast(res_in_64); }; break; - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc(a + b); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype)); } return 0; } -template +template int OpArithmeticRightShift::register_fcn() { bool round = attribute->round(); int32_t num_bits = 0; switch (Dtype) { - case DType_INT8: + case TOSA_REF_TYPE_INT8: num_bits = 8; break; - case DType_INT16: + case TOSA_REF_TYPE_INT16: num_bits = 16; break; - case DType_INT32: + case TOSA_REF_TYPE_INT32: num_bits = 32; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType { @@ -195,69 +198,69 @@ int OpArithmeticRightShift::register_fcn() return 0; } -template +template OpArithmeticRightShift::~OpArithmeticRightShift() { if (attribute) delete attribute; } -template +template int OpBitwiseAnd::register_fcn() { switch (Dtype) { - case DType_INT8: - case DType_INT16: - case DType_INT32: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpBitwiseOr::register_fcn() { switch (Dtype) { - case DType_INT8: - case DType_INT16: - case DType_INT32: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpBitwiseXor::register_fcn() { switch (Dtype) { - case DType_INT8: - case DType_INT16: - case DType_INT32: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpIntdiv::register_fcn() { switch (InDtype) { - case DType_INT32: + case TOSA_REF_TYPE_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b != 0, "OpIntDiv: divisor must be non-zero value"); int64_t res_in_64 = static_cast(a) / b; @@ -268,47 +271,47 @@ int OpIntdiv::register_fcn() }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype)); } return 0; } -template +template int OpLogicalAnd::register_fcn() { switch (Dtype) { - case DType_BOOL: + case TOSA_REF_TYPE_BOOL: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpLogicalLeftShift::register_fcn() { switch (Dtype) { - case DType_INT8: + case TOSA_REF_TYPE_INT8: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", (int32_t)b); return static_cast(static_cast(a << b)); }; break; - case DType_INT16: + case TOSA_REF_TYPE_INT16: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", (int32_t)b); return static_cast(static_cast(a << b)); }; break; - case DType_INT32: + case TOSA_REF_TYPE_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", (int32_t)b); @@ -316,32 +319,32 @@ int OpLogicalLeftShift::register_fcn() }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpLogicalRightShift::register_fcn() { switch (Dtype) { - case DType_INT8: + case TOSA_REF_TYPE_INT8: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", (int32_t)b); return static_cast(static_cast(a) >> b); }; break; - case DType_INT16: + case TOSA_REF_TYPE_INT16: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", (int32_t)b); return static_cast(static_cast(a) >> b); }; break; - case DType_INT32: + case TOSA_REF_TYPE_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", (int32_t)b); @@ -349,91 +352,96 @@ int OpLogicalRightShift::register_fcn() }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpLogicalOr::register_fcn() { switch (Dtype) { - case DType_BOOL: + case TOSA_REF_TYPE_BOOL: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpLogicalXor::register_fcn() { switch (Dtype) { - case DType_BOOL: + case TOSA_REF_TYPE_BOOL: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpMaximum::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: - case DType_INT32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_FP64: + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpMinimum::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: - case DType_INT32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_FP64: + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpMul::register_fcn() { int32_t shift = attribute->shift(); switch (InDtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc(a * b); }; break; - case DType_INT32: + case TOSA_REF_TYPE_FP64: + this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; }; + break; + case TOSA_REF_TYPE_INT32: this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType { int64_t result; if (shift > 0) @@ -457,8 +465,8 @@ int OpMul::register_fcn() return static_cast(result); }; break; - case DType_INT8: - case DType_INT16: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType { OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs; @@ -468,41 +476,44 @@ int OpMul::register_fcn() }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype)); } return 0; } -template +template OpMul::~OpMul() { if (attribute) delete attribute; } -template +template int OpPow::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc(powf(a, b)); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return pow(a, b); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpSub::register_fcn() { switch (InDtype) { - case DType_INT32: + case TOSA_REF_TYPE_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { int64_t res_in_64 = static_cast(a) - b; int64_t i32_max_in_64 = static_cast(std::numeric_limits::max()); @@ -511,19 +522,22 @@ int OpSub::register_fcn() return static_cast(res_in_64); }; break; - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc(a - b); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype)); } return 0; } -template +template OpTable::OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -535,13 +549,13 @@ OpTable::OpTable(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Table); } -template +template OpTable::~OpTable() { if (attribute) delete attribute; } -template +template int OpTable::checkTensorAttributes() { // Check Tosa Level @@ -573,12 +587,12 @@ int OpTable::checkTensorAttributes() return 0; } -template +template int OpTable::eval() { switch (InDtype) { - case DType_INT8: + case TOSA_REF_TYPE_INT8: this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { int32_t input_truncated = std::min(std::max(in, QInMin), QInMax); int32_t index = input_truncated - QInMin; @@ -587,7 +601,7 @@ int OpTable::eval() return value; }); break; - case DType_INT16: + case TOSA_REF_TYPE_INT16: this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { // 1. make sure input is int16 range int32_t input_truncated = std::min(std::max(in, QInMin), QInMax); @@ -610,7 +624,7 @@ int OpTable::eval() }); break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype)); } return GraphNode::eval(); @@ -630,11 +644,13 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, BF16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP32, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT32, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP64, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16); @@ -672,11 +688,13 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, BF16, BF16); @@ -684,15 +702,18 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP64, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16); @@ -703,3 +724,4 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, BF16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP32, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP64, BOOL); diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index 020ddb5..5f6e531 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ namespace TosaReference // the way of registering lambda + .binaryExpr() might sacrifice performance here // but it can avoid partially specialization for combination of {rankN, rank0} x {FP32/INT32, QU8, ...} // needs to revisit if performance becomes a bottleneck here -template +template class BinaryNodeBase : public GraphNode { public: @@ -67,7 +67,7 @@ protected: }; // primary class -template +template class BinaryNode : public BinaryNodeBase { public: @@ -86,7 +86,7 @@ public: }; // partial specialization for rank 0 -template +template class BinaryNode<0, InDtype, OutDtype> : public BinaryNodeBase<0, InDtype, OutDtype> { public: @@ -100,19 +100,19 @@ public: }; #define DEF_TEMPLATE_BINARY_OP_DEFAULT(Opname, OPNAME) \ - template \ + template \ class Op##Opname : public BinaryNode \ { \ public: \ - Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \ - : BinaryNode(sgt_, Op_##OPNAME, id_) \ + Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \ + : BinaryNode(sgt_, Op_##OPNAME, id_) \ { \ register_fcn(); \ } \ - static constexpr DType InDtype = Dtype; \ - static constexpr DType OutDtype = Dtype; \ - using InEigenType = typename GetEigenType::type; \ - using OutEigenType = typename GetEigenType::type; \ + static constexpr TOSA_REF_TYPE InDtype = Dtype; \ + static constexpr TOSA_REF_TYPE OutDtype = Dtype; \ + using InEigenType = typename GetEigenType::type; \ + using OutEigenType = typename GetEigenType::type; \ virtual int register_fcn(); \ }; @@ -133,7 +133,7 @@ DEF_TEMPLATE_BINARY_OP_DEFAULT(Sub, SUB) #undef DEF_TEMPLATE_BINARY_OP_DEFAULT -template +template class OpArithmeticRightShift : public BinaryNode { public: @@ -154,7 +154,7 @@ protected: TosaArithmeticRightShiftAttribute* attribute; }; -template +template class OpMul : public BinaryNode { public: @@ -175,7 +175,7 @@ protected: TosaMulAttribute* attribute; }; -template +template class OpTable : public GraphNode { public: @@ -185,9 +185,11 @@ public: virtual int checkTensorAttributes(); virtual int eval(); - static constexpr DType TableDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT16; - static constexpr DType OutDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT32; - static constexpr uint32_t TableNumEntries = (InDtype == DType_INT8) ? 256 : 513; + static constexpr TOSA_REF_TYPE TableDtype = + (InDtype == TOSA_REF_TYPE_INT8) ? TOSA_REF_TYPE_INT8 : TOSA_REF_TYPE_INT16; + static constexpr TOSA_REF_TYPE OutDtype = + (InDtype == TOSA_REF_TYPE_INT8) ? TOSA_REF_TYPE_INT8 : TOSA_REF_TYPE_INT32; + static constexpr uint32_t TableNumEntries = (InDtype == TOSA_REF_TYPE_INT8) ? 256 : 513; using InEigenType = typename GetEigenType::type; using TableEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc index 4d53ae4..090ce29 100644 --- a/reference_model/src/ops/ewise_ternary.cc +++ b/reference_model/src/ops/ewise_ternary.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template +template OpSelectBase::OpSelectBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -29,11 +29,11 @@ OpSelectBase::OpSelectBase(SubgraphTraverser* sgt_, setRequiredRank(0, 6); } -template +template OpSelectBase::~OpSelectBase() {} -template +template int OpSelectBase::checkTensorAttributes() { // Check Tosa Level @@ -66,13 +66,13 @@ int OpSelectBase::checkTensorAttributes() return 0; } -template +template int OpSelectBase::eval() { FATAL_ERROR("shouldn't be called"); } -template +template int OpSelect::broadcast() { const std::vector& cond_shape = this->cond->getShape(); @@ -90,7 +90,7 @@ int OpSelect::broadcast() return 0; } -template +template int OpSelect::eval() { this->broadcast(); @@ -102,7 +102,7 @@ int OpSelect::eval() return GraphNode::eval(); } -template +template int OpSelect<0, Dtype>::eval() { this->out->getTensor() = this->cond->getTensor().select(this->then_val->getTensor(), this->else_val->getTensor()); @@ -118,6 +118,7 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BF16); @@ -126,3 +127,4 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP64); diff --git a/reference_model/src/ops/ewise_ternary.h b/reference_model/src/ops/ewise_ternary.h index 75a2194..c6970cb 100644 --- a/reference_model/src/ops/ewise_ternary.h +++ b/reference_model/src/ops/ewise_ternary.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ namespace TosaReference // 3. Else_val: Rank N, type= // 4. Result: Rank N, type= // Cond, Then_val, Else_val need to be mutually-broadcastable -template +template class OpSelectBase : public GraphNode { public: @@ -39,7 +39,7 @@ public: virtual int checkTensorAttributes(); virtual int eval(); - using CondEigenType = typename GetEigenType::type; + using CondEigenType = typename GetEigenType::type; using InEigenType = typename GetEigenType::type; using TCond = Eigen::Tensor; using TIn = Eigen::Tensor; @@ -55,7 +55,7 @@ protected: }; // primary class -template +template class OpSelect : public OpSelectBase { public: @@ -69,7 +69,7 @@ public: }; // partial specialization for rank 0 -template +template class OpSelect<0, Dtype> : public OpSelectBase<0, Dtype> { public: diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc index 8dc37e2..514cb84 100644 --- a/reference_model/src/ops/ewise_unary.cc +++ b/reference_model/src/ops/ewise_unary.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template +template UnaryNode::UnaryNode(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_) : GraphNode(sgt_, op_, id_) { @@ -35,11 +35,11 @@ UnaryNode::UnaryNode(SubgraphTraverser* sgt_, const Op& op_, uint64 }; } -template +template UnaryNode::~UnaryNode() {} -template +template int UnaryNode::checkTensorAttributes() { // Check Tosa Level @@ -69,7 +69,7 @@ int UnaryNode::checkTensorAttributes() return 0; } -template +template int UnaryNode::eval() { this->result->getTensor() = this->a->getTensor().unaryExpr(this->fcn); @@ -77,71 +77,75 @@ int UnaryNode::eval() return GraphNode::eval(); } -template +template int OpAbs::register_fcn() { switch (Dtype) { - case DType_FP32: // No fpTrunc for FP32 as it is a no-op - case DType_INT32: + case TOSA_REF_TYPE_FP32: // No fpTrunc for FP32 as it is a no-op + case TOSA_REF_TYPE_FP64: + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); }; break; - case DType_FP16: - case DType_BF16: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(a > (InEigenType)0 ? a : (-a)); }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpBitwiseNot::register_fcn() { switch (Dtype) { - case DType_INT8: - case DType_INT16: - case DType_INT32: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a) -> OutEigenType { return ~a; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpCeil::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(ceilf(a)); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return ceil(a); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpClz::register_fcn() { int32_t num_bits; switch (Dtype) { - case DType_INT32: + case TOSA_REF_TYPE_INT32: num_bits = 32; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } this->fcn = [num_bits](int32_t a) -> int32_t { @@ -163,73 +167,82 @@ int OpClz::register_fcn() return 0; } -template +template int OpExp::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(expf(a)); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return exp(a); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpFloor::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(floorf(a)); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return floor(a); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpLog::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(logf(a)); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return log(a); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpLogicalNot::register_fcn() { switch (Dtype) { - case DType_BOOL: + case TOSA_REF_TYPE_BOOL: this->fcn = [](InEigenType a) -> OutEigenType { return !a; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template OpNegate::OpNegate(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -240,31 +253,37 @@ OpNegate::OpNegate(SubgraphTraverser* sgt_, register_fcn(); } -template +template OpNegate::~OpNegate() { if (attribute) delete attribute; } -template +template int OpNegate::register_fcn() { - ERROR_IF(Dtype != DType_INT8 && attribute->input1_zp() != 0, "OpNegate: zeropoint only for int8_t"); - ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpNegate: zeropoint only for int8_t"); + ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->input1_zp() != 0, "OpNegate: zeropoint only for int8_t"); + ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->output_zp() != 0, "OpNegate: zeropoint only for int8_t"); switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { InEigenType result = -(a); return fpTrunc(result); }; break; - case DType_INT16: - case DType_INT32: + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { + OutEigenType result = -(a); + return result; + }; + break; + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: this->fcn = [this](InEigenType a) -> OutEigenType { int64_t res_in_64 = 0L - a; int64_t i32_max_in_64 = static_cast(std::numeric_limits::max()); @@ -272,7 +291,7 @@ int OpNegate::register_fcn() REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpNegate: result not in acc type range (int32)"); int64_t max_clip_in_64, min_clip_in_64; - if (Dtype == DType_INT16) + if (Dtype == TOSA_REF_TYPE_INT16) { max_clip_in_64 = static_cast(std::numeric_limits::max()); min_clip_in_64 = static_cast(std::numeric_limits::min()); @@ -285,7 +304,7 @@ int OpNegate::register_fcn() return static_cast(std::min(max_clip_in_64, std::max(min_clip_in_64, res_in_64))); }; break; - case DType_INT8: + case TOSA_REF_TYPE_INT8: this->fcn = [this](InEigenType a) -> OutEigenType { int64_t res_in_64 = 0 - (a - attribute->input1_zp()); int64_t i32_max_in_64 = static_cast(std::numeric_limits::max()); @@ -297,41 +316,47 @@ int OpNegate::register_fcn() }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpReciprocal::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(1.0 / a); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return (1.0L / a); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template +template int OpRsqrt::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(1.0 / sqrtf(a)); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return (1.0L / sqrt(a)); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; @@ -345,11 +370,13 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16); @@ -358,20 +385,24 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL); @@ -381,11 +412,14 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP64); diff --git a/reference_model/src/ops/ewise_unary.h b/reference_model/src/ops/ewise_unary.h index 16a4c88..21ee276 100644 --- a/reference_model/src/ops/ewise_unary.h +++ b/reference_model/src/ops/ewise_unary.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ using namespace tosa; namespace TosaReference { -template +template class UnaryNode : public GraphNode { public: @@ -45,11 +45,11 @@ protected: }; #define DEF_TEMPLATE_UNARY_OP(Opname, OPNAME) \ - template \ + template \ class Op##Opname : public UnaryNode \ { \ public: \ - Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \ + Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \ : UnaryNode(sgt_, Op_##OPNAME, id_) \ { \ register_fcn(); \ @@ -75,7 +75,7 @@ DEF_TEMPLATE_UNARY_OP(Rsqrt, RSQRT) #undef DEF_TEMPLATE_UNARY_OP // Negate is the only unary op with attributes -template +template class OpNegate : public UnaryNode { public: diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc index 190b354..ca12cfe 100644 --- a/reference_model/src/ops/image.cc +++ b/reference_model/src/ops/image.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template +template OpResize::OpResize(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -35,14 +35,14 @@ OpResize::OpResize(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Resize); } -template +template OpResize::~OpResize() { if (attribute) delete attribute; } -template +template int OpResize::checkTensorAttributes() { if (validateRequiredOperands()) @@ -64,7 +64,8 @@ int OpResize::checkTensorAttributes() if (this->mode == ResizeMode_BILINEAR) { - if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16) + if (OutDtype != TOSA_REF_TYPE_INT32 && OutDtype != TOSA_REF_TYPE_INT48 && OutDtype != TOSA_REF_TYPE_FP32 && + OutDtype != TOSA_REF_TYPE_FP16 && OutDtype != TOSA_REF_TYPE_BF16 && OutDtype != TOSA_REF_TYPE_FP64) { printNodeValidationError("OpResize: invalid data type for BILINEAR"); return 1; @@ -72,7 +73,8 @@ int OpResize::checkTensorAttributes() } else { - if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16) + if (OutDtype != TOSA_REF_TYPE_INT8 && OutDtype != TOSA_REF_TYPE_INT16 && OutDtype != TOSA_REF_TYPE_FP32 && + OutDtype != TOSA_REF_TYPE_FP16 && OutDtype != TOSA_REF_TYPE_BF16 && OutDtype != TOSA_REF_TYPE_FP64) { printNodeValidationError("OpResize: invalid data type for NEAREST"); return 1; @@ -87,7 +89,7 @@ int OpResize::checkTensorAttributes() return 0; } -template +template int OpResize::eval() { int in_batch = in->getShape()[0]; @@ -157,24 +159,38 @@ int OpResize::eval() int32_t y = oy * scale_y_d + offset_y; int32_t x = ox * scale_x_d + offset_x; - float fy = static_cast(y) / static_cast(scale_y_n); - float fx = static_cast(x) / static_cast(scale_x_n); - - int32_t iy = floor(fy); - int32_t ix = floor(fx); - + int32_t iy; + int32_t ix; resize_t dy; resize_t dx; - if (std::is_floating_point::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) || - (typeid(resize_t) == typeid(half_float::half))) + if (std::is_same::value) { - dy = (resize_t)(fy - iy); - dx = (resize_t)(fx - ix); + const double fy_double = static_cast(y) / static_cast(scale_y_n); + const double fx_double = static_cast(x) / static_cast(scale_x_n); + iy = floor(fy_double); + ix = floor(fx_double); + + dy = (resize_t)(fy_double - iy); + dx = (resize_t)(fx_double - ix); } else { - dy = (resize_t)(y - (iy * scale_y_n)); - dx = (resize_t)(x - (ix * scale_x_n)); + const float fy = static_cast(y) / static_cast(scale_y_n); + const float fx = static_cast(x) / static_cast(scale_x_n); + iy = floor(fy); + ix = floor(fx); + + if (std::is_floating_point::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) || + (typeid(resize_t) == typeid(half_float::half))) + { + dy = (resize_t)(fy - iy); + dx = (resize_t)(fx - ix); + } + else + { + dy = (resize_t)(y - (iy * scale_y_n)); + dx = (resize_t)(x - (ix * scale_x_n)); + } } int32_t iy0 = MAX(iy, 0); @@ -248,3 +264,4 @@ DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, INT16, INT16, int16_t); DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP16, FP16, half_float::half); DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, BF16, BF16, Eigen::bfloat16); DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP32, FP32, float); +DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP64, FP64, double); diff --git a/reference_model/src/ops/image.h b/reference_model/src/ops/image.h index 508d2c8..6d5a418 100644 --- a/reference_model/src/ops/image.h +++ b/reference_model/src/ops/image.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ using namespace tosa; namespace TosaReference { -template +template class OpResize : public GraphNode { public: diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 1db3974..0a78884 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -37,11 +37,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, Op opType, TosaAttributeBase* attribute, uint64_t id, - DType inputDType, + TOSA_REF_TYPE inputDTYPE, int inputRank, - DType outputDType, + TOSA_REF_TYPE outputDTYPE, int outputRank, - DType weightDType, + TOSA_REF_TYPE weightDTYPE, int weightRank) { switch (opType) @@ -53,6 +53,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64); break; case Op_AVG_POOL2D: DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16); @@ -61,6 +62,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP32, FP32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP64, FP64); break; case Op_CONV2D: DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16); @@ -70,6 +72,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT4, INT32); DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64); break; case Op_CONV3D: DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16); @@ -79,6 +82,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT4, INT32); DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpConv3d, FP64, FP64, FP64); break; case Op_DEPTHWISE_CONV2D: DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); @@ -88,9 +92,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32); DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64); break; case Op_FFT2D: DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32); + DEF_FACTORY_ONE_TYPE(OpFFT2d, FP64); break; case Op_FULLY_CONNECTED: DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16); @@ -100,6 +106,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32); DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP64, FP64, FP64); break; case Op_MATMUL: DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP16, FP16); @@ -108,6 +115,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP32, FP32); DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT8, INT32); DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT16, INT48); + DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP64, FP64); break; case Op_MAX_POOL2D: DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16); @@ -115,9 +123,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP32); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16); + DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP64); break; case Op_RFFT2D: DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP32); + DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP64); break; case Op_TRANSPOSE_CONV2D: DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16); @@ -127,6 +137,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32); DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64); break; // activation_funcs @@ -136,16 +147,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP64); break; case Op_SIGMOID: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP64); break; case Op_TANH: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP64); break; // ewise_binary @@ -154,6 +168,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP64); break; case Op_ARITHMETIC_RIGHT_SHIFT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8); @@ -202,12 +217,14 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP64); break; case Op_MINIMUM: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP64); break; case Op_MUL: DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16); @@ -216,17 +233,20 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP64, FP64); break; case Op_POW: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP64); break; case Op_SUB: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP64); break; case Op_TABLE: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8); @@ -239,6 +259,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP64); break; case Op_BITWISE_NOT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT8); @@ -249,6 +270,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP64); break; case Op_CLZ: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); @@ -257,16 +279,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP64); break; case Op_FLOOR: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP64); break; case Op_LOG: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP64); break; case Op_LOGICAL_NOT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL); @@ -278,16 +303,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP64); break; case Op_RECIPROCAL: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP64); break; case Op_RSQRT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP64); break; // ewise_ternary @@ -299,6 +327,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP64); break; // comparison @@ -307,18 +336,21 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP64); break; case Op_GREATER: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP64); break; case Op_GREATER_EQUAL: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP64); break; // reduction @@ -335,6 +367,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP64); break; case Op_REDUCE_MIN: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16); @@ -343,16 +376,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP64); break; case Op_REDUCE_PRODUCT: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProductDouble, FP64); break; case Op_REDUCE_SUM: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumDouble, FP64); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32); break; @@ -365,6 +401,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64); break; case Op_PAD: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); @@ -374,6 +411,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64); break; case Op_RESHAPE: DEF_FACTORY_RESHAPE(OpReshape, FP16); @@ -383,6 +421,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RESHAPE(OpReshape, INT16); DEF_FACTORY_RESHAPE(OpReshape, INT32); DEF_FACTORY_RESHAPE(OpReshape, BOOL); + DEF_FACTORY_RESHAPE(OpReshape, FP64); break; case Op_REVERSE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); @@ -392,6 +431,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64); break; case Op_SLICE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); @@ -401,6 +441,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64); break; case Op_TILE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16); @@ -410,6 +451,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64); break; case Op_TRANSPOSE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); @@ -419,6 +461,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64); break; // scatter_gather @@ -429,6 +472,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpGather, FP16); DEF_FACTORY_ONE_TYPE(OpGather, BF16); DEF_FACTORY_ONE_TYPE(OpGather, FP32); + DEF_FACTORY_ONE_TYPE(OpGather, FP64); break; case Op_SCATTER: DEF_FACTORY_ONE_TYPE(OpScatter, INT8); @@ -437,6 +481,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpScatter, FP16); DEF_FACTORY_ONE_TYPE(OpScatter, BF16); DEF_FACTORY_ONE_TYPE(OpScatter, FP32); + DEF_FACTORY_ONE_TYPE(OpScatter, FP64); break; // image @@ -448,6 +493,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OpResize, FP16, FP16); DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OpResize, BF16, BF16); DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OpResize, FP32, FP32); + DEF_FACTORY_TWO_TYPE_RESIZE_FP64(OpResize, FP64, FP64); break; // data_nodes @@ -461,6 +507,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP64); break; // type_conversion @@ -499,6 +546,13 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64); break; case Op_RESCALE: DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8); diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h index 9117df4..f276e03 100644 --- a/reference_model/src/ops/op_factory.h +++ b/reference_model/src/ops/op_factory.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,19 +23,19 @@ #define DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) \ case RANK: \ - return new OP(sgt, attribute, id); + return new OP(sgt, attribute, id); #define DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \ case RANK: \ - return new OP(sgt, attribute, id); + return new OP(sgt, attribute, id); #define DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \ case RANK2: \ - return new OP(sgt, attribute, id); + return new OP(sgt, attribute, id); #define DEF_FACTORY_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \ case RANK2: \ - return new OP(sgt, attribute, id); + return new OP(sgt, attribute, id); #define DEF_FACTORY_ONE_RANK_0_6(OP) \ switch (inputRank) \ @@ -57,40 +57,42 @@ } #define DEF_FACTORY_ONE_TYPE(OP, DTYPE) \ - if (inputDType == DType_##DTYPE) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE) \ { \ - return new OP(sgt, attribute, id); \ + return new OP(sgt, attribute, id); \ } #define DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE, ACCUM_DTYPE) \ - if (inputDType == DType_##DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == TOSA_REF_TYPE_##ACCUM_DTYPE) \ { \ - return new OP(sgt, attribute, id); \ + return new OP(sgt, attribute, id); \ } #define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \ - if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && weightDTYPE == TOSA_REF_TYPE_##DTYPE2) \ { \ - return new OP(sgt, attribute, id); \ + return new OP(sgt, attribute, id); \ } #define DEF_FACTORY_TWO_TYPE_IN_OUT(OP, DTYPE1, DTYPE2) \ - if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \ { \ - return new OP(sgt, attribute, id); \ + return new OP(sgt, attribute, id); \ } #define DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE1, DTYPE2, ACCUM_DTYPE) \ - if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 \ - && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && weightDTYPE == TOSA_REF_TYPE_##DTYPE2 && \ + ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == TOSA_REF_TYPE_##ACCUM_DTYPE) \ { \ - return new OP(sgt, attribute, id); \ - } \ + return new OP(sgt, attribute, \ + id); \ + } #define DEF_FACTORY_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3) \ - if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 && outputDType == DType_##DTYPE3) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && weightDTYPE == TOSA_REF_TYPE_##DTYPE2 && \ + outputDTYPE == TOSA_REF_TYPE_##DTYPE3) \ { \ - return new OP(sgt, attribute, id); \ + return new OP(sgt, attribute, id); \ } // Statement-expression to evaluate accumulate attribute in-place @@ -108,35 +110,41 @@ FATAL_ERROR("Can't initialize Tosa" #ATTRIBUTE_NAME "Attribute.\nPre-initialization " \ "of this attribute is required in order to determine the accumulate type."); \ } \ - accumDType; \ - }) \ + ConvertDType(accumDType); \ + }) #define DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OP, DTYPE1, DTYPE2) \ - if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \ { \ - return new OP(sgt, attribute, id); \ + return new OP(sgt, attribute, id); \ } #define DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OP, DTYPE1, DTYPE2) \ - if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \ { \ - return new OP(sgt, attribute, id); \ + return new OP(sgt, attribute, id); \ } #define DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OP, DTYPE1, DTYPE2) \ - if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \ + { \ + return new OP(sgt, attribute, id); \ + } + +#define DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OP, DTYPE1, DTYPE2) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \ { \ - return new OP(sgt, attribute, id); \ + return new OP(sgt, attribute, id); \ } -#define DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OP, DTYPE1, DTYPE2) \ - if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ +#define DEF_FACTORY_TWO_TYPE_RESIZE_FP64(OP, DTYPE1, DTYPE2) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \ { \ - return new OP(sgt, attribute, id); \ + return new OP(sgt, attribute, id); \ } #define DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \ - if (inputDType == DType_##DTYPE) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE) \ { \ switch (inputRank) \ { \ @@ -151,7 +159,7 @@ } #define DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \ - if (inputDType == DType_##DTYPE) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE) \ { \ switch (inputRank) \ { \ @@ -165,7 +173,7 @@ } #define DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \ - if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \ { \ switch (inputRank) \ { \ @@ -180,7 +188,7 @@ } #define DEF_FACTORY_RESHAPE(OP, DTYPE) \ - if (inputDType == DType_##DTYPE && outputDType == DType_##DTYPE) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE && outputDTYPE == TOSA_REF_TYPE_##DTYPE) \ { \ switch (inputRank) \ { \ @@ -292,11 +300,11 @@ public: tosa::Op opType, tosa::TosaAttributeBase* attribute, uint64_t id, - tosa::DType inputDType, + TOSA_REF_TYPE inputDTYPE, int inputRank, - tosa::DType outputDType, + TOSA_REF_TYPE outputDTYPE, int outputRank, - tosa::DType weightDType, + TOSA_REF_TYPE weightDTYPE, int weightRank); }; }; // namespace TosaReference diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc index cd9d55f..bf8ba57 100644 --- a/reference_model/src/ops/reduction.cc +++ b/reference_model/src/ops/reduction.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template +template ReduceNode::ReduceNode(SubgraphTraverser* sgt_, const Op& op_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, op_, id_) { @@ -30,14 +30,14 @@ ReduceNode::ReduceNode(SubgraphTraverser* sgt_, const Op& op_, Tosa INIT_ATTRIBUTE(Axis); } -template +template ReduceNode::~ReduceNode() { if (attribute) delete attribute; } -template +template int ReduceNode::checkTensorAttributes() { if (validateRequiredOperands()) @@ -100,7 +100,7 @@ struct AnyReducer { bool finalize(const bool accum) const { return accum; } }; -template +template int OpReduceAll::eval() { this->out->getTensor() = this->in->getTensor().reduce(this->dims, AllReducer()).reshape(this->out->getTensor().dimensions()); @@ -108,7 +108,7 @@ int OpReduceAll::eval() return GraphNode::eval(); } -template +template int OpReduceAny::eval() { this->out->getTensor() = this->in->getTensor().reduce(this->dims, AnyReducer()).reshape(this->out->getTensor().dimensions()); @@ -116,7 +116,7 @@ int OpReduceAny::eval() return GraphNode::eval(); } -template +template int OpReduceMax::eval() { this->out->getTensor() = this->in->getTensor().maximum(this->dims).reshape(this->out->getTensor().dimensions()); @@ -124,7 +124,7 @@ int OpReduceMax::eval() return GraphNode::eval(); } -template +template int OpReduceMin::eval() { this->out->getTensor() = this->in->getTensor().minimum(this->dims).reshape(this->out->getTensor().dimensions()); @@ -132,35 +132,74 @@ int OpReduceMin::eval() return GraphNode::eval(); } -template +template int OpReduceProduct::eval() { switch(Dtype) { - case DType_FP16: - case DType_BF16: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc(f);}); break; - default: + case TOSA_REF_TYPE_FP32: this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()); break; + default: + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); + } + + return GraphNode::eval(); +} + +struct ProductDoubleReducer +{ + static const bool PacketAccess = false; + void reduce(const double val, double* accum) + { + *accum *= val; + } + double initialize() const + { + return 1.0; + } + double finalize(const double accum) const + { + return accum; + } +}; + +template +int OpReduceProductDouble::eval() +{ + switch (Dtype) + { + case TOSA_REF_TYPE_FP64: + this->out->getTensor() = this->in->getTensor() + .reduce(this->dims, ProductDoubleReducer()) + .reshape(this->out->getTensor().dimensions()); + break; + default: + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return GraphNode::eval(); } -template +template int OpReduceSum::eval() { switch(Dtype) { - case DType_FP16: - case DType_BF16: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc(f);}); break; - default: + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_INT32: this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()); break; + default: + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return GraphNode::eval(); @@ -183,7 +222,7 @@ struct SumRequiresReducer { SubgraphTraverser* parent_sgt; }; -template +template int OpReduceSumInt::eval() { this->out->getTensor() = this->in->getTensor().reduce(this->dims, SumRequiresReducer(this->parent_sgt)).reshape(this->out->getTensor().dimensions()); @@ -191,6 +230,40 @@ int OpReduceSumInt::eval() return GraphNode::eval(); } +struct SumDoubleReducer +{ + static const bool PacketAccess = false; + void reduce(const double val, double* accum) + { + *accum += val; + } + double initialize() const + { + return 0.0; + } + double finalize(const double accum) const + { + return accum; + } +}; + +template +int OpReduceSumDouble::eval() +{ + switch (Dtype) + { + case TOSA_REF_TYPE_FP64: + this->out->getTensor() = this->in->getTensor() + .reduce(this->dims, SumDoubleReducer()) + .reshape(this->out->getTensor().dimensions()); + break; + default: + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); + } + + return GraphNode::eval(); +} + // template explicit instantiation DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL); @@ -202,6 +275,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16); @@ -209,12 +283,15 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProductDouble, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumDouble, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32); diff --git a/reference_model/src/ops/reduction.h b/reference_model/src/ops/reduction.h index 6e98a76..aeb9f1d 100644 --- a/reference_model/src/ops/reduction.h +++ b/reference_model/src/ops/reduction.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ using namespace tosa; namespace TosaReference { -template +template class ReduceNode : public GraphNode { public: @@ -44,7 +44,7 @@ protected: TosaAxisAttribute* attribute; }; -template +template class OpReduceAll : public ReduceNode { public: @@ -54,7 +54,7 @@ public: virtual int eval(); }; -template +template class OpReduceAny : public ReduceNode { public: @@ -64,7 +64,7 @@ public: virtual int eval(); }; -template +template class OpReduceMax : public ReduceNode { public: @@ -74,7 +74,7 @@ public: virtual int eval(); }; -template +template class OpReduceMin : public ReduceNode { public: @@ -84,7 +84,7 @@ public: virtual int eval(); }; -template +template class OpReduceProduct : public ReduceNode { public: @@ -94,7 +94,17 @@ public: virtual int eval(); }; -template +template +class OpReduceProductDouble : public ReduceNode +{ +public: + OpReduceProductDouble(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) + : ReduceNode(sgt_, Op_REDUCE_PRODUCT, attribute_, id_) + {} + virtual int eval(); +}; + +template class OpReduceSum : public ReduceNode { public: @@ -104,7 +114,7 @@ public: virtual int eval(); }; -template +template class OpReduceSumInt : public ReduceNode { public: @@ -114,6 +124,16 @@ public: virtual int eval(); }; +template +class OpReduceSumDouble : public ReduceNode +{ +public: + OpReduceSumDouble(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) + : ReduceNode(sgt_, Op_REDUCE_SUM, attribute_, id_) + {} + virtual int eval(); +}; + }; // namespace TosaReference #endif diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc index bcd8ce5..80b6c58 100644 --- a/reference_model/src/ops/scatter_gather.cc +++ b/reference_model/src/ops/scatter_gather.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template +template OpGather::OpGather(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -29,11 +29,11 @@ OpGather::OpGather(SubgraphTraverser* sgt_, setRequiredOperands(2, 1); } -template +template OpGather::~OpGather() {} -template +template int OpGather::checkTensorAttributes() { if (validateRequiredOperands()) @@ -96,7 +96,7 @@ int OpGather::checkTensorAttributes() return 0; } -template +template int OpGather::eval() { for (int32_t n = 0; n < N; n++) @@ -116,7 +116,7 @@ int OpGather::eval() return GraphNode::eval(); } -template +template OpScatter::OpScatter(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -125,11 +125,11 @@ OpScatter::OpScatter(SubgraphTraverser* sgt_, setRequiredOperands(3, 1); } -template +template OpScatter::~OpScatter() {} -template +template int OpScatter::checkTensorAttributes() { if (validateRequiredOperands()) @@ -199,7 +199,7 @@ int OpScatter::checkTensorAttributes() return 0; } -template +template int OpScatter::eval() { // Initializes the output tensor with the input value for values that are unchanged by the scatter operation. @@ -229,6 +229,7 @@ DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32); DEF_INSTANTIATE_ONE_TYPE(OpGather, FP16); DEF_INSTANTIATE_ONE_TYPE(OpGather, BF16); DEF_INSTANTIATE_ONE_TYPE(OpGather, FP32); +DEF_INSTANTIATE_ONE_TYPE(OpGather, FP64); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16); @@ -236,3 +237,4 @@ DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32); DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, BF16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP32); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP64); diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h index af09153..fb675a9 100644 --- a/reference_model/src/ops/scatter_gather.h +++ b/reference_model/src/ops/scatter_gather.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ using namespace tosa; namespace TosaReference { -template +template class OpGather : public GraphNode { public: @@ -45,7 +45,7 @@ protected: TosaReference::TensorTemplate* output; }; -template +template class OpScatter : public GraphNode { public: diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h index ece14b1..6dd6e76 100644 --- a/reference_model/src/ops/template_types.h +++ b/reference_model/src/ops/template_types.h @@ -16,11 +16,10 @@ #ifndef OP_TEMPLATE_TYPES_H #define OP_TEMPLATE_TYPES_H -#include "tosa_generated.h" -#include +#include "dtype.h" #include "half.hpp" +#include #include -#include "arith_util.h" using namespace tosa; @@ -64,213 +63,218 @@ using Tensor5 = TensorTemplate>; template using Tensor6 = TensorTemplate>; -template +template struct GetEigenType; template <> -struct GetEigenType +struct GetEigenType +{ + using type = double; +}; +template <> +struct GetEigenType { using type = float; }; template <> -struct GetEigenType +struct GetEigenType { // NOTE: full precision used using type = float; }; template <> -struct GetEigenType +struct GetEigenType { // NOTE: full precision used using type = float; }; template <> -struct GetEigenType +struct GetEigenType { using type = int32_t; }; template <> -struct GetEigenType +struct GetEigenType { using type = int64_t; }; template <> -struct GetEigenType +struct GetEigenType { using type = bool; }; template <> -struct GetEigenType +struct GetEigenType { using type = int32_t; }; template <> -struct GetEigenType +struct GetEigenType { using type = int32_t; }; template <> -struct GetEigenType +struct GetEigenType { using type = int32_t; }; template <> -struct GetEigenType +struct GetEigenType { using type = int32_t; }; template <> -struct GetEigenType +struct GetEigenType { using type = int32_t; }; /* Get Accumulate Eigen Type: -Same behaviour as GetEigenType for all DTypes except the -single specialised case of DType_FP16. */ -template +Same behaviour as GetEigenType for all DTYPEs except the +single specialised case of TOSA_REF_TYPE_FP16. */ +template struct GetAccEigenType; template <> -struct GetAccEigenType +struct GetAccEigenType { using type = half_float::half; }; -template +template struct GetAccEigenType { using type = typename GetEigenType::type; }; // Meta function to get number of bits -template +template struct GetNumBits { static constexpr int32_t value = 0; }; template <> -struct GetNumBits +struct GetNumBits { static constexpr int32_t value = 1; }; template <> -struct GetNumBits +struct GetNumBits { static constexpr int32_t value = 8; }; template <> -struct GetNumBits +struct GetNumBits { static constexpr int32_t value = 16; }; template <> -struct GetNumBits +struct GetNumBits { static constexpr int32_t value = 4; }; template <> -struct GetNumBits +struct GetNumBits { static constexpr int32_t value = 8; }; template <> -struct GetNumBits +struct GetNumBits { static constexpr int32_t value = 16; }; template <> -struct GetNumBits +struct GetNumBits { static constexpr int32_t value = 32; }; template <> -struct GetNumBits +struct GetNumBits { static constexpr int32_t value = 48; }; template <> -struct GetNumBits +struct GetNumBits { static constexpr int32_t value = 16; }; // Meta function to get quantized min/max in compile time -template +template struct GetQMin { static constexpr int64_t value = INT64_C(0); }; template <> -struct GetQMin +struct GetQMin { static constexpr int64_t value = INT64_C(0); }; template <> -struct GetQMin +struct GetQMin { static constexpr int64_t value = INT64_C(0); }; template <> -struct GetQMin +struct GetQMin { static constexpr int64_t value = INT64_C(-8); }; template <> -struct GetQMin +struct GetQMin { static constexpr int64_t value = INT64_C(-128); }; template <> -struct GetQMin +struct GetQMin { static constexpr int64_t value = INT64_C(-32768); }; template <> -struct GetQMin +struct GetQMin { static constexpr int64_t value = -(INT64_C(1) << 31); }; template <> -struct GetQMin +struct GetQMin { static constexpr int64_t value = -(INT64_C(1) << 47); }; -template +template struct GetQMax { static constexpr int64_t value = INT64_C(0); }; template <> -struct GetQMax +struct GetQMax { static constexpr int64_t value = INT64_C(255); }; template <> -struct GetQMax +struct GetQMax { static constexpr int64_t value = INT64_C(65535); }; template <> -struct GetQMax +struct GetQMax { static constexpr int64_t value = INT64_C(7); }; template <> -struct GetQMax +struct GetQMax { static constexpr int64_t value = INT64_C(127); }; template <> -struct GetQMax +struct GetQMax { static constexpr int64_t value = INT64_C(32767); }; template <> -struct GetQMax +struct GetQMax { static constexpr int64_t value = (INT64_C(1) << 31) - 1; }; template <> -struct GetQMax +struct GetQMax { static constexpr int64_t value = (INT64_C(1) << 47) - 1; }; diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index b3845df..f8fd323 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -116,14 +116,14 @@ int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute, } int check_conv_attribute(tosa::TosaConvAttribute* attribute, - uint32_t conv_dimension, - std::vector input_shape, - std::vector output_shape, - std::vector weights, - uint32_t offset_kernel, - DType InDtype, - DType WeightDtype, - std::string& msg) + uint32_t conv_dimension, + std::vector input_shape, + std::vector output_shape, + std::vector weights, + uint32_t offset_kernel, + TOSA_REF_TYPE InDtype, + TOSA_REF_TYPE WeightDtype, + std::string& msg) { if (attribute->pad().size() != (2 * conv_dimension)) { @@ -226,11 +226,13 @@ int check_conv_attribute(tosa::TosaConvAttribute* attribute, return 1; } - if (InDtype != DType_INT8 && attribute->input_zp() != 0) { + if (InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0) + { msg = "Input zero point must be zero for non-int8 data"; return 1; } - if (WeightDtype != DType_INT8 && attribute->weight_zp() != 0) { + if (WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0) + { msg = "Weight zero point must be zero for non-int8 data"; return 1; } @@ -318,7 +320,7 @@ int check_fft_shape(const std::vector& in_real, return 0; } -template +template OpArgMax::OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -330,14 +332,14 @@ OpArgMax::OpArgMax(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Axis); } -template +template OpArgMax::~OpArgMax() { if (attribute) delete attribute; } -template +template int OpArgMax::checkTensorAttributes() { if (validateRequiredOperands()) @@ -355,7 +357,7 @@ int OpArgMax::checkTensorAttributes() return 1; } - if (outputs[0]->getDtype() != DType_INT32) + if (outputs[0]->getDtype() != TOSA_REF_TYPE_INT32) { printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator"); return 1; @@ -400,7 +402,7 @@ int OpArgMax::checkTensorAttributes() return 0; } -template +template int OpArgMax::eval() { Eigen::Tensor index = this->input->getTensor().argmax(attribute->axis()); @@ -410,7 +412,7 @@ int OpArgMax::eval() return GraphNode::eval(); } -template +template OpAvgPool2d::OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -422,14 +424,14 @@ OpAvgPool2d::OpAvgPool2d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Pool); } -template +template OpAvgPool2d::~OpAvgPool2d() { if (attribute) delete attribute; } -template +template int OpAvgPool2d::checkTensorAttributes() { if (validateRequiredOperands()) @@ -449,8 +451,10 @@ int OpAvgPool2d::checkTensorAttributes() in = dynamic_cast*>(inputs[0]); out = dynamic_cast*>(outputs[0]); - ERROR_IF(Dtype != DType_INT8 && attribute->input_zp() != 0, "OpAvgPool2d: Input zeropoint must be zero for non int8_t data"); - ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpAvgPool2d: Output zeropoint must be zero for non int8_t data"); + ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0, + "OpAvgPool2d: Input zeropoint must be zero for non int8_t data"); + ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->output_zp() != 0, + "OpAvgPool2d: Output zeropoint must be zero for non int8_t data"); std::string msg; if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg)) @@ -466,8 +470,9 @@ int OpAvgPool2d::checkTensorAttributes() // This calculates the number of padding elements used for each location along an axis // Average pooling only divides by the number of elements used, not including padding. // This function uses left/right, but is also used for vertical padding with top/bottom -template -ETensor1 OpAvgPool2d::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right) +template +ETensor1 OpAvgPool2d::calculate_div_map_1d( + int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right) { ETensor1 result(out_size); @@ -495,7 +500,7 @@ ETensor1 OpAvgPool2d::calculate_div_map_1d(int in_size // assuming input and output tensor have same scales like tflite reference // so no need to scale input and output -template +template int OpAvgPool2d::eval() { int in_batch = this->in->getShape()[0]; @@ -531,7 +536,7 @@ int OpAvgPool2d::eval() LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL"); LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL"); - tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype(); + TOSA_REF_TYPE accum_dtype = ConvertDType(this->attribute->accum_dtype()); DEBUG_INFO(OP, "perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], " @@ -556,7 +561,7 @@ int OpAvgPool2d::eval() pad[3] = std::make_pair(0, 0); ETensor4 input_val = this->in->getTensor(); - if (Dtype == DType_INT8) + if (Dtype == TOSA_REF_TYPE_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); } @@ -604,7 +609,8 @@ int OpAvgPool2d::eval() dm2_h.contract(dm2_w, contract_dims) .reshape(Eigen::array{ 1, out_height, out_width, 1 }) .broadcast(bcast); - if (Dtype != DType_FP32 && Dtype != DType_FP16 && Dtype != DType_BF16) + if (Dtype != TOSA_REF_TYPE_FP32 && Dtype != TOSA_REF_TYPE_FP16 && Dtype != TOSA_REF_TYPE_BF16 && + Dtype != TOSA_REF_TYPE_FP64) { try { @@ -632,7 +638,7 @@ int OpAvgPool2d::eval() return GraphNode::eval(); } -template +template OpConv2d::OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -644,14 +650,14 @@ OpConv2d::OpConv2d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Conv); } -template +template OpConv2d::~OpConv2d() { if (attribute) delete attribute; } -template +template int OpConv2d::checkTensorAttributes() { if (validateRequiredOperands()) @@ -688,7 +694,7 @@ int OpConv2d::checkTensorAttributes() return 0; } -template +template int OpConv2d::eval() { int in_batch = this->input->getShape()[0]; @@ -781,7 +787,7 @@ int OpConv2d::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (InDtype == DType_INT8 || WeightDtype == DType_INT8) + if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -817,7 +823,7 @@ int OpConv2d::eval() // reshape back to [N, H, W, C] this->output->getTensor() = biased_output.reshape(col2im_output_dims); - if (OutDtype == DType_INT48) + if (OutDtype == TOSA_REF_TYPE_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -826,7 +832,7 @@ int OpConv2d::eval() return GraphNode::eval(); } -template +template OpConv3d::OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -838,14 +844,14 @@ OpConv3d::OpConv3d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Conv); } -template +template OpConv3d::~OpConv3d() { if (attribute) delete attribute; } -template +template int OpConv3d::checkTensorAttributes() { if (validateRequiredOperands()) @@ -882,7 +888,7 @@ int OpConv3d::checkTensorAttributes() return 0; } -template +template int OpConv3d::eval() { int in_batch = this->input->getShape()[0]; @@ -959,7 +965,7 @@ int OpConv3d::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (InDtype == DType_INT8 || WeightDtype == DType_INT8) + if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -1020,7 +1026,7 @@ int OpConv3d::eval() } } - if (OutDtype == DType_INT48) + if (OutDtype == TOSA_REF_TYPE_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -1029,10 +1035,10 @@ int OpConv3d::eval() return GraphNode::eval(); } -template +template OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - uint64_t id_) + TosaAttributeBase* attribute_, + uint64_t id_) : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_) { setRequiredOperands(3, 1); @@ -1041,14 +1047,14 @@ OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTra INIT_ATTRIBUTE(Conv); } -template +template OpDepthwiseConv2d::~OpDepthwiseConv2d() { if (attribute) delete attribute; } -template +template int OpDepthwiseConv2d::checkTensorAttributes() { if (validateRequiredOperands()) @@ -1085,7 +1091,7 @@ int OpDepthwiseConv2d::checkTensorAttributes() return 0; } -template +template int OpDepthwiseConv2d::eval() { int in_batch = this->input->getShape()[0]; @@ -1149,7 +1155,7 @@ int OpDepthwiseConv2d::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (InDtype == DType_INT8 || WeightDtype == DType_INT8) + if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -1205,7 +1211,7 @@ int OpDepthwiseConv2d::eval() } } - if (OutDtype == DType_INT48) + if (OutDtype == TOSA_REF_TYPE_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -1214,10 +1220,10 @@ int OpDepthwiseConv2d::eval() return GraphNode::eval(); } -template +template OpFullyConnected::OpFullyConnected(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - uint64_t id_) + TosaAttributeBase* attribute_, + uint64_t id_) : GraphNode(sgt_, Op_FULLY_CONNECTED, id_) { setRequiredOperands(3, 1); @@ -1226,14 +1232,14 @@ OpFullyConnected::OpFullyConnected(SubgraphTrave INIT_ATTRIBUTE(FullyConnected); } -template +template OpFullyConnected::~OpFullyConnected() { if (attribute) delete attribute; } -template +template int OpFullyConnected::checkTensorAttributes() { if (validateRequiredOperands()) @@ -1265,13 +1271,15 @@ int OpFullyConnected::checkTensorAttributes() output = dynamic_cast*>(outputs[0]); - ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpFullyConnected: Input zeropoint must be zero for non int8_t data"); - ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpFullyConnected: Weight zeropoint must be zero for non int8_t data"); + ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0, + "OpFullyConnected: Input zeropoint must be zero for non int8_t data"); + ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0, + "OpFullyConnected: Weight zeropoint must be zero for non int8_t data"); return 0; } -template +template int OpFullyConnected::eval() { typedef Eigen::Tensor::DimensionPair DimPair; @@ -1289,7 +1297,7 @@ int OpFullyConnected::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle); - if (InDtype == DType_INT8 || WeightDtype == DType_INT8) + if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -1299,7 +1307,7 @@ int OpFullyConnected::eval() input_val.template cast().contract(weight_val.template cast(), dims).template cast() + this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast); - if (OutDtype == DType_INT48) + if (OutDtype == TOSA_REF_TYPE_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -1307,7 +1315,7 @@ int OpFullyConnected::eval() return GraphNode::eval(); } -template +template OpMatMul::OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -1319,14 +1327,14 @@ OpMatMul::OpMatMul(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(MatMul); } -template +template OpMatMul::~OpMatMul() { if (attribute) delete attribute; } -template +template int OpMatMul::checkTensorAttributes() { if (validateRequiredOperands()) @@ -1382,13 +1390,15 @@ int OpMatMul::checkTensorAttributes() } W = b->getShape()[2]; - ERROR_IF(Dtype != DType_INT8 && attribute->a_zp() != 0, "OpMatMul: A zeropoint must be zero for non int8_t data"); - ERROR_IF(Dtype != DType_INT8 && attribute->b_zp() != 0, "OpMatMul: B zeropoint must be zero for non int8_t data"); + ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->a_zp() != 0, + "OpMatMul: A zeropoint must be zero for non int8_t data"); + ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->b_zp() != 0, + "OpMatMul: B zeropoint must be zero for non int8_t data"); return 0; } -template +template int OpMatMul::eval() { typedef Eigen::Tensor::DimensionPair DimPair; @@ -1396,7 +1406,7 @@ int OpMatMul::eval() TIn a_val = this->a->getTensor(); TIn b_val = this->b->getTensor(); - if (Dtype == DType_INT8) + if (Dtype == TOSA_REF_TYPE_INT8) { a_val = a_val - (InEigenType)attribute->a_zp(); b_val = b_val - (InEigenType)attribute->b_zp(); @@ -1434,7 +1444,7 @@ int OpMatMul::eval() } } - if (OutDtype == DType_INT48) + if (OutDtype == TOSA_REF_TYPE_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -1443,7 +1453,7 @@ int OpMatMul::eval() return GraphNode::eval(); } -template +template OpMaxPool2d::OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -1455,14 +1465,14 @@ OpMaxPool2d::OpMaxPool2d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Pool); } -template +template OpMaxPool2d::~OpMaxPool2d() { if (attribute) delete attribute; } -template +template int OpMaxPool2d::checkTensorAttributes() { if (validateRequiredOperands()) @@ -1493,7 +1503,7 @@ int OpMaxPool2d::checkTensorAttributes() return 0; } -template +template int OpMaxPool2d::eval() { int in_batch = this->in->getShape()[0]; @@ -1586,10 +1596,8 @@ int OpMaxPool2d::eval() return GraphNode::eval(); } -template -OpFFT2d::OpFFT2d(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - uint64_t id_) +template +OpFFT2d::OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_FFT2D, id_) { setRequiredOperands(2, 2); @@ -1598,14 +1606,14 @@ OpFFT2d::OpFFT2d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(FFT); } -template -OpFFT2d::~OpFFT2d() { +template +OpFFT2d::~OpFFT2d() +{ if (attribute) delete attribute; } - -template +template int OpFFT2d::checkTensorAttributes() { if (validateRequiredOperands()) @@ -1643,7 +1651,7 @@ int OpFFT2d::checkTensorAttributes() return 0; } -template +template int OpFFT2d::eval() { int in_real_batch = this->in_real->getShape()[0]; @@ -1709,7 +1717,7 @@ int OpFFT2d::eval() return GraphNode::eval(); } -template +template OpRFFT2d::OpRFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -1719,11 +1727,11 @@ OpRFFT2d::OpRFFT2d(SubgraphTraverser* sgt_, setRequiredRank(3); } -template +template OpRFFT2d::~OpRFFT2d() {} -template +template int OpRFFT2d::checkTensorAttributes() { if (validateRequiredOperands()) @@ -1759,7 +1767,7 @@ int OpRFFT2d::checkTensorAttributes() return 0; } -template +template int OpRFFT2d::eval() { int32_t in_batch = in->getShape()[0]; @@ -1815,10 +1823,10 @@ int OpRFFT2d::eval() return GraphNode::eval(); } -template +template OpTransposeConv2d::OpTransposeConv2d(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - uint64_t id_) + TosaAttributeBase* attribute_, + uint64_t id_) : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_) { setRequiredOperands(3, 1); @@ -1827,14 +1835,14 @@ OpTransposeConv2d::OpTransposeConv2d(SubgraphTra INIT_ATTRIBUTE(TransposeConv); } -template +template OpTransposeConv2d::~OpTransposeConv2d() { if (attribute) delete attribute; } -template +template int OpTransposeConv2d::checkTensorAttributes() { if (validateRequiredOperands()) @@ -1923,13 +1931,15 @@ int OpTransposeConv2d::checkTensorAttributes() return 1; } - ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data"); - ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data"); + ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0, + "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data"); + ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0, + "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data"); return 0; } -template +template int OpTransposeConv2d::eval() { int in_batch = this->input->getShape()[0]; @@ -1985,7 +1995,7 @@ int OpTransposeConv2d::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (InDtype == DType_INT8 || WeightDtype == DType_INT8) + if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -2040,7 +2050,7 @@ int OpTransposeConv2d::eval() } } - if (OutDtype == DType_INT48) + if (OutDtype == TOSA_REF_TYPE_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -2055,6 +2065,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32); @@ -2062,6 +2073,7 @@ DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, BF16, FP32); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32); +DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP64, FP64); // [in_t, weight_t, out_t] DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16); @@ -2071,6 +2083,7 @@ DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP32, FP32, FP32); DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48); +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP64, FP64, FP64); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32); @@ -2079,6 +2092,7 @@ DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP32, FP32, FP32); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48); +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP64, FP64, FP64); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32); @@ -2087,8 +2101,10 @@ DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64); DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32); +DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64); DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16); DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32); @@ -2097,6 +2113,7 @@ DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP32, FP32, FP32); DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48); +DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP64, FP64, FP64); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT8, INT32); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT16, INT48); @@ -2104,14 +2121,17 @@ DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP16); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32); +DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP64, FP64); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16); +DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP64); DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32); +DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32); @@ -2120,3 +2140,4 @@ DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64); diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index 9ef4a58..df53f2b 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -24,7 +24,7 @@ using namespace tosa; namespace TosaReference { -template +template class OpArgMax : public GraphNode { public: @@ -35,7 +35,7 @@ public: virtual int eval(); using InEigenType = typename GetEigenType::type; - using OutEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TOut = Eigen::Tensor; @@ -45,7 +45,7 @@ protected: TosaReference::TensorTemplate* output; }; -template +template class OpAvgPool2d : public GraphNode { public: @@ -74,7 +74,7 @@ protected: ETensor1 calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right); }; -template +template class OpConv2d : public GraphNode { public: @@ -104,7 +104,7 @@ protected: tosa::TosaConvAttribute* attribute; }; -template +template class OpConv3d : public GraphNode { public: @@ -134,7 +134,7 @@ protected: tosa::TosaConvAttribute* attribute; }; -template +template class OpDepthwiseConv2d : public GraphNode { public: @@ -164,7 +164,7 @@ protected: tosa::TosaConvAttribute* attribute; }; -template +template class OpFullyConnected : public GraphNode { public: @@ -195,7 +195,7 @@ protected: tosa::TosaFullyConnectedAttribute* attribute; }; -template +template class OpMatMul : public GraphNode { public: @@ -227,7 +227,7 @@ protected: tosa::TosaMatMulAttribute* attribute; }; -template +template class OpMaxPool2d : public GraphNode { public: @@ -248,7 +248,7 @@ protected: tosa::TosaPoolAttribute* attribute; }; -template +template class OpFFT2d : public GraphNode { public: @@ -271,7 +271,7 @@ protected: tosa::TosaFFTAttribute* attribute; }; -template +template class OpRFFT2d : public GraphNode { public: @@ -292,7 +292,7 @@ protected: TosaReference::TensorTemplate* out_imag; }; -template +template class OpTransposeConv2d : public GraphNode { public: diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 9034add..68ffb1f 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template +template OpRescale::OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -35,14 +35,14 @@ OpRescale::OpRescale(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Rescale); } -template +template OpRescale::~OpRescale() { if (attribute) delete attribute; } -template +template int OpRescale::checkTensorAttributes() { // Check Tosa Level @@ -69,31 +69,33 @@ int OpRescale::checkTensorAttributes() ASSERT_MEM(in && out); - if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (InDtype != DType_UINT16) && (attribute->input_zp() != 0)) + if ((InDtype != TOSA_REF_TYPE_INT8) && (InDtype != TOSA_REF_TYPE_UINT8) && (InDtype != TOSA_REF_TYPE_UINT16) && + (attribute->input_zp() != 0)) { - printNodeValidationError("OpRescale: Input DType not INT8/UINT8/UINT16 and zero point not 0"); + printNodeValidationError("OpRescale: Input TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0"); return 1; } - if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (OutDtype != DType_UINT16) && (attribute->output_zp() != 0)) + if ((OutDtype != TOSA_REF_TYPE_INT8) && (OutDtype != TOSA_REF_TYPE_UINT8) && (OutDtype != TOSA_REF_TYPE_UINT16) && + (attribute->output_zp() != 0)) { - printNodeValidationError("OpRescale: Output DType not INT8/UINT8/UINT16 and zero point not 0"); + printNodeValidationError("OpRescale: Output TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0"); return 1; } - if ((InDtype == DType_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768))) + if ((InDtype == TOSA_REF_TYPE_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768))) { - printNodeValidationError("OpRescale: Input DType UINT16 and zero point not 0 or 32768"); + printNodeValidationError("OpRescale: Input TOSA_REF_TYPE UINT16 and zero point not 0 or 32768"); return 1; } - if ((OutDtype == DType_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768))) + if ((OutDtype == TOSA_REF_TYPE_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768))) { - printNodeValidationError("OpRescale: Output DType UINT16 and zero point not 0 or 32768"); + printNodeValidationError("OpRescale: Output TOSA_REF_TYPE UINT16 and zero point not 0 or 32768"); return 1; } - if (attribute->scale32() && (InDtype == DType_INT48)) + if (attribute->scale32() && (InDtype == TOSA_REF_TYPE_INT48)) { printNodeValidationError("OpRescale: Scale set to true but input type is INT48"); return 1; @@ -108,7 +110,7 @@ int OpRescale::checkTensorAttributes() return 0; } -template +template int OpRescale::eval() { int32_t input_zp = attribute->input_zp(); @@ -237,7 +239,7 @@ int OpRescale::eval() return GraphNode::eval(); } -template +template OpCast::OpCast(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -247,11 +249,11 @@ OpCast::OpCast(SubgraphTraverser* sgt_, setRequiredRank(0, 6); } -template +template OpCast::~OpCast() {} -template +template int OpCast::checkTensorAttributes() { // Check Tosa Level @@ -281,7 +283,7 @@ int OpCast::checkTensorAttributes() return 0; } -template +template int OpCast::eval() { this->out->getTensor() = this->in->getTensor().unaryExpr(cast_helper.get_fcn()); @@ -289,7 +291,7 @@ int OpCast::eval() return GraphNode::eval(); } -template +template CastHelper::CastHelper() { fcn = [](InEigenType in) -> OutEigenType { @@ -298,14 +300,14 @@ CastHelper::CastHelper() }; } -template -CastHelper::CastHelper() +template +CastHelper::CastHelper() { fcn = [](InEigenType in) -> bool { return (in != 0) ? true : false; }; } -template -CastHelper::CastHelper() +template +CastHelper::CastHelper() { fcn = [](bool in) -> OutEigenType { OutEigenType out = in ? (OutEigenType)1 : (OutEigenType)0; @@ -313,8 +315,8 @@ CastHelper::CastHelper() }; } -template -CastHelper::CastHelper() +template +CastHelper::CastHelper() { // Integer data converted to fp16 (stored as fp32) fcn = [](InEigenType in) -> float { @@ -324,17 +326,17 @@ CastHelper::CastHelper() }; } -CastHelper::CastHelper() +CastHelper::CastHelper() { // fp32 data converted to fp16 (stored as fp32) fcn = [](float in) -> float { - float out = fpTrunc(in); // truncate required for conversion from higher precision + float out = fpTrunc(in); // truncate required for conversion from higher precision return out; }; } -template -CastHelper::CastHelper() +template +CastHelper::CastHelper() { // Integer data converted to bf16 (stored as fp32) fcn = [](InEigenType in) -> float { @@ -343,16 +345,16 @@ CastHelper::CastHelper() }; } -CastHelper::CastHelper() +CastHelper::CastHelper() { // fp32 data converted to bf16 (stored as fp32) fcn = [](float in) -> float { - return fpTrunc(in); // truncate required for conversions from higher precision + return fpTrunc(in); // truncate required for conversions from higher precision }; } -template -CastHelper::CastHelper() +template +CastHelper::CastHelper() { // fp16 data (stored as fp32) converted to integer fcn = [](float in) -> OutEigenType { @@ -366,7 +368,7 @@ CastHelper::CastHelper() }; } -CastHelper::CastHelper() +CastHelper::CastHelper() { // No-op since fp16 values treated internally as their fp32 representation fcn = [](float in) -> OutEigenType { @@ -374,8 +376,8 @@ CastHelper::CastHelper() }; } -template -CastHelper::CastHelper() +template +CastHelper::CastHelper() { // bf16 data (stored as fp32) converted to integer fcn = [](float in) -> OutEigenType { @@ -386,7 +388,7 @@ CastHelper::CastHelper() }; } -CastHelper::CastHelper() +CastHelper::CastHelper() { // No-op since bf16 values treated as truncated fp32 internally fcn = [](InEigenType in) -> OutEigenType { @@ -394,8 +396,8 @@ CastHelper::CastHelper() }; } -template -CastHelper::CastHelper() +template +CastHelper::CastHelper() { // Integer data converted to fp32 fcn = [](InEigenType in) -> float { @@ -404,8 +406,8 @@ CastHelper::CastHelper() }; } -template -CastHelper::CastHelper() +template +CastHelper::CastHelper() { // fp32 data converted to integer fcn = [](float in) -> OutEigenType { @@ -416,6 +418,31 @@ CastHelper::CastHelper() }; } +template +CastHelper::CastHelper() +{ + switch (OutDtype) + { + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: + // fp64 data converted to integer + fcn = [](InEigenType in) -> OutEigenType { + OutEigenType out = std::rint(in); + out = std::max(out, OutMin); + out = std::min(out, OutMax); + return out; + }; + break; + case TOSA_REF_TYPE_FP64: + // no op + fcn = [](InEigenType in) -> OutEigenType { return in; }; + break; + default: + ASSERT_MSG(false, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(OutDtype)); + } +} + // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16); @@ -451,6 +478,13 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16); diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h index e2fc6e2..98799a0 100644 --- a/reference_model/src/ops/type_conversion.h +++ b/reference_model/src/ops/type_conversion.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ using namespace tosa; namespace TosaReference { -template +template class OpRescale : public GraphNode { public: @@ -46,7 +46,7 @@ protected: TosaReference::TensorTemplate* out; }; -template +template class CastHelper { public: @@ -64,12 +64,12 @@ private: FcnType fcn; }; -template -class CastHelper +template +class CastHelper { public: using InEigenType = typename GetEigenType::type; - using OutEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; using FcnType = std::function; CastHelper(); const FcnType& get_fcn() const @@ -81,11 +81,11 @@ private: FcnType fcn; }; -template -class CastHelper +template +class CastHelper { public: - using InEigenType = typename GetEigenType::type; + using InEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; using FcnType = std::function; static constexpr int32_t OutMin = GetQMin::value; @@ -100,12 +100,12 @@ private: FcnType fcn; }; -template -class CastHelper +template +class CastHelper { public: using InEigenType = typename GetEigenType::type; - using OutEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; using FcnType = std::function; CastHelper(); const FcnType& get_fcn() const @@ -117,11 +117,11 @@ private: FcnType fcn; }; -template -class CastHelper +template +class CastHelper { public: - using InEigenType = typename GetEigenType::type; + using InEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; using FcnType = std::function; static constexpr int32_t OutMin = GetQMin::value; @@ -137,11 +137,11 @@ private: }; template <> -class CastHelper +class CastHelper { public: - using InEigenType = typename GetEigenType::type; - using OutEigenType = typename GetEigenType::type; + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; using FcnType = std::function; CastHelper(); const FcnType& get_fcn() const @@ -153,12 +153,12 @@ private: FcnType fcn; }; -template -class CastHelper +template +class CastHelper { public: using InEigenType = typename GetEigenType::type; - using OutEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; using FcnType = std::function; CastHelper(); const FcnType& get_fcn() const @@ -170,11 +170,11 @@ private: FcnType fcn; }; -template -class CastHelper +template +class CastHelper { public: - using InEigenType = typename GetEigenType::type; + using InEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; using FcnType = std::function; static constexpr int32_t OutMin = GetQMin::value; @@ -190,11 +190,11 @@ private: }; template <> -class CastHelper +class CastHelper { public: - using InEigenType = typename GetEigenType::type; - using OutEigenType = typename GetEigenType::type; + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; using FcnType = std::function; CastHelper(); const FcnType& get_fcn() const @@ -206,12 +206,12 @@ private: FcnType fcn; }; -template -class CastHelper +template +class CastHelper { public: using InEigenType = typename GetEigenType::type; - using OutEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; using FcnType = std::function; CastHelper(); const FcnType& get_fcn() const @@ -224,11 +224,11 @@ private: }; template <> -class CastHelper +class CastHelper { public: - using InEigenType = typename GetEigenType::type; - using OutEigenType = typename GetEigenType::type; + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; using FcnType = std::function; CastHelper(); const FcnType& get_fcn() const @@ -241,11 +241,11 @@ private: }; template <> -class CastHelper +class CastHelper { public: - using InEigenType = typename GetEigenType::type; - using OutEigenType = typename GetEigenType::type; + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; using FcnType = std::function; CastHelper(); const FcnType& get_fcn() const @@ -257,11 +257,11 @@ private: FcnType fcn; }; -template -class CastHelper +template +class CastHelper { public: - using InEigenType = typename GetEigenType::type; + using InEigenType = typename GetEigenType::type; using OutEigenType = typename GetEigenType::type; using FcnType = std::function; static constexpr int32_t OutMin = GetQMin::value; @@ -276,7 +276,26 @@ private: FcnType fcn; }; -template +template +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + static constexpr int32_t OutMin = GetQMin::value; + static constexpr int32_t OutMax = GetQMax::value; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template class OpCast : public GraphNode { public: diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index e7641ba..4508291 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -138,9 +138,9 @@ int SubgraphTraverser::initializeGraph() for (auto op : block->GetOperators()) { // translated TosaSerializationOperator to GraphNode - DType input_dtype = DType_UNKNOWN; - DType output_dtype = DType_UNKNOWN; - DType weight_dtype = DType_UNKNOWN; + TOSA_REF_TYPE input_dtype = TOSA_REF_TYPE_UNKNOWN; + TOSA_REF_TYPE output_dtype = TOSA_REF_TYPE_UNKNOWN; + TOSA_REF_TYPE weight_dtype = TOSA_REF_TYPE_UNKNOWN; uint32_t input_rank = 0; uint32_t output_rank = 0; uint32_t weight_rank = 0; @@ -185,7 +185,7 @@ int SubgraphTraverser::initializeGraph() !input_tensor, "SubgraphTraverser::initializeGraph(): fail to get input tensor %s from TosaSerializationHandler", input_name.c_str()); - input_dtype = input_tensor->GetDtype(); + input_dtype = ConvertDType(input_tensor->GetDtype()); input_rank = input_tensor->GetShape().size(); } @@ -207,7 +207,7 @@ int SubgraphTraverser::initializeGraph() !weight_tensor, "SubgraphTraverser::initializeGraph(): fail to get weight tensor %s from TosaSerializationHandler", weight_name.c_str()); - weight_dtype = weight_tensor->GetDtype(); + weight_dtype = ConvertDType(weight_tensor->GetDtype()); weight_rank = weight_tensor->GetShape().size(); } @@ -220,7 +220,7 @@ int SubgraphTraverser::initializeGraph() !output_tensor, "SubgraphTraverser::initializeGraph(): fail to get output tensor %s from TosaSerializationHandler", output_name.c_str()); - output_dtype = output_tensor->GetDtype(); + output_dtype = ConvertDType(output_tensor->GetDtype()); output_rank = output_tensor->GetShape().size(); DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx, @@ -246,16 +246,16 @@ int SubgraphTraverser::initializeGraph() fprintf(g_func_debug.func_debug_file, "SubgraphTraverser::initializeGraph(): OpFactory could not allocate op %8s input=(%s rank %d) " "-> (%s rank %d)", - EnumNamesOp()[op->GetOp()], EnumNamesDType()[input_dtype], input_rank, - EnumNamesDType()[output_dtype], output_rank); + EnumNamesOp()[op->GetOp()], EnumNameTOSAREFTYPE(input_dtype), input_rank, + EnumNameTOSAREFTYPE(output_dtype), output_rank); } else { fprintf(g_func_debug.func_debug_file, "SubgraphTraverser::initializeGraph(): OpFactory could not allocate op %8s input=(%s rank %d), " "weight=(%s rank %d) -> (%s rank %d)", - EnumNamesOp()[op->GetOp()], EnumNamesDType()[input_dtype], input_rank, - EnumNamesDType()[weight_dtype], weight_rank, EnumNamesDType()[output_dtype], output_rank); + EnumNamesOp()[op->GetOp()], EnumNameTOSAREFTYPE(input_dtype), input_rank, + EnumNameTOSAREFTYPE(weight_dtype), weight_rank, EnumNameTOSAREFTYPE(output_dtype), output_rank); } for (auto& ts : op->GetInputTensorNames()) @@ -309,7 +309,7 @@ int SubgraphTraverser::initializeGraph() TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size()); SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d", - ts->GetName().c_str(), EnumNamesDType()[ts->GetDtype()], (int)ts->GetShape().size()); + ts->GetName().c_str(), EnumNameDType(ts->GetDtype()), (int)ts->GetShape().size()); addTensor(tensor); } @@ -411,73 +411,89 @@ int SubgraphTraverser::allocateTensor() if (!ts->GetData().empty()) { DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str()); - switch (ts->GetDtype()) + auto serialization_dtype = ts->GetDtype(); + switch (serialization_dtype) { - case DType_INT4: - { + case DType_INT4: { std::vector i4_data; TosaSerializationHandler::ConvertU8toI4(ts->GetData(), tensor->getElementCount(), i4_data); std::vector i32_data(i4_data.begin(), i4_data.end()); tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); } break; - case DType_INT8: - { + case DType_INT8: { std::vector i8_data; TosaSerializationHandler::ConvertU8toI8(ts->GetData(), tensor->getElementCount(), i8_data); std::vector i32_data(i8_data.begin(), i8_data.end()); tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); } break; - case DType_INT16: - { + case DType_INT16: { std::vector i16_data; TosaSerializationHandler::ConvertU8toI16(ts->GetData(), tensor->getElementCount(), i16_data); std::vector i32_data(i16_data.begin(), i16_data.end()); tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); } break; - case DType_INT32: - { + case DType_INT32: { std::vector i32_data; TosaSerializationHandler::ConvertU8toI32(ts->GetData(), tensor->getElementCount(), i32_data); tensor->setTensorValueInt32(i32_data.size(), i32_data.data()); } break; - case DType_INT48: - { + case DType_INT48: { std::vector i64_data; TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data); tensor->setTensorValueInt64(i64_data.size(), i64_data.data()); } break; - case DType_FP16: - { + case DType_FP16: { // Interpret f16 data as float std::vector f16_data; TosaSerializationHandler::ConvertU8toF16(ts->GetData(), tensor->getElementCount(), f16_data); - tensor->setTensorValueFloat(f16_data.size(), f16_data.data()); + if (tensor->getDtype() == TOSA_REF_TYPE_FP64) + { + std::vector f64_data(f16_data.begin(), f16_data.end()); + tensor->setTensorValueDouble(f64_data.size(), f64_data.data()); + } + else + { + tensor->setTensorValueFloat(f16_data.size(), f16_data.data()); + } } break; - case DType_BF16: - { + case DType_BF16: { std::vector fp32_data; TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); // Ensure valid bfloat16 stored in each float for (auto f : fp32_data) ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f); - tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data()); + if (tensor->getDtype() == TOSA_REF_TYPE_FP64) + { + std::vector f64_data(fp32_data.begin(), fp32_data.end()); + tensor->setTensorValueDouble(f64_data.size(), f64_data.data()); + } + else + { + tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data()); + } } break; - case DType_FP32: - { + case DType_FP32: { std::vector fp32_data; TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); - tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data()); + if (tensor->getDtype() == TOSA_REF_TYPE_FP64) + { + std::vector f64_data(fp32_data.begin(), fp32_data.end()); + tensor->setTensorValueDouble(f64_data.size(), f64_data.data()); + } + else + { + tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data()); + } } break; - case DType_BOOL: - { + case DType_BOOL: { std::vector bool_data; TosaSerializationHandler::ConvertU8toBool(ts->GetData(), tensor->getElementCount(), bool_data); @@ -493,7 +509,7 @@ int SubgraphTraverser::allocateTensor() break; default: SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported tensor type %s.", - EnumNamesDType()[ts->GetDtype()]); + EnumNameDType(ts->GetDtype())); } } } @@ -802,14 +818,14 @@ int SubgraphTraverser::validateGraph() if (g_func_config.tosa_profile == 0) { - DType dtype = currTensor->getDtype(); + TOSA_REF_TYPE dtype = currTensor->getDtype(); // Float-point disallowed - if (dtype == DType_FP32 || dtype == DType_FP16) + if (dtype == TOSA_REF_TYPE_FP32 || dtype == TOSA_REF_TYPE_FP16) { WARNING("SubgraphTraverser::validateGraph(): TOSA Base Inference profile selected: All floating point " "disabled, but %s tensor %s found\n", - EnumNamesDType()[dtype], currTensor->getName().c_str()); + EnumNameTOSAREFTYPE(dtype), currTensor->getName().c_str()); return 1; } } diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index 7af2069..08d5b2a 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -22,11 +22,14 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -TosaReference::Tensor::Tensor(std::string tensorName_, DType tensorDtype_, std::vector shape_) +TosaReference::Tensor::Tensor(const std::string tensorName_, + const DType serializationDtype_, + const std::vector shape_) + : tensorName(tensorName_) + , serializationDtype(serializationDtype_) + , shape(shape_) + , tensorDtype(ConvertDType(serializationDtype_)) { - tensorName = std::string(tensorName_); - tensorDtype = tensorDtype_; - shape = std::vector(shape_); producer = nullptr; isValid = false; consumers.clear(); @@ -75,7 +78,7 @@ int TosaReference::Tensor::addConsumer(GraphNode* node) int TosaReference::Tensor::dumpTensorParams(FILE* out) const { - fprintf(out, "Name: %s DType=%s isValid=%d Rank=%d Shape=%s\n", tensorName.c_str(), EnumNamesDType()[getDtype()], + fprintf(out, "Name: %s DType=%s isValid=%d Rank=%d Shape=%s\n", tensorName.c_str(), EnumNameTOSAREFTYPE(getDtype()), getIsValid(), getRank(), getShapeAsString().c_str()); return 0; @@ -83,7 +86,7 @@ int TosaReference::Tensor::dumpTensorParams(FILE* out) const int TosaReference::Tensor::dumpTensorParams(std::ostream& out) const { - out << "Name: " << getName() << " DType=" << EnumNamesDType()[getDtype()] << " isValid=" << getIsValid() + out << "Name: " << getName() << " DType=" << EnumNameTOSAREFTYPE(getDtype()) << " isValid=" << getIsValid() << " Rank=" << getRank() << " Shape=" << getShapeAsString() << "\n"; return 0; @@ -92,28 +95,33 @@ int TosaReference::Tensor::dumpTensorParams(std::ostream& out) const int TosaReference::Tensor::readFromNpyFile(const char* filename) { uint32_t elements = getElementCount(); - float* fdatabuf = nullptr; + double* f64databuf = nullptr; + float* f32databuf = nullptr; half_float::half* f16databuf = nullptr; int32_t* i32databuf = nullptr; int64_t* i64databuf = nullptr; bool* bdatabuf = nullptr; NumpyUtilities::NPError nperror; - DType dtype = getDtype(); + TOSA_REF_TYPE dtype = getDtype(); + DType serialization_dtype = getSerializationDtype(); - switch (dtype) + assert(dtype == ConvertDType(serialization_dtype)); + // if dtype is FP64, serialization_dtype must be one of FP32, FP16, BF16 + assert(dtype != TOSA_REF_TYPE_FP64 || serialization_dtype == DType_FP32 || serialization_dtype == DType_FP16 || + serialization_dtype == DType_BF16); + + switch (serialization_dtype) { case DType_FP32: case DType_BF16: - fdatabuf = (float*)calloc(sizeof(float), elements); - ASSERT_MEM(fdatabuf); + f32databuf = (float*)calloc(sizeof(float), elements); + ASSERT_MEM(f32databuf); - nperror = NumpyUtilities::readFromNpyFile(filename, elements, fdatabuf); + nperror = NumpyUtilities::readFromNpyFile(filename, elements, f32databuf); break; case DType_FP16: f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements); ASSERT_MEM(f16databuf); - fdatabuf = (float*)calloc(sizeof(float), elements); - ASSERT_MEM(fdatabuf); nperror = NumpyUtilities::readFromNpyFile(filename, elements, f16databuf); break; @@ -141,7 +149,7 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) nperror = NumpyUtilities::readFromNpyFile(filename, elements, bdatabuf); break; default: - FATAL_ERROR("unsupported tensor type=%s", EnumNamesDType()[getDtype()]); + FATAL_ERROR("unknown tensor type=%s", EnumNameDType(serialization_dtype)); } switch (nperror) @@ -154,7 +162,7 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) FATAL_ERROR("readFromNpyFile: IO error reading file: %s", filename); case NumpyUtilities::FILE_TYPE_MISMATCH: FATAL_ERROR("readFromNpyFile: Tensor type %s and Numpy file type mismatch for tensor %s filename %s", - EnumNamesDType()[getDtype()], getName().c_str(), filename); + EnumNameTOSAREFTYPE(getDtype()), getName().c_str(), filename); case NumpyUtilities::HEADER_PARSE_ERROR: FATAL_ERROR("Numpy header parsing error for file: %s", filename); case NumpyUtilities::BUFFER_SIZE_MISMATCH: @@ -166,75 +174,133 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) switch (dtype) { - case DType_FP16: + case TOSA_REF_TYPE_FP16: // Convert from fp16 to fp32 so that fp16 values can be manipulated as float + f32databuf = (float*)calloc(sizeof(float), elements); + ASSERT_MEM(f32databuf); for (uint32_t i=0; i < elements; i++) { - fdatabuf[i] = half_float::half_cast(f16databuf[i]); + f32databuf[i] = half_float::half_cast(f16databuf[i]); } - if (setTensorValueFloat(elements, fdatabuf)) + if (setTensorValueFloat(elements, f32databuf)) { free(f16databuf); - free(fdatabuf); + free(f32databuf); return 1; } break; - case DType_BF16: + case TOSA_REF_TYPE_BF16: for (uint32_t i=0; i < elements; i++) { ASSERT_MSG( - checkValidBFloat(fdatabuf[i]), + checkValidBFloat(f32databuf[i]), "Input float value not a valid bfloat16 value." ); } - if (setTensorValueFloat(elements, fdatabuf)) + if (setTensorValueFloat(elements, f32databuf)) { - free(fdatabuf); + free(f32databuf); return 1; } break; - case DType_FP32: - if (setTensorValueFloat(elements, fdatabuf)) + case TOSA_REF_TYPE_FP32: + if (setTensorValueFloat(elements, f32databuf)) { - free(fdatabuf); + free(f32databuf); return 1; } break; - case DType_INT32: - case DType_UINT8: - case DType_INT4: - case DType_INT8: - case DType_INT16: - case DType_UINT16: + case TOSA_REF_TYPE_INT32: + case TOSA_REF_TYPE_UINT8: + case TOSA_REF_TYPE_INT4: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_UINT16: if (setTensorValueInt32(elements, i32databuf)) { free(i32databuf); return 1; } break; - case DType_INT48: + case TOSA_REF_TYPE_INT48: if (setTensorValueInt64(elements, i64databuf)) { free(i64databuf); return 1; } break; - case DType_BOOL: + case TOSA_REF_TYPE_BOOL: if (setTensorValueBool(elements, bdatabuf)) { free(i32databuf); return 1; } break; + case TOSA_REF_TYPE_FP64: + switch (serialization_dtype) + { + case DType_FP16: + // FP16 -> FP64 + f64databuf = (double*)calloc(sizeof(double), elements); + ASSERT_MEM(f64databuf); + for (uint32_t i = 0; i < elements; i++) + { + f64databuf[i] = half_float::half_cast(f16databuf[i]); + } + if (setTensorValueDouble(elements, f64databuf)) + { + free(f16databuf); + free(f64databuf); + return 1; + } + break; + case DType_BF16: + // BF16 -> FP64 + f64databuf = (double*)calloc(sizeof(double), elements); + ASSERT_MEM(f64databuf); + for (uint32_t i = 0; i < elements; i++) + { + ASSERT_MSG(checkValidBFloat(f32databuf[i]), "Input float value not a valid bfloat16 value."); + f64databuf[i] = static_cast(f32databuf[i]); + } + if (setTensorValueDouble(elements, f64databuf)) + { + free(f32databuf); + free(f64databuf); + return 1; + } + break; + case DType_FP32: + // FP32 -> FP64 + f64databuf = (double*)calloc(sizeof(double), elements); + ASSERT_MEM(f64databuf); + for (uint32_t i = 0; i < elements; i++) + { + f64databuf[i] = static_cast(f32databuf[i]); + } + if (setTensorValueDouble(elements, f64databuf)) + { + free(f32databuf); + free(f64databuf); + return 1; + } + break; + default: + FATAL_ERROR("unexpected tensor type=%s and original tensor type=%s", EnumNameTOSAREFTYPE(dtype), + EnumNameDType(serialization_dtype)); + } + break; default: - FATAL_ERROR("unsupported tensor type=%s", EnumNamesDType()[getDtype()]); + FATAL_ERROR("unsupported tensor type=%s", EnumNameTOSAREFTYPE(dtype)); } setIsValid(); - if (fdatabuf) - free(fdatabuf); + if (f32databuf) + free(f32databuf); if (f16databuf) free(f16databuf); + if (f64databuf) + free(f64databuf); if (i32databuf) free(i32databuf); if (i64databuf) @@ -247,58 +313,59 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) int TosaReference::Tensor::writeToNpyFile(const char* filename) const { - float* fdatabuf = nullptr; + float* f32databuf = nullptr; + double* f64databuf = nullptr; half_float::half* f16databuf = nullptr; int32_t* i32databuf = nullptr; int64_t* i64databuf = nullptr; bool* bdatabuf = nullptr; NumpyUtilities::NPError nperror; uint32_t elements = getElementCount(); - DType dtype = getDtype(); + const TOSA_REF_TYPE dtype = getDtype(); switch (dtype) { - case DType_FP32: - case DType_BF16: - fdatabuf = (float*)calloc(sizeof(float), elements); - ASSERT_MEM(fdatabuf); + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_BF16: + f32databuf = (float*)calloc(sizeof(float), elements); + ASSERT_MEM(f32databuf); - if (getTensorValueFloat(elements, fdatabuf)) + if (getTensorValueFloat(elements, f32databuf)) { - free(fdatabuf); + free(f32databuf); return 1; } - nperror = NumpyUtilities::writeToNpyFile(filename, shape, fdatabuf); + nperror = NumpyUtilities::writeToNpyFile(filename, shape, f32databuf); - free(fdatabuf); + free(f32databuf); break; - case DType_FP16: - fdatabuf = (float*)calloc(sizeof(float), elements); - ASSERT_MEM(fdatabuf); + case TOSA_REF_TYPE_FP16: + f32databuf = (float*)calloc(sizeof(float), elements); + ASSERT_MEM(f32databuf); f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements); ASSERT_MEM(f16databuf); - if (getTensorValueFloat(elements, fdatabuf)) + if (getTensorValueFloat(elements, f32databuf)) { - free(fdatabuf); + free(f32databuf); free(f16databuf); return 1; } // Convert fp32 to fp16 so that output file contains valid fp16 data for (uint32_t i=0; i < elements; i++) { - f16databuf[i] = half_float::half_cast(fdatabuf[i]); + f16databuf[i] = half_float::half_cast(f32databuf[i]); } nperror = NumpyUtilities::writeToNpyFile(filename, shape, f16databuf); - free(fdatabuf); + free(f32databuf); free(f16databuf); break; - case DType_INT32: - case DType_UINT8: - case DType_INT4: - case DType_INT8: - case DType_INT16: - case DType_UINT16: + case TOSA_REF_TYPE_INT32: + case TOSA_REF_TYPE_UINT8: + case TOSA_REF_TYPE_INT4: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_UINT16: i32databuf = (int32_t*)calloc(sizeof(int32_t), elements); ASSERT_MEM(i32databuf); @@ -312,7 +379,7 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const free(i32databuf); break; - case DType_INT48: + case TOSA_REF_TYPE_INT48: i64databuf = (int64_t*)calloc(sizeof(int64_t), elements); ASSERT_MEM(i64databuf); @@ -326,7 +393,7 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const free(i64databuf); break; - case DType_BOOL: + case TOSA_REF_TYPE_BOOL: bdatabuf = (bool*)calloc(sizeof(bool), elements); ASSERT_MEM(bdatabuf); @@ -340,8 +407,22 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const free(bdatabuf); break; - default: - FATAL_ERROR("unsupported tensor type=%s", EnumNamesDType()[getDtype()]); + case TOSA_REF_TYPE_FP64: + // @todo : support FP64 dtype + f64databuf = (double*)calloc(sizeof(double), elements); + ASSERT_MEM(f64databuf); + + if (getTensorValueDouble(elements, f64databuf)) + { + free(f64databuf); + return 1; + } + nperror = NumpyUtilities::writeToNpyFile(filename, shape, f64databuf); + + free(f64databuf); + break; + case TOSA_REF_TYPE_UNKNOWN: + FATAL_ERROR("unsupported tensor type=%s", EnumNameTOSAREFTYPE(getDtype())); } switch (nperror) @@ -386,11 +467,11 @@ int TosaReference::TensorTemplate::copyValueFrom(TosaReference::Tensor* src) return 1; \ } \ \ - uint32_t src_rank = src->getRank(); \ - uint32_t dst_rank = this->getRank(); \ - DType src_dtype = src->getDtype(); \ - DType dst_dtype = this->getDtype(); \ - bool tensor_match = true; \ + const uint32_t src_rank = src->getRank(); \ + const uint32_t dst_rank = this->getRank(); \ + const TOSA_REF_TYPE src_dtype = src->getDtype(); \ + const TOSA_REF_TYPE dst_dtype = this->getDtype(); \ + bool tensor_match = true; \ \ if ((src_rank != dst_rank) || (src_dtype != dst_dtype)) \ { \ @@ -413,8 +494,9 @@ int TosaReference::TensorTemplate::copyValueFrom(TosaReference::Tensor* src) { \ WARNING("source tensor %s (rank=%u, dtype=%s, shape=%s) doesn't match destination tensor %s (rank=%u, " \ "dtype=%s, shape=%s)", \ - src->getName().c_str(), src_rank, EnumNamesDType()[src_dtype], src->getShapeAsString().c_str(), \ - this->getName().c_str(), dst_rank, EnumNamesDType()[dst_dtype], this->getShapeAsString().c_str()); \ + src->getName().c_str(), src_rank, EnumNameTOSAREFTYPE(src_dtype), src->getShapeAsString().c_str(), \ + this->getName().c_str(), dst_rank, EnumNameTOSAREFTYPE(dst_dtype), \ + this->getShapeAsString().c_str()); \ return 1; \ } \ \ @@ -429,6 +511,13 @@ DEF_CTENSOR_COPY_VALUE_FROM(3, float) DEF_CTENSOR_COPY_VALUE_FROM(4, float) DEF_CTENSOR_COPY_VALUE_FROM(5, float) DEF_CTENSOR_COPY_VALUE_FROM(6, float) +DEF_CTENSOR_COPY_VALUE_FROM(0, double) +DEF_CTENSOR_COPY_VALUE_FROM(1, double) +DEF_CTENSOR_COPY_VALUE_FROM(2, double) +DEF_CTENSOR_COPY_VALUE_FROM(3, double) +DEF_CTENSOR_COPY_VALUE_FROM(4, double) +DEF_CTENSOR_COPY_VALUE_FROM(5, double) +DEF_CTENSOR_COPY_VALUE_FROM(6, double) DEF_CTENSOR_COPY_VALUE_FROM(0, int32_t) DEF_CTENSOR_COPY_VALUE_FROM(1, int32_t) DEF_CTENSOR_COPY_VALUE_FROM(2, int32_t) @@ -453,13 +542,37 @@ DEF_CTENSOR_COPY_VALUE_FROM(6, bool) #undef DEF_CTENSOR_COPY_VALUE_FROM +int TosaReference::Tensor::readfromVector(const ArrayProxy vals) +{ + uint32_t elements = getElementCount(); + switch (getDtype()) + { + case TOSA_REF_TYPE_FP64: + if (vals.size() != elements) + { + WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + setTensorValueDouble(elements, vals.data()); + break; + default: + WARNING("The input type (float) doesn't match the data type assigned to the tensor (%s).", + EnumNameTOSAREFTYPE(getDtype())); + return -2; + } + setIsValid(); + return 0; +} + int TosaReference::Tensor::readfromVector(const ArrayProxy vals) { uint32_t elements = getElementCount(); switch (getDtype()) { - case DType_FP16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_FP32: if (vals.size() != elements) { WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", @@ -469,7 +582,7 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy vals) setTensorValueFloat(elements, vals.data()); break; - case DType_BF16: + case TOSA_REF_TYPE_BF16: if (vals.size() != elements) { WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", @@ -489,7 +602,7 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy vals) break; default: WARNING("The input type (float) doesn't match the data type assigned to the tensor (%s).", - EnumNameDType(getDtype())); + EnumNameTOSAREFTYPE(getDtype())); return -2; } setIsValid(); @@ -503,7 +616,7 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy val switch (getDtype()) { - case DType_FP16: + case TOSA_REF_TYPE_FP16: if (vals.size() != elements) { WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", @@ -521,7 +634,7 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy val break; default: WARNING("The input type doesn't match the data type assigned to the tensor (%s).", - EnumNameDType(getDtype())); + EnumNameTOSAREFTYPE(getDtype())); return -2; } setIsValid(); @@ -533,12 +646,12 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy vals) uint32_t elements = getElementCount(); switch (getDtype()) { - case DType_INT32: - case DType_UINT8: - case DType_INT4: - case DType_INT8: - case DType_INT16: - case DType_UINT16: + case TOSA_REF_TYPE_INT32: + case TOSA_REF_TYPE_UINT8: + case TOSA_REF_TYPE_INT4: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_UINT16: if (vals.size() != elements) { WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", @@ -550,7 +663,7 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy vals) break; default: WARNING("The input type doesn't match the data type assigned to the tensor (%s).", - EnumNameDType(getDtype())); + EnumNameTOSAREFTYPE(getDtype())); return -2; } setIsValid(); @@ -562,7 +675,7 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy vals) uint32_t elements = getElementCount(); switch (getDtype()) { - case DType_INT48: + case TOSA_REF_TYPE_INT48: if (vals.size() != elements) { WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", @@ -574,7 +687,7 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy vals) break; default: WARNING("The input type doesn't match the data type assigned to the tensor (%s).", - EnumNameDType(getDtype())); + EnumNameTOSAREFTYPE(getDtype())); return -2; } setIsValid(); @@ -587,7 +700,7 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy vals) switch (getDtype()) { - case DType_BOOL: + case TOSA_REF_TYPE_BOOL: if (vals.size() != elements) { WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", @@ -599,21 +712,45 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy vals) break; default: WARNING("The input type (bool) doesn't match the data type assigned to the tensor (%s).", - EnumNameDType(getDtype())); + EnumNameTOSAREFTYPE(getDtype())); return -2; } setIsValid(); return 0; } +int TosaReference::Tensor::writeToVector(ArrayProxy vals) +{ + uint32_t elements = getElementCount(); + + switch (getDtype()) + { + case TOSA_REF_TYPE_FP64: + if (vals.size() != elements) + { + WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + getTensorValueDouble(elements, vals.data()); + break; + default: + WARNING("The output type (float) doesn't match the data type assigned to the tensor (%s).", + EnumNameTOSAREFTYPE(getDtype())); + return -2; + } + return 0; +} + int TosaReference::Tensor::writeToVector(ArrayProxy vals) { uint32_t elements = getElementCount(); switch (getDtype()) { - case DType_FP16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_FP32: if (vals.size() != elements) { WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", @@ -623,7 +760,7 @@ int TosaReference::Tensor::writeToVector(ArrayProxy vals) getTensorValueFloat(elements, vals.data()); break; - case DType_BF16: + case TOSA_REF_TYPE_BF16: if (vals.size() != elements) { WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", @@ -644,7 +781,7 @@ int TosaReference::Tensor::writeToVector(ArrayProxy vals) break; default: WARNING("The output type (float) doesn't match the data type assigned to the tensor (%s).", - EnumNameDType(getDtype())); + EnumNameTOSAREFTYPE(getDtype())); return -2; } return 0; @@ -657,7 +794,7 @@ int TosaReference::Tensor::writeToVector(ArrayProxy vals) switch (getDtype()) { - case DType_FP16: + case TOSA_REF_TYPE_FP16: if (vals.size() != elements) { WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", @@ -675,7 +812,7 @@ int TosaReference::Tensor::writeToVector(ArrayProxy vals) break; default: WARNING("The output type doesn't match the data type assigned to the tensor (%s).", - EnumNameDType(getDtype())); + EnumNameTOSAREFTYPE(getDtype())); return -2; } return 0; @@ -687,12 +824,12 @@ int TosaReference::Tensor::writeToVector(ArrayProxy vals) switch (getDtype()) { - case DType_INT32: - case DType_UINT8: - case DType_INT4: - case DType_INT8: - case DType_INT16: - case DType_UINT16: + case TOSA_REF_TYPE_INT32: + case TOSA_REF_TYPE_UINT8: + case TOSA_REF_TYPE_INT4: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_UINT16: if (vals.size() != elements) { WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", @@ -704,7 +841,7 @@ int TosaReference::Tensor::writeToVector(ArrayProxy vals) break; default: WARNING("The output type doesn't match the data type assigned to the tensor (%s).", - EnumNameDType(getDtype())); + EnumNameTOSAREFTYPE(getDtype())); return -2; } return 0; @@ -716,7 +853,7 @@ int TosaReference::Tensor::writeToVector(ArrayProxy vals) switch (getDtype()) { - case tosa::DType_INT48: + case TOSA_REF_TYPE_INT48: if (vals.size() != elements) { WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", @@ -728,7 +865,7 @@ int TosaReference::Tensor::writeToVector(ArrayProxy vals) break; default: WARNING("The output type doesn't match the data type assigned to the tensor (%s).", - EnumNameDType(getDtype())); + EnumNameTOSAREFTYPE(getDtype())); return -2; } return 0; @@ -740,7 +877,7 @@ int TosaReference::Tensor::writeToVector(ArrayProxy vals) switch (getDtype()) { - case tosa::DType_BOOL: + case TOSA_REF_TYPE_BOOL: if (vals.size() != elements) { WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", @@ -752,14 +889,14 @@ int TosaReference::Tensor::writeToVector(ArrayProxy vals) break; default: WARNING("The output type (bool) doesn't match the data type assigned to the tensor (%s).", - EnumNameDType(getDtype())); + EnumNameTOSAREFTYPE(getDtype())); return -2; } return 0; } template -int TosaReference::TensorTemplate::setTensorValueFloat(const size_t buflen, const float* vals) +int TosaReference::TensorTemplate::setTensorValueDouble(const size_t buflen, const double* vals) { FATAL_ERROR("TensorTemplate::setTensorValueFloat should not be called. " "Implement template specialization version."); @@ -767,7 +904,7 @@ int TosaReference::TensorTemplate::setTensorValueFloat(const size_t buflen, c } template <> -int TosaReference::Tensor0::setTensorValueFloat(const size_t bufLen, const float* vals) +int TosaReference::Tensor0::setTensorValueDouble(const size_t bufLen, const double* vals) { ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); @@ -777,7 +914,7 @@ int TosaReference::Tensor0::setTensorValueFloat(const size_t bufLen, cons } template <> -int TosaReference::Tensor1::setTensorValueFloat(const size_t bufLen, const float* vals) +int TosaReference::Tensor1::setTensorValueDouble(const size_t bufLen, const double* vals) { uint32_t idx = 0; @@ -792,7 +929,7 @@ int TosaReference::Tensor1::setTensorValueFloat(const size_t bufLen, cons } template <> -int TosaReference::Tensor2::setTensorValueFloat(const size_t bufLen, const float* vals) +int TosaReference::Tensor2::setTensorValueDouble(const size_t bufLen, const double* vals) { uint32_t idx = 0; @@ -810,7 +947,7 @@ int TosaReference::Tensor2::setTensorValueFloat(const size_t bufLen, cons } template <> -int TosaReference::Tensor3::setTensorValueFloat(const size_t bufLen, const float* vals) +int TosaReference::Tensor3::setTensorValueDouble(const size_t bufLen, const double* vals) { uint32_t idx = 0; @@ -831,7 +968,7 @@ int TosaReference::Tensor3::setTensorValueFloat(const size_t bufLen, cons } template <> -int TosaReference::Tensor4::setTensorValueFloat(const size_t bufLen, const float* vals) +int TosaReference::Tensor4::setTensorValueDouble(const size_t bufLen, const double* vals) { uint32_t idx = 0; @@ -855,7 +992,7 @@ int TosaReference::Tensor4::setTensorValueFloat(const size_t bufLen, cons } template <> -int TosaReference::Tensor5::setTensorValueFloat(const size_t bufLen, const float* vals) +int TosaReference::Tensor5::setTensorValueDouble(const size_t bufLen, const double* vals) { uint32_t idx = 0; @@ -882,7 +1019,7 @@ int TosaReference::Tensor5::setTensorValueFloat(const size_t bufLen, cons } template <> -int TosaReference::Tensor6::setTensorValueFloat(const size_t bufLen, const float* vals) +int TosaReference::Tensor6::setTensorValueDouble(const size_t bufLen, const double* vals) { uint32_t idx = 0; @@ -911,15 +1048,15 @@ int TosaReference::Tensor6::setTensorValueFloat(const size_t bufLen, cons } template -int TosaReference::TensorTemplate::setTensorValueInt32(const size_t bufLen, const int32_t* vals) +int TosaReference::TensorTemplate::setTensorValueFloat(const size_t buflen, const float* vals) { - FATAL_ERROR("TensorTemplate::setTensorValueInt32 should not be called. " + FATAL_ERROR("TensorTemplate::setTensorValueFloat should not be called. " "Implement template specialization version."); return 0; } template <> -int TosaReference::Tensor0::setTensorValueInt32(const size_t bufLen, const int32_t* vals) +int TosaReference::Tensor0::setTensorValueFloat(const size_t bufLen, const float* vals) { ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); @@ -929,7 +1066,7 @@ int TosaReference::Tensor0::setTensorValueInt32(const size_t bufLen, co } template <> -int TosaReference::Tensor1::setTensorValueInt32(const size_t bufLen, const int32_t* vals) +int TosaReference::Tensor1::setTensorValueFloat(const size_t bufLen, const float* vals) { uint32_t idx = 0; @@ -944,7 +1081,7 @@ int TosaReference::Tensor1::setTensorValueInt32(const size_t bufLen, co } template <> -int TosaReference::Tensor2::setTensorValueInt32(const size_t bufLen, const int32_t* vals) +int TosaReference::Tensor2::setTensorValueFloat(const size_t bufLen, const float* vals) { uint32_t idx = 0; @@ -962,7 +1099,7 @@ int TosaReference::Tensor2::setTensorValueInt32(const size_t bufLen, co } template <> -int TosaReference::Tensor3::setTensorValueInt32(const size_t bufLen, const int32_t* vals) +int TosaReference::Tensor3::setTensorValueFloat(const size_t bufLen, const float* vals) { uint32_t idx = 0; @@ -983,7 +1120,7 @@ int TosaReference::Tensor3::setTensorValueInt32(const size_t bufLen, co } template <> -int TosaReference::Tensor4::setTensorValueInt32(const size_t bufLen, const int32_t* vals) +int TosaReference::Tensor4::setTensorValueFloat(const size_t bufLen, const float* vals) { uint32_t idx = 0; @@ -1007,7 +1144,7 @@ int TosaReference::Tensor4::setTensorValueInt32(const size_t bufLen, co } template <> -int TosaReference::Tensor5::setTensorValueInt32(const size_t bufLen, const int32_t* vals) +int TosaReference::Tensor5::setTensorValueFloat(const size_t bufLen, const float* vals) { uint32_t idx = 0; @@ -1034,7 +1171,7 @@ int TosaReference::Tensor5::setTensorValueInt32(const size_t bufLen, co } template <> -int TosaReference::Tensor6::setTensorValueInt32(const size_t bufLen, const int32_t* vals) +int TosaReference::Tensor6::setTensorValueFloat(const size_t bufLen, const float* vals) { uint32_t idx = 0; @@ -1063,15 +1200,15 @@ int TosaReference::Tensor6::setTensorValueInt32(const size_t bufLen, co } template -int TosaReference::TensorTemplate::setTensorValueInt64(const size_t bufLen, const int64_t* vals) +int TosaReference::TensorTemplate::setTensorValueInt32(const size_t bufLen, const int32_t* vals) { - FATAL_ERROR("TensorTemplate::setTensorValueInt64 should not be called. " + FATAL_ERROR("TensorTemplate::setTensorValueInt32 should not be called. " "Implement template specialization version."); return 0; } template <> -int TosaReference::Tensor0::setTensorValueInt64(const size_t bufLen, const int64_t* vals) +int TosaReference::Tensor0::setTensorValueInt32(const size_t bufLen, const int32_t* vals) { ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); @@ -1081,7 +1218,7 @@ int TosaReference::Tensor0::setTensorValueInt64(const size_t bufLen, co } template <> -int TosaReference::Tensor1::setTensorValueInt64(const size_t bufLen, const int64_t* vals) +int TosaReference::Tensor1::setTensorValueInt32(const size_t bufLen, const int32_t* vals) { uint32_t idx = 0; @@ -1096,7 +1233,7 @@ int TosaReference::Tensor1::setTensorValueInt64(const size_t bufLen, co } template <> -int TosaReference::Tensor2::setTensorValueInt64(const size_t bufLen, const int64_t* vals) +int TosaReference::Tensor2::setTensorValueInt32(const size_t bufLen, const int32_t* vals) { uint32_t idx = 0; @@ -1114,7 +1251,7 @@ int TosaReference::Tensor2::setTensorValueInt64(const size_t bufLen, co } template <> -int TosaReference::Tensor3::setTensorValueInt64(const size_t bufLen, const int64_t* vals) +int TosaReference::Tensor3::setTensorValueInt32(const size_t bufLen, const int32_t* vals) { uint32_t idx = 0; @@ -1135,7 +1272,7 @@ int TosaReference::Tensor3::setTensorValueInt64(const size_t bufLen, co } template <> -int TosaReference::Tensor4::setTensorValueInt64(const size_t bufLen, const int64_t* vals) +int TosaReference::Tensor4::setTensorValueInt32(const size_t bufLen, const int32_t* vals) { uint32_t idx = 0; @@ -1159,7 +1296,7 @@ int TosaReference::Tensor4::setTensorValueInt64(const size_t bufLen, co } template <> -int TosaReference::Tensor5::setTensorValueInt64(const size_t bufLen, const int64_t* vals) +int TosaReference::Tensor5::setTensorValueInt32(const size_t bufLen, const int32_t* vals) { uint32_t idx = 0; @@ -1186,7 +1323,7 @@ int TosaReference::Tensor5::setTensorValueInt64(const size_t bufLen, co } template <> -int TosaReference::Tensor6::setTensorValueInt64(const size_t bufLen, const int64_t* vals) +int TosaReference::Tensor6::setTensorValueInt32(const size_t bufLen, const int32_t* vals) { uint32_t idx = 0; @@ -1215,15 +1352,15 @@ int TosaReference::Tensor6::setTensorValueInt64(const size_t bufLen, co } template -int TosaReference::TensorTemplate::setTensorValueBool(const size_t buflen, const bool* vals) +int TosaReference::TensorTemplate::setTensorValueInt64(const size_t bufLen, const int64_t* vals) { - FATAL_ERROR("TensorTemplate::setTensorValueBool should not be called. " + FATAL_ERROR("TensorTemplate::setTensorValueInt64 should not be called. " "Implement template specialization version."); return 0; } template <> -int TosaReference::Tensor0::setTensorValueBool(const size_t bufLen, const bool* vals) +int TosaReference::Tensor0::setTensorValueInt64(const size_t bufLen, const int64_t* vals) { ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); @@ -1233,7 +1370,7 @@ int TosaReference::Tensor0::setTensorValueBool(const size_t bufLen, const } template <> -int TosaReference::Tensor1::setTensorValueBool(const size_t bufLen, const bool* vals) +int TosaReference::Tensor1::setTensorValueInt64(const size_t bufLen, const int64_t* vals) { uint32_t idx = 0; @@ -1248,7 +1385,7 @@ int TosaReference::Tensor1::setTensorValueBool(const size_t bufLen, const } template <> -int TosaReference::Tensor2::setTensorValueBool(const size_t bufLen, const bool* vals) +int TosaReference::Tensor2::setTensorValueInt64(const size_t bufLen, const int64_t* vals) { uint32_t idx = 0; @@ -1266,7 +1403,7 @@ int TosaReference::Tensor2::setTensorValueBool(const size_t bufLen, const } template <> -int TosaReference::Tensor3::setTensorValueBool(const size_t bufLen, const bool* vals) +int TosaReference::Tensor3::setTensorValueInt64(const size_t bufLen, const int64_t* vals) { uint32_t idx = 0; @@ -1287,7 +1424,7 @@ int TosaReference::Tensor3::setTensorValueBool(const size_t bufLen, const } template <> -int TosaReference::Tensor4::setTensorValueBool(const size_t bufLen, const bool* vals) +int TosaReference::Tensor4::setTensorValueInt64(const size_t bufLen, const int64_t* vals) { uint32_t idx = 0; @@ -1311,7 +1448,7 @@ int TosaReference::Tensor4::setTensorValueBool(const size_t bufLen, const } template <> -int TosaReference::Tensor5::setTensorValueBool(const size_t bufLen, const bool* vals) +int TosaReference::Tensor5::setTensorValueInt64(const size_t bufLen, const int64_t* vals) { uint32_t idx = 0; @@ -1338,7 +1475,7 @@ int TosaReference::Tensor5::setTensorValueBool(const size_t bufLen, const } template <> -int TosaReference::Tensor6::setTensorValueBool(const size_t bufLen, const bool* vals) +int TosaReference::Tensor6::setTensorValueInt64(const size_t bufLen, const int64_t* vals) { uint32_t idx = 0; @@ -1367,64 +1504,50 @@ int TosaReference::Tensor6::setTensorValueBool(const size_t bufLen, const } template -int TosaReference::TensorTemplate::getTensorValueFloat(const size_t bufLen, float* vals) const +int TosaReference::TensorTemplate::setTensorValueBool(const size_t buflen, const bool* vals) { - FATAL_ERROR("TensorTemplate::getTensorValueFloat should not be called. " + FATAL_ERROR("TensorTemplate::setTensorValueBool should not be called. " "Implement template specialization version."); return 0; } template <> -int TosaReference::Tensor0::getTensorValueFloat(const size_t bufLen, float* vals) const +int TosaReference::Tensor0::setTensorValueBool(const size_t bufLen, const bool* vals) { - int totalVals = 1; - - ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); - vals[0] = (*tensor)(0); + (*tensor)(0) = vals[0]; return 0; } template <> -int TosaReference::Tensor1::getTensorValueFloat(const size_t bufLen, float* vals) const +int TosaReference::Tensor1::setTensorValueBool(const size_t bufLen, const bool* vals) { - uint32_t idx = 0; - int totalVals = 1; - - for (size_t i = 0; i < shape.size(); i++) - { - totalVals *= shape[i]; - } + uint32_t idx = 0; - ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); for (int i0 = 0; i0 < shape[0]; i0++) { - vals[idx++] = (*tensor)(i0); + (*tensor)(i0) = vals[idx++]; } return 0; } template <> -int TosaReference::Tensor2::getTensorValueFloat(const size_t bufLen, float* vals) const +int TosaReference::Tensor2::setTensorValueBool(const size_t bufLen, const bool* vals) { - uint32_t idx = 0; - int totalVals = 1; - - for (size_t i = 0; i < shape.size(); i++) - { - totalVals *= shape[i]; - } + uint32_t idx = 0; - ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); for (int i0 = 0; i0 < shape[0]; i0++) { for (int i1 = 0; i1 < shape[1]; i1++) { - vals[idx++] = (*tensor)(i0, i1); + (*tensor)(i0, i1) = vals[idx++]; } } @@ -1432,17 +1555,11 @@ int TosaReference::Tensor2::getTensorValueFloat(const size_t bufLen, floa } template <> -int TosaReference::Tensor3::getTensorValueFloat(const size_t bufLen, float* vals) const +int TosaReference::Tensor3::setTensorValueBool(const size_t bufLen, const bool* vals) { - uint32_t idx = 0; - int totalVals = 1; - - for (size_t i = 0; i < shape.size(); i++) - { - totalVals *= shape[i]; - } + uint32_t idx = 0; - ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); for (int i0 = 0; i0 < shape[0]; i0++) { @@ -1450,7 +1567,7 @@ int TosaReference::Tensor3::getTensorValueFloat(const size_t bufLen, floa { for (int i2 = 0; i2 < shape[2]; i2++) { - vals[idx++] = (*tensor)(i0, i1, i2); + (*tensor)(i0, i1, i2) = vals[idx++]; } } } @@ -1459,17 +1576,11 @@ int TosaReference::Tensor3::getTensorValueFloat(const size_t bufLen, floa } template <> -int TosaReference::Tensor4::getTensorValueFloat(const size_t bufLen, float* vals) const +int TosaReference::Tensor4::setTensorValueBool(const size_t bufLen, const bool* vals) { - uint32_t idx = 0; - int totalVals = 1; - - for (size_t i = 0; i < shape.size(); i++) - { - totalVals *= shape[i]; - } + uint32_t idx = 0; - ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); for (int i0 = 0; i0 < shape[0]; i0++) { @@ -1479,7 +1590,7 @@ int TosaReference::Tensor4::getTensorValueFloat(const size_t bufLen, floa { for (int i3 = 0; i3 < shape[3]; i3++) { - vals[idx++] = (*tensor)(i0, i1, i2, i3); + (*tensor)(i0, i1, i2, i3) = vals[idx++]; } } } @@ -1489,17 +1600,11 @@ int TosaReference::Tensor4::getTensorValueFloat(const size_t bufLen, floa } template <> -int TosaReference::Tensor5::getTensorValueFloat(const size_t bufLen, float* vals) const +int TosaReference::Tensor5::setTensorValueBool(const size_t bufLen, const bool* vals) { - uint32_t idx = 0; - int totalVals = 1; - - for (size_t i = 0; i < shape.size(); i++) - { - totalVals *= shape[i]; - } + uint32_t idx = 0; - ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); for (int i0 = 0; i0 < shape[0]; i0++) { @@ -1511,7 +1616,7 @@ int TosaReference::Tensor5::getTensorValueFloat(const size_t bufLen, floa { for (int i4 = 0; i4 < shape[4]; i4++) { - vals[idx++] = (*tensor)(i0, i1, i2, i3, i4); + (*tensor)(i0, i1, i2, i3, i4) = vals[idx++]; } } } @@ -1522,10 +1627,384 @@ int TosaReference::Tensor5::getTensorValueFloat(const size_t bufLen, floa } template <> -int TosaReference::Tensor6::getTensorValueFloat(const size_t bufLen, float* vals) const +int TosaReference::Tensor6::setTensorValueBool(const size_t bufLen, const bool* vals) { - uint32_t idx = 0; - int totalVals = 1; + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + for (int i5 = 0; i5 < shape[5]; i5++) + { + (*tensor)(i0, i1, i2, i3, i4, i5) = vals[idx++]; + } + } + } + } + } + } + return 0; +} + +template +int TosaReference::TensorTemplate::getTensorValueDouble(const size_t bufLen, double* vals) const +{ + FATAL_ERROR("TensorTemplate::getTensorValueDouble should not be called. " + "Implement template specialization version."); + return 0; +} + +template <> +int TosaReference::Tensor0::getTensorValueDouble(const size_t bufLen, double* vals) const +{ + int totalVals = 1; + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + vals[0] = (*tensor)(0); + + return 0; +} + +template <> +int TosaReference::Tensor1::getTensorValueDouble(const size_t bufLen, double* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + vals[idx++] = (*tensor)(i0); + } + + return 0; +} + +template <> +int TosaReference::Tensor2::getTensorValueDouble(const size_t bufLen, double* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + vals[idx++] = (*tensor)(i0, i1); + } + } + + return 0; +} + +template <> +int TosaReference::Tensor3::getTensorValueDouble(const size_t bufLen, double* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + vals[idx++] = (*tensor)(i0, i1, i2); + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor4::getTensorValueDouble(const size_t bufLen, double* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + vals[idx++] = (*tensor)(i0, i1, i2, i3); + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor5::getTensorValueDouble(const size_t bufLen, double* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + vals[idx++] = (*tensor)(i0, i1, i2, i3, i4); + } + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor6::getTensorValueDouble(const size_t bufLen, double* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + for (int i5 = 0; i5 < shape[5]; i5++) + { + vals[idx++] = (*tensor)(i0, i1, i2, i3, i4, i5); + } + } + } + } + } + } + return 0; +} + +template +int TosaReference::TensorTemplate::getTensorValueFloat(const size_t bufLen, float* vals) const +{ + FATAL_ERROR("TensorTemplate::getTensorValueFloat should not be called. " + "Implement template specialization version."); + return 0; +} + +template <> +int TosaReference::Tensor0::getTensorValueFloat(const size_t bufLen, float* vals) const +{ + int totalVals = 1; + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + vals[0] = (*tensor)(0); + + return 0; +} + +template <> +int TosaReference::Tensor1::getTensorValueFloat(const size_t bufLen, float* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + vals[idx++] = (*tensor)(i0); + } + + return 0; +} + +template <> +int TosaReference::Tensor2::getTensorValueFloat(const size_t bufLen, float* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + vals[idx++] = (*tensor)(i0, i1); + } + } + + return 0; +} + +template <> +int TosaReference::Tensor3::getTensorValueFloat(const size_t bufLen, float* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + vals[idx++] = (*tensor)(i0, i1, i2); + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor4::getTensorValueFloat(const size_t bufLen, float* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + vals[idx++] = (*tensor)(i0, i1, i2, i3); + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor5::getTensorValueFloat(const size_t bufLen, float* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + vals[idx++] = (*tensor)(i0, i1, i2, i3, i4); + } + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor6::getTensorValueFloat(const size_t bufLen, float* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; for (size_t i = 0; i < shape.size(); i++) { @@ -2092,38 +2571,114 @@ int TosaReference::Tensor5::getTensorValueBool(const size_t bufLen, bool* } template <> -int TosaReference::Tensor6::getTensorValueBool(const size_t bufLen, bool* vals) const +int TosaReference::Tensor6::getTensorValueBool(const size_t bufLen, bool* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + for (int i5 = 0; i5 < shape[5]; i5++) + { + vals[idx++] = (*tensor)(i0, i1, i2, i3, i4, i5); + } + } + } + } + } + } + return 0; +} + +template <> +int TosaReference::Tensor0::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor0(); + + if (tensor) + return 0; + else + return 1; +} +template <> +int TosaReference::Tensor1::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor1(shape[0]); + if (tensor) + return 0; + else + return 1; +} +template <> +int TosaReference::Tensor2::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor2(shape[0], shape[1]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor3::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor3(shape[0], shape[1], shape[2]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor4::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor4(shape[0], shape[1], shape[2], shape[3]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor5::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor5(shape[0], shape[1], shape[2], shape[3], shape[4]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor6::allocate() { - uint32_t idx = 0; - int totalVals = 1; - - for (size_t i = 0; i < shape.size(); i++) - { - totalVals *= shape[i]; - } - - ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); - - for (int i0 = 0; i0 < shape[0]; i0++) - { - for (int i1 = 0; i1 < shape[1]; i1++) - { - for (int i2 = 0; i2 < shape[2]; i2++) - { - for (int i3 = 0; i3 < shape[3]; i3++) - { - for (int i4 = 0; i4 < shape[4]; i4++) - { - for (int i5 = 0; i5 < shape[5]; i5++) - { - vals[idx++] = (*tensor)(i0, i1, i2, i3, i4, i5); - } - } - } - } - } - } - return 0; + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor6(shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]); + if (tensor) + return 0; + else + return 1; } template <> @@ -2427,6 +2982,230 @@ int TosaReference::Tensor6::allocate() return 1; } +template <> +int TosaReference::Tensor0::dumpTensor(FILE* out) const +{ + char fp_fmt[32]; + snprintf(fp_fmt, sizeof(fp_fmt), "[ %%%sf ]\n", g_func_config.fp_format.c_str()); + + if (tensor == nullptr) + { + fprintf(out, "\n"); + return 0; + } + + fprintf(out, fp_fmt, (*tensor)(0)); + + return 0; +} + +template <> +int TosaReference::Tensor1::dumpTensor(FILE* out) const +{ + char fp_fmt[32]; + snprintf(fp_fmt, sizeof(fp_fmt), " %%%sf ", g_func_config.fp_format.c_str()); + + if (tensor == nullptr) + { + fprintf(out, "\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, fp_fmt, (*tensor)(i0)); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor2::dumpTensor(FILE* out) const +{ + char fp_fmt[32]; + snprintf(fp_fmt, sizeof(fp_fmt), " %%%sf ", g_func_config.fp_format.c_str()); + + if (tensor == nullptr) + { + fprintf(out, "\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, fp_fmt, (*tensor)(i0, i1)); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor3::dumpTensor(FILE* out) const +{ + char fp_fmt[32]; + snprintf(fp_fmt, sizeof(fp_fmt), " %%%sf ", g_func_config.fp_format.c_str()); + + if (tensor == nullptr) + { + fprintf(out, "\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, "["); + for (int i2 = 0; i2 < shape[2]; i2++) + { + fprintf(out, fp_fmt, (*tensor)(i0, i1, i2)); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor4::dumpTensor(FILE* out) const +{ + char fp_fmt[32]; + snprintf(fp_fmt, sizeof(fp_fmt), " %%%sf ", g_func_config.fp_format.c_str()); + + if (tensor == nullptr) + { + fprintf(out, "\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, "["); + for (int i2 = 0; i2 < shape[2]; i2++) + { + fprintf(out, "["); + for (int i3 = 0; i3 < shape[3]; i3++) + { + fprintf(out, fp_fmt, (*tensor)(i0, i1, i2, i3)); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor5::dumpTensor(FILE* out) const +{ + char fp_fmt[32]; + snprintf(fp_fmt, sizeof(fp_fmt), " %%%sf ", g_func_config.fp_format.c_str()); + + if (tensor == nullptr) + { + fprintf(out, "\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, "["); + for (int i2 = 0; i2 < shape[2]; i2++) + { + fprintf(out, "["); + for (int i3 = 0; i3 < shape[3]; i3++) + { + fprintf(out, "["); + for (int i4 = 0; i4 < shape[4]; i4++) + { + fprintf(out, fp_fmt, (*tensor)(i0, i1, i2, i3, i4)); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor6::dumpTensor(FILE* out) const +{ + char fp_fmt[32]; + snprintf(fp_fmt, sizeof(fp_fmt), " %%%sf ", g_func_config.fp_format.c_str()); + + if (tensor == nullptr) + { + fprintf(out, "\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, "["); + for (int i2 = 0; i2 < shape[2]; i2++) + { + fprintf(out, "["); + for (int i3 = 0; i3 < shape[3]; i3++) + { + fprintf(out, "["); + for (int i4 = 0; i4 < shape[4]; i4++) + { + fprintf(out, "["); + for (int i5 = 0; i5 < shape[5]; i5++) + { + fprintf(out, fp_fmt, (*tensor)(i0, i1, i2, i3, i4, i5)); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + template <> int TosaReference::Tensor0::dumpTensor(FILE* out) const { @@ -3342,6 +4121,14 @@ int TosaReference::TensorTemplate::dumpTensor(FILE* out) const } // template explicit specialization +template class TosaReference::TensorTemplate>; +template class TosaReference::TensorTemplate>; +template class TosaReference::TensorTemplate>; +template class TosaReference::TensorTemplate>; +template class TosaReference::TensorTemplate>; +template class TosaReference::TensorTemplate>; +template class TosaReference::TensorTemplate>; + template class TosaReference::TensorTemplate>; template class TosaReference::TensorTemplate>; template class TosaReference::TensorTemplate>; diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index d5f1de8..08ee8bf 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -17,9 +17,9 @@ #define TOSA_REFERENCE_TENSOR_H #include "array_proxy.h" +#include "dtype.h" #include "model_common.h" #include "ops/template_types.h" -#include "tosa_generated.h" #include "tosa_serialization_handler.h" #include #include @@ -34,7 +34,7 @@ class GraphNode; class Tensor { public: - Tensor(std::string tensorName_, DType tensorDtype__, std::vector shape_); + Tensor(const std::string tensorName_, const DType serializationDtype_, const std::vector shape_); virtual ~Tensor(); @@ -212,19 +212,26 @@ public: return shape.size(); } - const DType getDtype() const + const TOSA_REF_TYPE getDtype() const { return tensorDtype; } + const DType getSerializationDtype() const + { + return serializationDtype; + } + virtual int dumpTensor(FILE* out) const = 0; virtual int dumpTensorParams(FILE* out) const; virtual int dumpTensorParams(std::ostream& out) const; + virtual int setTensorValueDouble(const size_t bufLen, const double* vals) = 0; virtual int setTensorValueFloat(const size_t bufLen, const float* vals) = 0; virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals) = 0; virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals) = 0; virtual int setTensorValueBool(const size_t bufLen, const bool* vals) = 0; + virtual int getTensorValueDouble(const size_t bufLen, double* fbuf) const = 0; virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const = 0; virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const = 0; virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const = 0; @@ -234,12 +241,14 @@ public: virtual int writeToNpyFile(const char* filename) const; virtual int copyValueFrom(Tensor* tensor) = 0; + virtual int readfromVector(const ArrayProxy vals); virtual int readfromVector(const ArrayProxy vals); virtual int readfromVector(const ArrayProxy vals); virtual int readfromVector(const ArrayProxy vals); virtual int readfromVector(const ArrayProxy vals); virtual int readfromVector(const ArrayProxy vals); + virtual int writeToVector(ArrayProxy vals); virtual int writeToVector(ArrayProxy vals); virtual int writeToVector(ArrayProxy vals); virtual int writeToVector(ArrayProxy vals); @@ -258,10 +267,11 @@ public: virtual bool is_allocated() = 0; protected: - std::string tensorName; - DType tensorDtype; + const std::string tensorName; + const DType serializationDtype; + const std::vector shape; + const TOSA_REF_TYPE tensorDtype; int isValid; - std::vector shape; int isSubgraphInput; int isSubgraphOutput; bool isAllocated; @@ -284,8 +294,8 @@ template class TensorTemplate : public Tensor { public: - TensorTemplate(std::string tensorName_, DType tensorDtype_, std::vector shape_) - : Tensor(tensorName_, tensorDtype_, shape_) + TensorTemplate(const std::string tensorName_, const DType dtype_, const std::vector shape_) + : Tensor(tensorName_, dtype_, shape_) { tensor = nullptr; } @@ -330,10 +340,13 @@ public: virtual int dumpTensor(FILE* out) const; + virtual int setTensorValueDouble(const size_t bufLen, const double* vals); virtual int setTensorValueFloat(const size_t bufLen, const float* vals); virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals); virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals); virtual int setTensorValueBool(const size_t bufLen, const bool* vals); + + virtual int getTensorValueDouble(const size_t bufLen, double* fbuf) const; virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const; virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const; virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const; @@ -362,6 +375,21 @@ int Tensor5::allocate(); template <> int Tensor6::allocate(); +template <> +int Tensor0::allocate(); +template <> +int Tensor1::allocate(); +template <> +int Tensor2::allocate(); +template <> +int Tensor3::allocate(); +template <> +int Tensor4::allocate(); +template <> +int Tensor5::allocate(); +template <> +int Tensor6::allocate(); + template <> int Tensor0::allocate(); template <> @@ -422,6 +450,21 @@ int Tensor5::copyValueFrom(Tensor* src); template <> int Tensor6::copyValueFrom(Tensor* src); +template <> +int Tensor0::copyValueFrom(Tensor* src); +template <> +int Tensor1::copyValueFrom(Tensor* src); +template <> +int Tensor2::copyValueFrom(Tensor* src); +template <> +int Tensor3::copyValueFrom(Tensor* src); +template <> +int Tensor4::copyValueFrom(Tensor* src); +template <> +int Tensor5::copyValueFrom(Tensor* src); +template <> +int Tensor6::copyValueFrom(Tensor* src); + template <> int Tensor0::copyValueFrom(Tensor* src); template <> @@ -557,6 +600,36 @@ int Tensor5::getTensorValueFloat(const size_t bufLen, float* vals) const; template <> int Tensor6::getTensorValueFloat(const size_t bufLen, float* vals) const; +template <> +int Tensor0::setTensorValueDouble(const size_t bufLen, const double* vals); +template <> +int Tensor1::setTensorValueDouble(const size_t bufLen, const double* vals); +template <> +int Tensor2::setTensorValueDouble(const size_t bufLen, const double* vals); +template <> +int Tensor3::setTensorValueDouble(const size_t bufLen, const double* vals); +template <> +int Tensor4::setTensorValueDouble(const size_t bufLen, const double* vals); +template <> +int Tensor5::setTensorValueDouble(const size_t bufLen, const double* vals); +template <> +int Tensor6::setTensorValueDouble(const size_t bufLen, const double* vals); + +template <> +int Tensor0::getTensorValueDouble(const size_t bufLen, double* vals) const; +template <> +int Tensor1::getTensorValueDouble(const size_t bufLen, double* vals) const; +template <> +int Tensor2::getTensorValueDouble(const size_t bufLen, double* vals) const; +template <> +int Tensor3::getTensorValueDouble(const size_t bufLen, double* vals) const; +template <> +int Tensor4::getTensorValueDouble(const size_t bufLen, double* vals) const; +template <> +int Tensor5::getTensorValueDouble(const size_t bufLen, double* vals) const; +template <> +int Tensor6::getTensorValueDouble(const size_t bufLen, double* vals) const; + template <> int Tensor0::setTensorValueBool(const size_t bufLen, const bool* vals); template <> @@ -587,7 +660,6 @@ int Tensor5::getTensorValueBool(const size_t bufLen, bool* vals) const; template <> int Tensor6::getTensorValueBool(const size_t bufLen, bool* vals) const; -// assume we only dump float type tensor now template <> int Tensor0::dumpTensor(FILE* out) const; template <> @@ -603,6 +675,20 @@ int Tensor5::dumpTensor(FILE* out) const; template <> int Tensor6::dumpTensor(FILE* out) const; template <> +int Tensor0::dumpTensor(FILE* out) const; +template <> +int Tensor1::dumpTensor(FILE* out) const; +template <> +int Tensor2::dumpTensor(FILE* out) const; +template <> +int Tensor3::dumpTensor(FILE* out) const; +template <> +int Tensor4::dumpTensor(FILE* out) const; +template <> +int Tensor5::dumpTensor(FILE* out) const; +template <> +int Tensor6::dumpTensor(FILE* out) const; +template <> int Tensor0::dumpTensor(FILE* out) const; template <> int Tensor1::dumpTensor(FILE* out) const; @@ -648,100 +734,119 @@ int Tensor6::dumpTensor(FILE* out) const; class TensorFactory { public: - static Tensor* newTensor(std::string tensorName_, DType tensorDtype_, std::vector shape_, const uint32_t rank) + static Tensor* newTensor(std::string tensorName_, DType dtype_, std::vector shape_, const uint32_t rank) { + TOSA_REF_TYPE tensorDtype_ = ConvertDType(dtype_); switch (tensorDtype_) { - case DType_FP32: - case DType_FP16: - case DType_BF16: + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + switch (rank) + { + case 0: + return new Tensor0(tensorName_, dtype_, shape_); + case 1: + return new Tensor1(tensorName_, dtype_, shape_); + case 2: + return new Tensor2(tensorName_, dtype_, shape_); + case 3: + return new Tensor3(tensorName_, dtype_, shape_); + case 4: + return new Tensor4(tensorName_, dtype_, shape_); + case 5: + return new Tensor5(tensorName_, dtype_, shape_); + case 6: + return new Tensor6(tensorName_, dtype_, shape_); + } + break; + case TOSA_REF_TYPE_INT32: + case TOSA_REF_TYPE_UINT8: + case TOSA_REF_TYPE_INT4: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_UINT16: switch (rank) { case 0: - return new Tensor0(tensorName_, tensorDtype_, shape_); + return new Tensor0(tensorName_, dtype_, shape_); case 1: - return new Tensor1(tensorName_, tensorDtype_, shape_); + return new Tensor1(tensorName_, dtype_, shape_); case 2: - return new Tensor2(tensorName_, tensorDtype_, shape_); + return new Tensor2(tensorName_, dtype_, shape_); case 3: - return new Tensor3(tensorName_, tensorDtype_, shape_); + return new Tensor3(tensorName_, dtype_, shape_); case 4: - return new Tensor4(tensorName_, tensorDtype_, shape_); + return new Tensor4(tensorName_, dtype_, shape_); case 5: - return new Tensor5(tensorName_, tensorDtype_, shape_); + return new Tensor5(tensorName_, dtype_, shape_); case 6: - return new Tensor6(tensorName_, tensorDtype_, shape_); + return new Tensor6(tensorName_, dtype_, shape_); } break; - case DType_INT32: - case DType_UINT8: - case DType_INT4: - case DType_INT8: - case DType_INT16: - case DType_UINT16: + case TOSA_REF_TYPE_INT48: switch (rank) { case 0: - return new Tensor0(tensorName_, tensorDtype_, shape_); + return new Tensor0(tensorName_, dtype_, shape_); case 1: - return new Tensor1(tensorName_, tensorDtype_, shape_); + return new Tensor1(tensorName_, dtype_, shape_); case 2: - return new Tensor2(tensorName_, tensorDtype_, shape_); + return new Tensor2(tensorName_, dtype_, shape_); case 3: - return new Tensor3(tensorName_, tensorDtype_, shape_); + return new Tensor3(tensorName_, dtype_, shape_); case 4: - return new Tensor4(tensorName_, tensorDtype_, shape_); + return new Tensor4(tensorName_, dtype_, shape_); case 5: - return new Tensor5(tensorName_, tensorDtype_, shape_); + return new Tensor5(tensorName_, dtype_, shape_); case 6: - return new Tensor6(tensorName_, tensorDtype_, shape_); + return new Tensor6(tensorName_, dtype_, shape_); } break; - case DType_INT48: + case TOSA_REF_TYPE_BOOL: switch (rank) { case 0: - return new Tensor0(tensorName_, tensorDtype_, shape_); + return new Tensor0(tensorName_, dtype_, shape_); case 1: - return new Tensor1(tensorName_, tensorDtype_, shape_); + return new Tensor1(tensorName_, dtype_, shape_); case 2: - return new Tensor2(tensorName_, tensorDtype_, shape_); + return new Tensor2(tensorName_, dtype_, shape_); case 3: - return new Tensor3(tensorName_, tensorDtype_, shape_); + return new Tensor3(tensorName_, dtype_, shape_); case 4: - return new Tensor4(tensorName_, tensorDtype_, shape_); + return new Tensor4(tensorName_, dtype_, shape_); case 5: - return new Tensor5(tensorName_, tensorDtype_, shape_); + return new Tensor5(tensorName_, dtype_, shape_); case 6: - return new Tensor6(tensorName_, tensorDtype_, shape_); + return new Tensor6(tensorName_, dtype_, shape_); } break; - case DType_BOOL: + case TOSA_REF_TYPE_FP64: switch (rank) { case 0: - return new Tensor0(tensorName_, tensorDtype_, shape_); + return new Tensor0(tensorName_, dtype_, shape_); case 1: - return new Tensor1(tensorName_, tensorDtype_, shape_); + return new Tensor1(tensorName_, dtype_, shape_); case 2: - return new Tensor2(tensorName_, tensorDtype_, shape_); + return new Tensor2(tensorName_, dtype_, shape_); case 3: - return new Tensor3(tensorName_, tensorDtype_, shape_); + return new Tensor3(tensorName_, dtype_, shape_); case 4: - return new Tensor4(tensorName_, tensorDtype_, shape_); + return new Tensor4(tensorName_, dtype_, shape_); case 5: - return new Tensor5(tensorName_, tensorDtype_, shape_); + return new Tensor5(tensorName_, dtype_, shape_); case 6: - return new Tensor6(tensorName_, tensorDtype_, shape_); + return new Tensor6(tensorName_, dtype_, shape_); } break; - default: + case TOSA_REF_TYPE_UNKNOWN: + assert(0); // tensorDtype_ is uninitialized break; } return nullptr; } - - static Tensor* newTensor(DType type, const std::vector shape); }; }; // namespace TosaReference diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib index dce6ceb..cfcb20d 160000 --- a/thirdparty/serialization_lib +++ b/thirdparty/serialization_lib @@ -1 +1 @@ -Subproject commit dce6cebbeb6c45625c4ef8fafb5a7775319101c5 +Subproject commit cfcb20d08c4c409bbcd2d2dde6ca5ecdac299454 diff --git a/verif/frameworks/tosa_verif_framework_compiler_runner.py b/verif/frameworks/tosa_verif_framework_compiler_runner.py index 0d98c17..28e4369 100755 --- a/verif/frameworks/tosa_verif_framework_compiler_runner.py +++ b/verif/frameworks/tosa_verif_framework_compiler_runner.py @@ -56,6 +56,13 @@ def parse_args(): required=True, help="Reference model base directory", ) + parser.add_argument( + "-p", + "--precise-mode", + dest="precise_mode", + action="store_true", + help="run in precise mode (FP64)", + ) parser.add_argument( "-v", "--verbose", dest="verbose", action="count", help="Verbose run" ) @@ -552,6 +559,9 @@ def run_test(args, test, framework): if args.debug_ref_model: ref_model_cmd.extend(["-D ALL", "-l high"]) + if args.precise_mode: + ref_model_cmd.extend(["--precise_mode=1"]) + if args.valgrind: ref_model_cmd = [ "valgrind", @@ -594,7 +604,11 @@ def run_test(args, test, framework): ) return (TestResult.REF_MODEL_RUNTIME_ERROR, 0.0, e) - if tf_result.dtype == np.float16: + if args.precise_mode == 1 and ( + tf_result.dtype == np.float16 or tf_result.dtype == np.float32 + ): + tf_result = tf_result.astype(np.float64) + elif tf_result.dtype == np.float16: tf_result = tf_result.astype(np.float32) elif ( tf_result.dtype == np.uint8 diff --git a/verif/runner/tosa_refmodel_sut_run.py b/verif/runner/tosa_refmodel_sut_run.py index 95f6e7b..df5c0db 100644 --- a/verif/runner/tosa_refmodel_sut_run.py +++ b/verif/runner/tosa_refmodel_sut_run.py @@ -45,6 +45,9 @@ class TosaSUTRunner(TosaTestRunner): if args.ref_intermediates: cmd.extend(["--dump_intermediates", str(args.ref_intermediates)]) + if args.precise_mode: + cmd.extend(["--precise_mode=1"]) + # Run command and interpret tosa graph result via process return codes graphMessage = None try: diff --git a/verif/runner/tosa_verif_run_tests.py b/verif/runner/tosa_verif_run_tests.py index 6b5d77e..814c864 100644 --- a/verif/runner/tosa_verif_run_tests.py +++ b/verif/runner/tosa_verif_run_tests.py @@ -147,6 +147,13 @@ def parseArgs(argv): help="A TOSA level defines operator parameter ranges that an implementation shall support." "Config tosa_level for running the reference model only. Default is EIGHTK", ) + parser.add_argument( + "-p", + "--precise-mode", + dest="precise_mode", + action="store_true", + help="Run the reference model in precise mode (FP64)", + ) args = parser.parse_args(argv) -- cgit v1.2.1