aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-03-28 22:06:56 +0000
committerTai Ly <tai.ly@arm.com>2023-05-05 19:23:15 +0000
commita4d748b08accce06fab93e2d2b96e499b35ae89b (patch)
tree20a3957e1f45f65f35d5d67ecce1618659e388f0
parent0c71686875618b2e11290273b7a05b88ef8a8aae (diff)
downloadreference_model-a4d748b08accce06fab93e2d2b96e499b35ae89b.tar.gz
[reference model] Add precise mode
This adds --precise_mode=1 option to tosa_referece_model, which will cause reference model to convert all floating point tensors to FP64 tensors and compute all operators accordingly. Also adds optional -p arguments to test runners tosa_verif_run_tests.py and tosa_verif_framework_compiler_runner.py to run tests in precise mode Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I156055216ad61710096497a8fa1a653be2a602a3
-rw-r--r--reference_model/CMakeLists.txt3
-rw-r--r--reference_model/include/dtype.h132
-rw-r--r--reference_model/include/func_config.h1
-rw-r--r--reference_model/src/arith_util.h15
-rw-r--r--reference_model/src/command_line_utils.h1
-rw-r--r--reference_model/src/graph_node.h21
-rw-r--r--reference_model/src/model_runner_impl.cc8
-rw-r--r--reference_model/src/ops/activation_funcs.cc56
-rw-r--r--reference_model/src/ops/activation_funcs.h8
-rw-r--r--reference_model/src/ops/comparison.cc44
-rw-r--r--reference_model/src/ops/comparison.h26
-rw-r--r--reference_model/src/ops/control_flow.cc21
-rw-r--r--reference_model/src/ops/data_layout.cc93
-rw-r--r--reference_model/src/ops/data_layout.h22
-rw-r--r--reference_model/src/ops/data_nodes.cc11
-rw-r--r--reference_model/src/ops/data_nodes.h4
-rw-r--r--reference_model/src/ops/ewise_binary.cc210
-rw-r--r--reference_model/src/ops/ewise_binary.h36
-rw-r--r--reference_model/src/ops/ewise_ternary.cc18
-rw-r--r--reference_model/src/ops/ewise_ternary.h10
-rw-r--r--reference_model/src/ops/ewise_unary.cc164
-rw-r--r--reference_model/src/ops/ewise_unary.h10
-rw-r--r--reference_model/src/ops/image.cc55
-rw-r--r--reference_model/src/ops/image.h4
-rw-r--r--reference_model/src/ops/op_factory.cc60
-rw-r--r--reference_model/src/ops/op_factory.h82
-rw-r--r--reference_model/src/ops/reduction.cc111
-rw-r--r--reference_model/src/ops/reduction.h38
-rw-r--r--reference_model/src/ops/scatter_gather.cc20
-rw-r--r--reference_model/src/ops/scatter_gather.h6
-rw-r--r--reference_model/src/ops/template_types.h96
-rw-r--r--reference_model/src/ops/tensor_ops.cc203
-rw-r--r--reference_model/src/ops/tensor_ops.h24
-rw-r--r--reference_model/src/ops/type_conversion.cc116
-rw-r--r--reference_model/src/ops/type_conversion.h99
-rw-r--r--reference_model/src/subgraph_traverser.cc90
-rw-r--r--reference_model/src/tensor.cc999
-rw-r--r--reference_model/src/tensor.h209
m---------thirdparty/serialization_lib0
-rwxr-xr-xverif/frameworks/tosa_verif_framework_compiler_runner.py16
-rw-r--r--verif/runner/tosa_refmodel_sut_run.py3
-rw-r--r--verif/runner/tosa_verif_run_tests.py7
42 files changed, 2285 insertions, 867 deletions
diff --git a/reference_model/CMakeLists.txt b/reference_model/CMakeLists.txt
index 6494225..a086f4b 100644
--- a/reference_model/CMakeLists.txt
+++ b/reference_model/CMakeLists.txt
@@ -1,6 +1,6 @@
cmake_minimum_required (VERSION 3.4)
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -109,6 +109,7 @@ set(PUBLIC_HEADERS)
list(APPEND PUBLIC_HEADERS
include/debug_modes.def
include/debug_types.h
+ include/dtype.h
include/func_config.h
include/func_debug.h
include/graph_status.h
diff --git a/reference_model/include/dtype.h b/reference_model/include/dtype.h
new file mode 100644
index 0000000..4976b54
--- /dev/null
+++ b/reference_model/include/dtype.h
@@ -0,0 +1,132 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef TOSA_REFERENCE_DTYPE_H
+#define TOSA_REFERENCE_DTYPE_H
+
+#include "model_common.h"
+#include "tosa_generated.h"
+#include <cstdint>
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+// Reference Model version of tosa.fbs enum DType
+// Plus a FP64 data type for precise mode.
+enum TOSA_REF_TYPE : uint32_t
+{
+ TOSA_REF_TYPE_UNKNOWN = 0,
+ TOSA_REF_TYPE_BOOL = 1,
+ TOSA_REF_TYPE_UINT8 = 2,
+ TOSA_REF_TYPE_INT4 = 3,
+ TOSA_REF_TYPE_INT8 = 4,
+ TOSA_REF_TYPE_INT16 = 5,
+ TOSA_REF_TYPE_INT32 = 6,
+ TOSA_REF_TYPE_INT48 = 7,
+ TOSA_REF_TYPE_FP32 = 8,
+ TOSA_REF_TYPE_UINT16 = 9,
+ TOSA_REF_TYPE_FP16 = 10,
+ TOSA_REF_TYPE_BF16 = 11,
+ TOSA_REF_TYPE_FP64 = 99, // FP64 is special: add new data types above
+};
+
+inline const char* EnumNameTOSAREFTYPE(TOSA_REF_TYPE e)
+{
+ switch (e)
+ {
+ case TOSA_REF_TYPE_UNKNOWN:
+ return EnumNameDType(DType_UNKNOWN);
+ case TOSA_REF_TYPE_BOOL:
+ return EnumNameDType(DType_BOOL);
+ case TOSA_REF_TYPE_UINT8:
+ return EnumNameDType(DType_UINT8);
+ case TOSA_REF_TYPE_INT4:
+ return EnumNameDType(DType_INT4);
+ case TOSA_REF_TYPE_INT8:
+ return EnumNameDType(DType_INT8);
+ case TOSA_REF_TYPE_INT16:
+ return EnumNameDType(DType_INT16);
+ case TOSA_REF_TYPE_INT32:
+ return EnumNameDType(DType_INT32);
+ case TOSA_REF_TYPE_INT48:
+ return EnumNameDType(DType_INT48);
+ case TOSA_REF_TYPE_FP32:
+ return EnumNameDType(DType_FP32);
+ case TOSA_REF_TYPE_UINT16:
+ return EnumNameDType(DType_UINT16);
+ case TOSA_REF_TYPE_FP16:
+ return EnumNameDType(DType_FP16);
+ case TOSA_REF_TYPE_BF16:
+ return EnumNameDType(DType_BF16);
+ case TOSA_REF_TYPE_FP64:
+ return "FP64";
+ default:
+ assert(false);
+ }
+}
+
+// return corresponding TOSA_REF_TYPE for DType
+inline TOSA_REF_TYPE ConvertDType(const DType dtype)
+{
+ assert(DType_MAX == DType_BF16); // must update whenever DType_MAX changes
+
+ if (g_func_config.precise_mode)
+ {
+ // in precise mode, convert all floating DType to TOSA_REF_TYPE_FP64
+ switch (dtype)
+ {
+ case DType_FP16:
+ case DType_FP32:
+ case DType_BF16:
+ return TOSA_REF_TYPE_FP64;
+ default:
+ break;
+ }
+ }
+
+ switch (dtype)
+ {
+ case DType_BOOL:
+ return TOSA_REF_TYPE_BOOL;
+ case DType_UINT8:
+ return TOSA_REF_TYPE_UINT8;
+ case DType_INT4:
+ return TOSA_REF_TYPE_INT4;
+ case DType_INT8:
+ return TOSA_REF_TYPE_INT8;
+ case DType_INT16:
+ return TOSA_REF_TYPE_INT16;
+ case DType_INT32:
+ return TOSA_REF_TYPE_INT32;
+ case DType_INT48:
+ return TOSA_REF_TYPE_INT48;
+ case DType_FP32:
+ return TOSA_REF_TYPE_FP32;
+ case DType_UINT16:
+ return TOSA_REF_TYPE_UINT16;
+ case DType_FP16:
+ return TOSA_REF_TYPE_FP16;
+ case DType_BF16:
+ return TOSA_REF_TYPE_BF16;
+ default:
+ break;
+ }
+ return TOSA_REF_TYPE_UNKNOWN;
+}
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/include/func_config.h b/reference_model/include/func_config.h
index c1f8ef6..b92845b 100644
--- a/reference_model/include/func_config.h
+++ b/reference_model/include/func_config.h
@@ -48,6 +48,7 @@ struct func_config_t
uint32_t tosa_profile = 1;
uint32_t dump_intermediates = 0;
std::string fp_format = "0.5";
+ uint32_t precise_mode = 0;
bool float_is_big_endian = false; // Set in arith_util.h by float_is_big_endian()
tosa_level_t tosa_level;
diff --git a/reference_model/src/arith_util.h b/reference_model/src/arith_util.h
index 59bdf44..76f3bb8 100644
--- a/reference_model/src/arith_util.h
+++ b/reference_model/src/arith_util.h
@@ -30,11 +30,11 @@
#include <fenv.h>
#include <math.h>
#define __STDC_LIMIT_MACROS //enable min/max of plain data type
+#include "dtype.h"
#include "func_config.h"
#include "func_debug.h"
#include "half.hpp"
#include "inttypes.h"
-#include "tosa_generated.h"
#include <Eigen/Core>
#include <bitset>
#include <cassert>
@@ -45,6 +45,7 @@
using namespace tosa;
using namespace std;
+using namespace TosaReference;
inline size_t _count_one(uint64_t val)
{
@@ -259,27 +260,27 @@ inline bool float_is_big_endian()
return f_as_bytes[0] != f_neg_as_bytes[0];
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
float fpTrunc(float f_in)
{
- /* Truncates a float value based on the DType it represents.*/
+ /* Truncates a float value based on the TOSA_REF_TYPE it represents.*/
switch (Dtype)
{
- case DType_BF16:
+ case TOSA_REF_TYPE_BF16:
truncateFloatToBFloat(&f_in, 1);
break;
- case DType_FP16:
+ case TOSA_REF_TYPE_FP16:
// Cast to temporary float16 value before casting back to float32
{
half_float::half h = half_float::half_cast<half_float::half, float>(f_in);
f_in = half_float::half_cast<float, half_float::half>(h);
break;
}
- case DType_FP32:
+ case TOSA_REF_TYPE_FP32:
// No-op for fp32
break;
default:
- ASSERT_MSG(false, "DType %s should not be float-truncated.", EnumNameDType(Dtype));
+ ASSERT_MSG(false, "TOSA_REF_TYPE %s should not be float-truncated.", EnumNameTOSAREFTYPE(Dtype));
}
return f_in;
}
diff --git a/reference_model/src/command_line_utils.h b/reference_model/src/command_line_utils.h
index 4e6e555..8fdff01 100644
--- a/reference_model/src/command_line_utils.h
+++ b/reference_model/src/command_line_utils.h
@@ -59,6 +59,7 @@ int func_model_parse_cmd_line(
("tosa_level", "Set TOSA level (NONE, EIGHTK)",
cxxopts::value<tosa_level_t>(func_config.tosa_level))
("dump_intermediates", "Dump intermediate tensors (0/1)", cxxopts::value<uint32_t>(func_config.dump_intermediates))
+ ("p,precise_mode", "Calculate floating point operations in FP64 (0/1)", cxxopts::value<uint32_t>(func_config.precise_mode))
("v,version", "print model version")
("i,input_tensor_file", "specify input tensor files", cxxopts::value<std::vector<std::string>>())
("l,loglevel", func_debug.get_debug_verbosity_help_string(), cxxopts::value<std::string>())
diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h
index a9a336b..3433192 100644
--- a/reference_model/src/graph_node.h
+++ b/reference_model/src/graph_node.h
@@ -19,29 +19,30 @@
#include "attribute.h"
#include "subgraph_traverser.h"
#include "tensor.h"
-#include "tosa_generated.h"
#include <iostream>
-#define DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) template class TosaReference::OP<RANK, DType_##DTYPE>;
+#define DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) \
+ template class TosaReference::OP<RANK, TOSA_REF_TYPE_##DTYPE>;
#define DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \
- template class TosaReference::OP<RANK, DType_##DTYPE1, DType_##DTYPE2>;
+ template class TosaReference::OP<RANK, TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>;
#define DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \
- template class TosaReference::OP<RANK1, RANK2, DType_##DTYPE>;
+ template class TosaReference::OP<RANK1, RANK2, TOSA_REF_TYPE_##DTYPE>;
#define DEF_INSTANTIATE_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \
- template class TosaReference::OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>;
+ template class TosaReference::OP<RANK1, RANK2, TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>;
-#define DEF_INSTANTIATE_ONE_TYPE(OP, DTYPE) template class TosaReference::OP<DType_##DTYPE>;
+#define DEF_INSTANTIATE_ONE_TYPE(OP, DTYPE) template class TosaReference::OP<TOSA_REF_TYPE_##DTYPE>;
-#define DEF_INSTANTIATE_TWO_TYPE(OP, DTYPE1, DTYPE2) template class TosaReference::OP<DType_##DTYPE1, DType_##DTYPE2>;
+#define DEF_INSTANTIATE_TWO_TYPE(OP, DTYPE1, DTYPE2) \
+ template class TosaReference::OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>;
-#define DEF_INSTANTIATE_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3) \
- template class TosaReference::OP<DType_##DTYPE1, DType_##DTYPE2, DType_##DTYPE3>;
+#define DEF_INSTANTIATE_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3) \
+ template class TosaReference::OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##DTYPE3>;
#define DEF_INSTANTIATE_THREE_TYPE_RESIZE(OP, DTYPE1, DTYPE2, OP_TYPE) \
- template class TosaReference::OP<DType_##DTYPE1, DType_##DTYPE2, OP_TYPE>;
+ template class TosaReference::OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, OP_TYPE>;
#define DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 0, DTYPE) \
diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc
index fa39c75..8089a1a 100644
--- a/reference_model/src/model_runner_impl.cc
+++ b/reference_model/src/model_runner_impl.cc
@@ -186,13 +186,13 @@ int ModelRunnerImpl::setInput(std::string input_name, uint8_t* raw_ptr, size_t s
int status = 0;
switch (tensor->getDtype())
{
- case DType::DType_FP16: {
+ case TOSA_REF_TYPE_FP16: {
auto typed_ptr = reinterpret_cast<half_float::half*>(raw_ptr);
const int elements = size / sizeof(half_float::half);
status = setInput(input_name, ArrayProxy(elements, typed_ptr));
break;
}
- case DType::DType_FP32: {
+ case TOSA_REF_TYPE_FP32: {
auto typed_ptr = reinterpret_cast<float*>(raw_ptr);
const int elements = size / sizeof(float);
status = setInput(input_name, ArrayProxy(elements, typed_ptr));
@@ -252,13 +252,13 @@ int ModelRunnerImpl::getOutput(std::string output_name, uint8_t* raw_ptr, size_t
int status = 0;
switch (tensor->getDtype())
{
- case DType::DType_FP16: {
+ case TOSA_REF_TYPE_FP16: {
auto typed_ptr = reinterpret_cast<half_float::half*>(raw_ptr);
const int elements = size / sizeof(half_float::half);
status = tensor->writeToVector(ArrayProxy(elements, typed_ptr));
break;
}
- case DType::DType_FP32: {
+ case TOSA_REF_TYPE_FP32: {
auto typed_ptr = reinterpret_cast<float*>(raw_ptr);
const int elements = size / sizeof(float);
status = tensor->writeToVector(ArrayProxy(elements, typed_ptr));
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc
index 24bd077..6681d6d 100644
--- a/reference_model/src/ops/activation_funcs.cc
+++ b/reference_model/src/ops/activation_funcs.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -23,7 +23,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpClamp<Rank, Dtype>::register_fcn()
{
// Check Tosa Level
@@ -32,9 +32,9 @@ int OpClamp<Rank, Dtype>::register_fcn()
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
{
InEigenType min = (InEigenType)attribute->min_fp();
InEigenType max = (InEigenType)attribute->max_fp();
@@ -43,8 +43,17 @@ int OpClamp<Rank, Dtype>::register_fcn()
this->fcn = [min, max](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(a <= min ? min : a >= max ? max : a); };
}
break;
- case DType_INT8:
- case DType_INT16:
+ case TOSA_REF_TYPE_FP64:
+ {
+ InEigenType min = (InEigenType)attribute->min_fp();
+ InEigenType max = (InEigenType)attribute->max_fp();
+ ERROR_IF(max < min, "OpClamp: max smaller than min");
+
+ this->fcn = [min, max](InEigenType a) -> OutEigenType { return (a <= min ? min : a >= max ? max : a); };
+ }
+ break;
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
{
InEigenType min = (InEigenType)attribute->min_int();
InEigenType max = (InEigenType)attribute->max_int();
@@ -53,19 +62,19 @@ int OpClamp<Rank, Dtype>::register_fcn()
}
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpClamp<Rank, Dtype>::~OpClamp()
{
if (attribute) delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpSigmoid<Rank, Dtype>::register_fcn()
{
// Check Tosa Level
@@ -74,21 +83,24 @@ int OpSigmoid<Rank, Dtype>::register_fcn()
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType {
return fpTrunc<Dtype>(1.f / (1.f + (expf(-1.f * a))));
};
break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a) -> OutEigenType { return (1.L / (1.L + (exp(-1.L * a)))); };
+ break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpTanh<Rank, Dtype>::register_fcn()
{
// Check Tosa Level
@@ -97,13 +109,16 @@ int OpTanh<Rank, Dtype>::register_fcn()
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(tanhf(a)); };
break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a) -> OutEigenType { return tanh(a); };
+ break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
@@ -115,11 +130,14 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP64);
diff --git a/reference_model/src/ops/activation_funcs.h b/reference_model/src/ops/activation_funcs.h
index 9a697cd..2372fcb 100644
--- a/reference_model/src/ops/activation_funcs.h
+++ b/reference_model/src/ops/activation_funcs.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -24,7 +24,7 @@ using namespace tosa;
namespace TosaReference
{
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpClamp : public UnaryNode<Rank, Dtype>
{
public:
@@ -45,7 +45,7 @@ protected:
TosaClampAttribute* attribute;
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpSigmoid : public UnaryNode<Rank, Dtype>
{
public:
@@ -61,7 +61,7 @@ public:
virtual int register_fcn();
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpTanh : public UnaryNode<Rank, Dtype>
{
public:
diff --git a/reference_model/src/ops/comparison.cc b/reference_model/src/ops/comparison.cc
index a5711eb..8a084c7 100644
--- a/reference_model/src/ops/comparison.cc
+++ b/reference_model/src/ops/comparison.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -22,7 +22,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpEqual<Rank, Dtype>::register_fcn()
{
// Check Tosa Level
@@ -31,20 +31,21 @@ int OpEqual<Rank, Dtype>::register_fcn()
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
- case DType_INT32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
+ case TOSA_REF_TYPE_INT32:
+ case TOSA_REF_TYPE_FP64:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpGreater<Rank, Dtype>::register_fcn()
{
// Check Tosa Level
@@ -53,20 +54,21 @@ int OpGreater<Rank, Dtype>::register_fcn()
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
- case DType_INT32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
+ case TOSA_REF_TYPE_INT32:
+ case TOSA_REF_TYPE_FP64:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpGreaterEqual<Rank, Dtype>::register_fcn()
{
// Check Tosa Level
@@ -75,14 +77,15 @@ int OpGreaterEqual<Rank, Dtype>::register_fcn()
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
- case DType_INT32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
+ case TOSA_REF_TYPE_INT32:
+ case TOSA_REF_TYPE_FP64:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
@@ -93,13 +96,16 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP64);
diff --git a/reference_model/src/ops/comparison.h b/reference_model/src/ops/comparison.h
index 29e6b5a..263df6a 100644
--- a/reference_model/src/ops/comparison.h
+++ b/reference_model/src/ops/comparison.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -24,45 +24,45 @@ using namespace tosa;
namespace TosaReference
{
-template <int Rank, DType Dtype>
-class OpEqual : public BinaryNode<Rank, Dtype, DType_BOOL>
+template <int Rank, TOSA_REF_TYPE Dtype>
+class OpEqual : public BinaryNode<Rank, Dtype, TOSA_REF_TYPE_BOOL>
{
public:
OpEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
- : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_EQUAL, id_)
+ : BinaryNode<Rank, Dtype, TOSA_REF_TYPE_BOOL>(sgt_, Op_EQUAL, id_)
{
register_fcn();
}
using InEigenType = typename GetEigenType<Dtype>::type;
- using OutEigenType = typename GetEigenType<DType_BOOL>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BOOL>::type;
virtual int register_fcn();
};
-template <int Rank, DType Dtype>
-class OpGreater : public BinaryNode<Rank, Dtype, DType_BOOL>
+template <int Rank, TOSA_REF_TYPE Dtype>
+class OpGreater : public BinaryNode<Rank, Dtype, TOSA_REF_TYPE_BOOL>
{
public:
OpGreater(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
- : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_GREATER, id_)
+ : BinaryNode<Rank, Dtype, TOSA_REF_TYPE_BOOL>(sgt_, Op_GREATER, id_)
{
register_fcn();
}
using InEigenType = typename GetEigenType<Dtype>::type;
- using OutEigenType = typename GetEigenType<DType_BOOL>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BOOL>::type;
virtual int register_fcn();
};
-template <int Rank, DType Dtype>
-class OpGreaterEqual : public BinaryNode<Rank, Dtype, DType_BOOL>
+template <int Rank, TOSA_REF_TYPE Dtype>
+class OpGreaterEqual : public BinaryNode<Rank, Dtype, TOSA_REF_TYPE_BOOL>
{
public:
OpGreaterEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
- : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_EQUAL, id_)
+ : BinaryNode<Rank, Dtype, TOSA_REF_TYPE_BOOL>(sgt_, Op_EQUAL, id_)
{
register_fcn();
}
using InEigenType = typename GetEigenType<Dtype>::type;
- using OutEigenType = typename GetEigenType<DType_BOOL>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BOOL>::type;
virtual int register_fcn();
};
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc
index f573d5b..03ad6c6 100644
--- a/reference_model/src/ops/control_flow.cc
+++ b/reference_model/src/ops/control_flow.cc
@@ -174,8 +174,8 @@ int OpCondIf::checkTensorAttributes()
{
ERROR_IF(getInputs().size() < 1, "OpCondIf: must have at least 1 operand");
- ERROR_IF(inputs[0]->getDtype() != DType_BOOL || inputs[0]->getRank() != 0,
- "OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNamesDType()[inputs[0]->getDtype()],
+ ERROR_IF(inputs[0]->getDtype() != TOSA_REF_TYPE_BOOL || inputs[0]->getRank() != 0,
+ "OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNameTOSAREFTYPE(inputs[0]->getDtype()),
inputs[0]->getRank());
cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]);
@@ -223,9 +223,9 @@ int OpCondIf::checkTensorAttributes()
std::string else_block_input_name = else_block->GetInputs()[i];
TosaSerializationTensor* then_block_input = then_block->GetTensorByName(then_block_input_name);
TosaSerializationTensor* else_block_input = else_block->GetTensorByName(else_block_input_name);
- ERROR_IF(operator_input->getDtype() != then_block_input->GetDtype(),
+ ERROR_IF(operator_input->getDtype() != ConvertDType(then_block_input->GetDtype()),
"OpCondIf: input tensor type mismatch with then_block input type");
- ERROR_IF(operator_input->getDtype() != else_block_input->GetDtype(),
+ ERROR_IF(operator_input->getDtype() != ConvertDType(else_block_input->GetDtype()),
"OpCondIf: input tensor type mismatch with else_block input type");
ERROR_IF(operator_input->getRank() != (int32_t)then_block_input->GetShape().size(),
"OpCondIf: input tensor rank mismatch with then_block input rank");
@@ -247,9 +247,9 @@ int OpCondIf::checkTensorAttributes()
std::string else_block_output_name = else_block->GetOutputs()[i];
TosaSerializationTensor* then_block_output = then_block->GetTensorByName(then_block_output_name);
TosaSerializationTensor* else_block_output = else_block->GetTensorByName(else_block_output_name);
- ERROR_IF(operator_output->getDtype() != then_block_output->GetDtype(),
+ ERROR_IF(operator_output->getDtype() != ConvertDType(then_block_output->GetDtype()),
"OpCondIf: output tensor type mismatch with then_block output type");
- ERROR_IF(operator_output->getDtype() != else_block_output->GetDtype(),
+ ERROR_IF(operator_output->getDtype() != ConvertDType(else_block_output->GetDtype()),
"OpCondIf: output tensor type mismatch with else_block output type");
ERROR_IF(operator_output->getRank() != (int32_t)then_block_output->GetShape().size(),
"OpCondIf: output tensor rank mismatch with then_block output rank");
@@ -364,11 +364,11 @@ int OpWhileLoop::checkTensorAttributes()
TosaSerializationTensor* body_block_input = body_block->GetTensorByName(body_block_input_name);
TosaSerializationTensor* body_block_output = body_block->GetTensorByName(body_block_output_name);
- ERROR_IF(operator_input->getDtype() != cond_block_input->GetDtype(),
+ ERROR_IF(operator_input->getDtype() != ConvertDType(cond_block_input->GetDtype()),
"OpWhileLoop: input tensor type mismatch with cond_block input type");
- ERROR_IF(operator_input->getDtype() != body_block_input->GetDtype(),
+ ERROR_IF(operator_input->getDtype() != ConvertDType(body_block_input->GetDtype()),
"OpWhileLoop: input tensor type mismatch with body_block input type");
- ERROR_IF(operator_input->getDtype() != body_block_output->GetDtype(),
+ ERROR_IF(operator_input->getDtype() != ConvertDType(body_block_output->GetDtype()),
"OpWhileLoop: input tensor type mismatch with body_block output type");
ERROR_IF(operator_input->getRank() != (int32_t)cond_block_input->GetShape().size(),
"OpWhileLoop: input tensor rank mismatch with cond_block input rank");
@@ -399,8 +399,7 @@ int OpWhileLoop::checkTensorAttributes()
int OpWhileLoop::eval()
{
-
- TosaReference::Tensor0<bool> cond_output_ctensor(std::string("cond_output"), DType_BOOL, std::vector<int32_t>({}));
+ TosaReference::Tensor0<bool> cond_output_ctensor("cond_output", DType_BOOL, std::vector<int32_t>({}));
cond_output_ctensor.allocate();
std::vector<TosaReference::Tensor*> cond_block_outputs;
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index a189466..442cef8 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -20,7 +20,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -32,14 +32,14 @@ OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Axis);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpConcat<Rank, Dtype>::~OpConcat()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpConcat<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -100,7 +100,7 @@ int OpConcat<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpConcat<Rank, Dtype>::eval()
{
@@ -124,7 +124,7 @@ int OpConcat<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -136,12 +136,12 @@ OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Pad);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpPad<Rank, Dtype>::~OpPad()
{
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpPad<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -185,22 +185,23 @@ int OpPad<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpPad<Rank, Dtype>::eval()
{
InEigenType pad_value = 0;
switch (Dtype)
{
- case DType_BOOL:
- case DType_INT8:
- case DType_INT16:
- case DType_INT32:
+ case TOSA_REF_TYPE_BOOL:
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
+ case TOSA_REF_TYPE_INT32:
pad_value = (InEigenType)attribute->pad_const_int();
break;
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
+ case TOSA_REF_TYPE_FP64:
pad_value = (InEigenType)attribute->pad_const_fp();
break;
default:
@@ -213,7 +214,7 @@ int OpPad<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int InRank, int OutRank, DType Dtype>
+template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -225,14 +226,14 @@ OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Reshape);
}
-template <int InRank, int OutRank, DType Dtype>
+template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
OpReshape<InRank, OutRank, Dtype>::~OpReshape()
{
if (attribute)
delete attribute;
}
-template <int InRank, int OutRank, DType Dtype>
+template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -270,7 +271,7 @@ int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int InRank, int OutRank, DType Dtype>
+template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
int OpReshape<InRank, OutRank, Dtype>::eval()
{
for (int32_t d = 0; d < OutRank; d++)
@@ -313,7 +314,7 @@ int OpReshape<InRank, OutRank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -325,14 +326,14 @@ OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Axis);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpReverse<Rank, Dtype>::~OpReverse()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReverse<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -376,7 +377,7 @@ int OpReverse<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReverse<Rank, Dtype>::eval()
{
out->getTensor() = in->getTensor().reverse(reverse_array);
@@ -384,7 +385,7 @@ int OpReverse<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -396,14 +397,14 @@ OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Slice);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpSlice<Rank, Dtype>::~OpSlice()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpSlice<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -449,7 +450,7 @@ int OpSlice<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpSlice<Rank, Dtype>::eval()
{
out->getTensor() = in->getTensor().slice(begin_array, size_array);
@@ -457,7 +458,7 @@ int OpSlice<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -469,14 +470,14 @@ OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Tile);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpTileBase<Rank, Dtype>::~OpTileBase()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpTileBase<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -518,14 +519,14 @@ int OpTileBase<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpTile<Rank, Dtype>::eval()
{
// primary template shouldn't be called
- FATAL_ERROR("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]);
+ FATAL_ERROR("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNameTOSAREFTYPE(Dtype));
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpTile<1, Dtype>::eval()
{
for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
@@ -537,7 +538,7 @@ int OpTile<1, Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpTile<2, Dtype>::eval()
{
for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
@@ -553,7 +554,7 @@ int OpTile<2, Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpTile<3, Dtype>::eval()
{
for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
@@ -573,7 +574,7 @@ int OpTile<3, Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpTile<4, Dtype>::eval()
{
for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
@@ -597,7 +598,7 @@ int OpTile<4, Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpTile<5, Dtype>::eval()
{
for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
@@ -626,7 +627,7 @@ int OpTile<5, Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpTile<6, Dtype>::eval()
{
for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
@@ -659,7 +660,7 @@ int OpTile<6, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -671,13 +672,13 @@ OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Transpose);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpTranspose<Rank, Dtype>::~OpTranspose()
{
if (attribute) delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpTranspose<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -727,7 +728,7 @@ int OpTranspose<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpTranspose<Rank, Dtype>::eval()
{
out->getTensor() = in->getTensor().shuffle(perm_array);
@@ -743,6 +744,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16);
@@ -751,6 +753,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64);
DEF_INSTANTIATE_RESHAPE(OpReshape, FP16);
DEF_INSTANTIATE_RESHAPE(OpReshape, BF16);
@@ -759,6 +762,7 @@ DEF_INSTANTIATE_RESHAPE(OpReshape, INT8);
DEF_INSTANTIATE_RESHAPE(OpReshape, INT16);
DEF_INSTANTIATE_RESHAPE(OpReshape, INT32);
DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL);
+DEF_INSTANTIATE_RESHAPE(OpReshape, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16);
@@ -767,6 +771,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
@@ -775,6 +780,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16);
@@ -783,6 +789,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
@@ -791,6 +798,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
@@ -799,3 +807,4 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64);
diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h
index 3a6cb0d..94ce248 100644
--- a/reference_model/src/ops/data_layout.h
+++ b/reference_model/src/ops/data_layout.h
@@ -23,7 +23,7 @@ using namespace tosa;
namespace TosaReference
{
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpConcat : public GraphNode
{
public:
@@ -45,7 +45,7 @@ protected:
TosaReference::TensorTemplate<TOut>* out;
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpPad : public GraphNode
{
public:
@@ -66,7 +66,7 @@ protected:
TosaPadAttribute* attribute;
};
-template <int InRank, int OutRank, DType Dtype>
+template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
class OpReshape : public GraphNode
{
public:
@@ -90,7 +90,7 @@ protected:
TosaReference::TensorTemplate<TOut>* out;
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpReverse : public GraphNode
{
public:
@@ -112,7 +112,7 @@ protected:
Eigen::array<bool, Rank> reverse_array;
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpSlice : public GraphNode
{
public:
@@ -135,7 +135,7 @@ protected:
TosaReference::TensorTemplate<TOut>* out;
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpTileBase : public GraphNode
{
public:
@@ -156,7 +156,7 @@ protected:
};
// primary template for op tile
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpTile : public OpTileBase<Rank, Dtype>
{
public:
@@ -170,12 +170,12 @@ protected:
// partial specialization for specific rank
#define DEF_OP_TILE_RANK(N) \
- template <DType Dtype> \
+ template <TOSA_REF_TYPE Dtype> \
class OpTile<N, Dtype> : public OpTileBase<N, Dtype> \
{ \
public: \
- OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
- : OpTileBase<N, Dtype>(sgt_, attribute_, id_) \
+ OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
+ : OpTileBase<N, Dtype>(sgt_, attribute_, id_) \
{} \
\
protected: \
@@ -191,7 +191,7 @@ DEF_OP_TILE_RANK(6)
#undef DEF_OP_TILE_RANK
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpTranspose : public GraphNode
{
public:
diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc
index f5304a5..b7f987a 100644
--- a/reference_model/src/ops/data_nodes.cc
+++ b/reference_model/src/ops/data_nodes.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -42,7 +42,7 @@ int OpConst::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpIdentity<Rank, Dtype>::OpIdentity(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -52,11 +52,11 @@ OpIdentity<Rank, Dtype>::OpIdentity(SubgraphTraverser* sgt_,
setRequiredRank(0, 6);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpIdentity<Rank, Dtype>::~OpIdentity()
{}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpIdentity<Rank, Dtype>::checkTensorAttributes()
{
@@ -78,7 +78,7 @@ int OpIdentity<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpIdentity<Rank, Dtype>::eval()
{
out->getTensor() = in->getTensor();
@@ -96,3 +96,4 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP64);
diff --git a/reference_model/src/ops/data_nodes.h b/reference_model/src/ops/data_nodes.h
index 8761a08..395c667 100644
--- a/reference_model/src/ops/data_nodes.h
+++ b/reference_model/src/ops/data_nodes.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -31,7 +31,7 @@ public:
virtual int eval();
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpIdentity : public GraphNode
{
public:
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 6aa0c0f..c5801e7 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -22,7 +22,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_,
const Op& op_,
uint64_t id_)
@@ -37,11 +37,11 @@ BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_,
fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); };
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
BinaryNodeBase<Rank, InDtype, OutDtype>::~BinaryNodeBase()
{}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -90,7 +90,7 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
{
const std::vector<int>& a_shape = a->getShape();
@@ -106,7 +106,7 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
return 0;
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int BinaryNode<Rank, InDtype, OutDtype>::eval()
{
this->broadcast();
@@ -124,7 +124,7 @@ int BinaryNode<Rank, InDtype, OutDtype>::eval()
}
// still need to partial specialize this, or Eigen will throw static assertion
-template <DType InDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int BinaryNode<0, InDtype, OutDtype>::eval()
{
this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn);
@@ -132,12 +132,12 @@ int BinaryNode<0, InDtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpAdd<Rank, Dtype>::register_fcn()
{
switch (InDtype)
{
- case DType_INT32:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
int64_t res_in_64 = static_cast<int64_t>(a) + b;
int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
@@ -146,36 +146,39 @@ int OpAdd<Rank, Dtype>::register_fcn()
return static_cast<InEigenType>(res_in_64);
};
break;
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a + b); };
break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; };
+ break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
{
bool round = attribute->round();
int32_t num_bits = 0;
switch (Dtype)
{
- case DType_INT8:
+ case TOSA_REF_TYPE_INT8:
num_bits = 8;
break;
- case DType_INT16:
+ case TOSA_REF_TYPE_INT16:
num_bits = 16;
break;
- case DType_INT32:
+ case TOSA_REF_TYPE_INT32:
num_bits = 32;
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
@@ -195,69 +198,69 @@ int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpArithmeticRightShift<Rank, Dtype>::~OpArithmeticRightShift()
{
if (attribute) delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpBitwiseAnd<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_INT8:
- case DType_INT16:
- case DType_INT32:
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpBitwiseOr<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_INT8:
- case DType_INT16:
- case DType_INT32:
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpBitwiseXor<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_INT8:
- case DType_INT16:
- case DType_INT32:
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpIntdiv<Rank, Dtype>::register_fcn()
{
switch (InDtype)
{
- case DType_INT32:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
REQUIRE(b != 0, "OpIntDiv: divisor must be non-zero value");
int64_t res_in_64 = static_cast<int64_t>(a) / b;
@@ -268,47 +271,47 @@ int OpIntdiv<Rank, Dtype>::register_fcn()
};
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpLogicalAnd<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_BOOL:
+ case TOSA_REF_TYPE_BOOL:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_INT8:
+ case TOSA_REF_TYPE_INT8:
this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
(int32_t)b);
return static_cast<OutEigenType>(static_cast<int8_t>(a << b));
};
break;
- case DType_INT16:
+ case TOSA_REF_TYPE_INT16:
this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
(int32_t)b);
return static_cast<OutEigenType>(static_cast<int16_t>(a << b));
};
break;
- case DType_INT32:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
(int32_t)b);
@@ -316,32 +319,32 @@ int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
};
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpLogicalRightShift<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_INT8:
+ case TOSA_REF_TYPE_INT8:
this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
(int32_t)b);
return static_cast<OutEigenType>(static_cast<int8_t>(a) >> b);
};
break;
- case DType_INT16:
+ case TOSA_REF_TYPE_INT16:
this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
(int32_t)b);
return static_cast<OutEigenType>(static_cast<int16_t>(a) >> b);
};
break;
- case DType_INT32:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
(int32_t)b);
@@ -349,91 +352,96 @@ int OpLogicalRightShift<Rank, Dtype>::register_fcn()
};
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpLogicalOr<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_BOOL:
+ case TOSA_REF_TYPE_BOOL:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpLogicalXor<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_BOOL:
+ case TOSA_REF_TYPE_BOOL:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpMaximum<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
- case DType_INT32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
+ case TOSA_REF_TYPE_FP64:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpMinimum<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
- case DType_INT32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
+ case TOSA_REF_TYPE_FP64:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int OpMul<Rank, InDtype, OutDtype>::register_fcn()
{
int32_t shift = attribute->shift();
switch (InDtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a * b); };
break;
- case DType_INT32:
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
+ break;
+ case TOSA_REF_TYPE_INT32:
this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType {
int64_t result;
if (shift > 0)
@@ -457,8 +465,8 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn()
return static_cast<OutEigenType>(result);
};
break;
- case DType_INT8:
- case DType_INT16:
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType {
OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs;
@@ -468,41 +476,44 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn()
};
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype));
}
return 0;
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
OpMul<Rank, InDtype, OutDtype>::~OpMul()
{
if (attribute) delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpPow<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(powf(a, b)); };
break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return pow(a, b); };
+ break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpSub<Rank, Dtype>::register_fcn()
{
switch (InDtype)
{
- case DType_INT32:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
int64_t res_in_64 = static_cast<int64_t>(a) - b;
int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
@@ -511,19 +522,22 @@ int OpSub<Rank, Dtype>::register_fcn()
return static_cast<InEigenType>(res_in_64);
};
break;
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a - b); };
break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; };
+ break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype));
}
return 0;
}
-template <int Rank, DType InDtype>
+template <int Rank, TOSA_REF_TYPE InDtype>
OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -535,13 +549,13 @@ OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Table);
}
-template <int Rank, DType InDtype>
+template <int Rank, TOSA_REF_TYPE InDtype>
OpTable<Rank, InDtype>::~OpTable()
{
if (attribute) delete attribute;
}
-template <int Rank, DType InDtype>
+template <int Rank, TOSA_REF_TYPE InDtype>
int OpTable<Rank, InDtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -573,12 +587,12 @@ int OpTable<Rank, InDtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType InDtype>
+template <int Rank, TOSA_REF_TYPE InDtype>
int OpTable<Rank, InDtype>::eval()
{
switch (InDtype)
{
- case DType_INT8:
+ case TOSA_REF_TYPE_INT8:
this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
int32_t index = input_truncated - QInMin;
@@ -587,7 +601,7 @@ int OpTable<Rank, InDtype>::eval()
return value;
});
break;
- case DType_INT16:
+ case TOSA_REF_TYPE_INT16:
this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
// 1. make sure input is int16 range
int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
@@ -610,7 +624,7 @@ int OpTable<Rank, InDtype>::eval()
});
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype));
}
return GraphNode::eval();
@@ -630,11 +644,13 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP16, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, BF16, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP32, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT32, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP64, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
@@ -672,11 +688,13 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, BF16, BF16);
@@ -684,15 +702,18 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP64, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
@@ -703,3 +724,4 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP16, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, BF16, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP32, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP64, BOOL);
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
index 020ddb5..5f6e531 100644
--- a/reference_model/src/ops/ewise_binary.h
+++ b/reference_model/src/ops/ewise_binary.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -38,7 +38,7 @@ namespace TosaReference
// the way of registering lambda + .binaryExpr() might sacrifice performance here
// but it can avoid partially specialization for combination of {rankN, rank0} x {FP32/INT32, QU8, ...}
// needs to revisit if performance becomes a bottleneck here
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
class BinaryNodeBase : public GraphNode
{
public:
@@ -67,7 +67,7 @@ protected:
};
// primary class
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
class BinaryNode : public BinaryNodeBase<Rank, InDtype, OutDtype>
{
public:
@@ -86,7 +86,7 @@ public:
};
// partial specialization for rank 0
-template <DType InDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
class BinaryNode<0, InDtype, OutDtype> : public BinaryNodeBase<0, InDtype, OutDtype>
{
public:
@@ -100,19 +100,19 @@ public:
};
#define DEF_TEMPLATE_BINARY_OP_DEFAULT(Opname, OPNAME) \
- template <int Rank, DType Dtype> \
+ template <int Rank, TOSA_REF_TYPE Dtype> \
class Op##Opname : public BinaryNode<Rank, Dtype, Dtype> \
{ \
public: \
- Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
- : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_##OPNAME, id_) \
+ Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
+ : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_##OPNAME, id_) \
{ \
register_fcn(); \
} \
- static constexpr DType InDtype = Dtype; \
- static constexpr DType OutDtype = Dtype; \
- using InEigenType = typename GetEigenType<InDtype>::type; \
- using OutEigenType = typename GetEigenType<OutDtype>::type; \
+ static constexpr TOSA_REF_TYPE InDtype = Dtype; \
+ static constexpr TOSA_REF_TYPE OutDtype = Dtype; \
+ using InEigenType = typename GetEigenType<InDtype>::type; \
+ using OutEigenType = typename GetEigenType<OutDtype>::type; \
virtual int register_fcn(); \
};
@@ -133,7 +133,7 @@ DEF_TEMPLATE_BINARY_OP_DEFAULT(Sub, SUB)
#undef DEF_TEMPLATE_BINARY_OP_DEFAULT
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpArithmeticRightShift : public BinaryNode<Rank, Dtype, Dtype>
{
public:
@@ -154,7 +154,7 @@ protected:
TosaArithmeticRightShiftAttribute* attribute;
};
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
class OpMul : public BinaryNode<Rank, InDtype, OutDtype>
{
public:
@@ -175,7 +175,7 @@ protected:
TosaMulAttribute* attribute;
};
-template <int Rank, DType InDtype>
+template <int Rank, TOSA_REF_TYPE InDtype>
class OpTable : public GraphNode
{
public:
@@ -185,9 +185,11 @@ public:
virtual int checkTensorAttributes();
virtual int eval();
- static constexpr DType TableDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT16;
- static constexpr DType OutDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT32;
- static constexpr uint32_t TableNumEntries = (InDtype == DType_INT8) ? 256 : 513;
+ static constexpr TOSA_REF_TYPE TableDtype =
+ (InDtype == TOSA_REF_TYPE_INT8) ? TOSA_REF_TYPE_INT8 : TOSA_REF_TYPE_INT16;
+ static constexpr TOSA_REF_TYPE OutDtype =
+ (InDtype == TOSA_REF_TYPE_INT8) ? TOSA_REF_TYPE_INT8 : TOSA_REF_TYPE_INT32;
+ static constexpr uint32_t TableNumEntries = (InDtype == TOSA_REF_TYPE_INT8) ? 256 : 513;
using InEigenType = typename GetEigenType<InDtype>::type;
using TableEigenType = typename GetEigenType<TableDtype>::type;
using OutEigenType = typename GetEigenType<OutDtype>::type;
diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc
index 4d53ae4..090ce29 100644
--- a/reference_model/src/ops/ewise_ternary.cc
+++ b/reference_model/src/ops/ewise_ternary.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -19,7 +19,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpSelectBase<Rank, Dtype>::OpSelectBase(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -29,11 +29,11 @@ OpSelectBase<Rank, Dtype>::OpSelectBase(SubgraphTraverser* sgt_,
setRequiredRank(0, 6);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpSelectBase<Rank, Dtype>::~OpSelectBase()
{}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpSelectBase<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -66,13 +66,13 @@ int OpSelectBase<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpSelectBase<Rank, Dtype>::eval()
{
FATAL_ERROR("shouldn't be called");
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpSelect<Rank, Dtype>::broadcast()
{
const std::vector<int>& cond_shape = this->cond->getShape();
@@ -90,7 +90,7 @@ int OpSelect<Rank, Dtype>::broadcast()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpSelect<Rank, Dtype>::eval()
{
this->broadcast();
@@ -102,7 +102,7 @@ int OpSelect<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpSelect<0, Dtype>::eval()
{
this->out->getTensor() = this->cond->getTensor().select(this->then_val->getTensor(), this->else_val->getTensor());
@@ -118,6 +118,7 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BF16);
@@ -126,3 +127,4 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP64);
diff --git a/reference_model/src/ops/ewise_ternary.h b/reference_model/src/ops/ewise_ternary.h
index 75a2194..c6970cb 100644
--- a/reference_model/src/ops/ewise_ternary.h
+++ b/reference_model/src/ops/ewise_ternary.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -29,7 +29,7 @@ namespace TosaReference
// 3. Else_val: Rank N, type=<V>
// 4. Result: Rank N, type=<V>
// Cond, Then_val, Else_val need to be mutually-broadcastable
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpSelectBase : public GraphNode
{
public:
@@ -39,7 +39,7 @@ public:
virtual int checkTensorAttributes();
virtual int eval();
- using CondEigenType = typename GetEigenType<DType_BOOL>::type;
+ using CondEigenType = typename GetEigenType<TOSA_REF_TYPE_BOOL>::type;
using InEigenType = typename GetEigenType<Dtype>::type;
using TCond = Eigen::Tensor<CondEigenType, Rank>;
using TIn = Eigen::Tensor<InEigenType, Rank>;
@@ -55,7 +55,7 @@ protected:
};
// primary class
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpSelect : public OpSelectBase<Rank, Dtype>
{
public:
@@ -69,7 +69,7 @@ public:
};
// partial specialization for rank 0
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
class OpSelect<0, Dtype> : public OpSelectBase<0, Dtype>
{
public:
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc
index 8dc37e2..514cb84 100644
--- a/reference_model/src/ops/ewise_unary.cc
+++ b/reference_model/src/ops/ewise_unary.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -22,7 +22,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
UnaryNode<Rank, Dtype>::UnaryNode(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_)
: GraphNode(sgt_, op_, id_)
{
@@ -35,11 +35,11 @@ UnaryNode<Rank, Dtype>::UnaryNode(SubgraphTraverser* sgt_, const Op& op_, uint64
};
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
UnaryNode<Rank, Dtype>::~UnaryNode()
{}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int UnaryNode<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -69,7 +69,7 @@ int UnaryNode<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int UnaryNode<Rank, Dtype>::eval()
{
this->result->getTensor() = this->a->getTensor().unaryExpr(this->fcn);
@@ -77,71 +77,75 @@ int UnaryNode<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpAbs<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_FP32: // No fpTrunc for FP32 as it is a no-op
- case DType_INT32:
+ case TOSA_REF_TYPE_FP32: // No fpTrunc for FP32 as it is a no-op
+ case TOSA_REF_TYPE_FP64:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); };
break;
- case DType_FP16:
- case DType_BF16:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(a > (InEigenType)0 ? a : (-a)); };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpBitwiseNot<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_INT8:
- case DType_INT16:
- case DType_INT32:
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [](InEigenType a) -> OutEigenType { return ~a; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpCeil<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(ceilf(a)); };
break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a) -> OutEigenType { return ceil(a); };
+ break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpClz<Rank, Dtype>::register_fcn()
{
int32_t num_bits;
switch (Dtype)
{
- case DType_INT32:
+ case TOSA_REF_TYPE_INT32:
num_bits = 32;
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
this->fcn = [num_bits](int32_t a) -> int32_t {
@@ -163,73 +167,82 @@ int OpClz<Rank, Dtype>::register_fcn()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpExp<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(expf(a)); };
break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a) -> OutEigenType { return exp(a); };
+ break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpFloor<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(floorf(a)); };
break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a) -> OutEigenType { return floor(a); };
+ break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpLog<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(logf(a)); };
break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a) -> OutEigenType { return log(a); };
+ break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpLogicalNot<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_BOOL:
+ case TOSA_REF_TYPE_BOOL:
this->fcn = [](InEigenType a) -> OutEigenType { return !a; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpNegate<Rank, Dtype>::OpNegate(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -240,31 +253,37 @@ OpNegate<Rank, Dtype>::OpNegate(SubgraphTraverser* sgt_,
register_fcn();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpNegate<Rank, Dtype>::~OpNegate()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpNegate<Rank, Dtype>::register_fcn()
{
- ERROR_IF(Dtype != DType_INT8 && attribute->input1_zp() != 0, "OpNegate: zeropoint only for int8_t");
- ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpNegate: zeropoint only for int8_t");
+ ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->input1_zp() != 0, "OpNegate: zeropoint only for int8_t");
+ ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->output_zp() != 0, "OpNegate: zeropoint only for int8_t");
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType {
InEigenType result = -(a);
return fpTrunc<Dtype>(result);
};
break;
- case DType_INT16:
- case DType_INT32:
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a) -> OutEigenType {
+ OutEigenType result = -(a);
+ return result;
+ };
+ break;
+ case TOSA_REF_TYPE_INT16:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [this](InEigenType a) -> OutEigenType {
int64_t res_in_64 = 0L - a;
int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
@@ -272,7 +291,7 @@ int OpNegate<Rank, Dtype>::register_fcn()
REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpNegate: result not in acc type range (int32)");
int64_t max_clip_in_64, min_clip_in_64;
- if (Dtype == DType_INT16)
+ if (Dtype == TOSA_REF_TYPE_INT16)
{
max_clip_in_64 = static_cast<int64_t>(std::numeric_limits<int16_t>::max());
min_clip_in_64 = static_cast<int64_t>(std::numeric_limits<int16_t>::min());
@@ -285,7 +304,7 @@ int OpNegate<Rank, Dtype>::register_fcn()
return static_cast<InEigenType>(std::min<int64_t>(max_clip_in_64, std::max<int64_t>(min_clip_in_64, res_in_64)));
};
break;
- case DType_INT8:
+ case TOSA_REF_TYPE_INT8:
this->fcn = [this](InEigenType a) -> OutEigenType {
int64_t res_in_64 = 0 - (a - attribute->input1_zp());
int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
@@ -297,41 +316,47 @@ int OpNegate<Rank, Dtype>::register_fcn()
};
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReciprocal<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(1.0 / a); };
break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a) -> OutEigenType { return (1.0L / a); };
+ break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpRsqrt<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(1.0 / sqrtf(a)); };
break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a) -> OutEigenType { return (1.0L / sqrt(a)); };
+ break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
@@ -345,11 +370,13 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16);
@@ -358,20 +385,24 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL);
@@ -381,11 +412,14 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP64);
diff --git a/reference_model/src/ops/ewise_unary.h b/reference_model/src/ops/ewise_unary.h
index 16a4c88..21ee276 100644
--- a/reference_model/src/ops/ewise_unary.h
+++ b/reference_model/src/ops/ewise_unary.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -22,7 +22,7 @@ using namespace tosa;
namespace TosaReference
{
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class UnaryNode : public GraphNode
{
public:
@@ -45,11 +45,11 @@ protected:
};
#define DEF_TEMPLATE_UNARY_OP(Opname, OPNAME) \
- template <int Rank, DType Dtype> \
+ template <int Rank, TOSA_REF_TYPE Dtype> \
class Op##Opname : public UnaryNode<Rank, Dtype> \
{ \
public: \
- Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
+ Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
: UnaryNode<Rank, Dtype>(sgt_, Op_##OPNAME, id_) \
{ \
register_fcn(); \
@@ -75,7 +75,7 @@ DEF_TEMPLATE_UNARY_OP(Rsqrt, RSQRT)
#undef DEF_TEMPLATE_UNARY_OP
// Negate is the only unary op with attributes
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpNegate : public UnaryNode<Rank, Dtype>
{
public:
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index 190b354..ca12cfe 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -23,7 +23,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <DType InDtype, DType OutDtype, typename resize_t>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t>
OpResize<InDtype, OutDtype, resize_t>::OpResize(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -35,14 +35,14 @@ OpResize<InDtype, OutDtype, resize_t>::OpResize(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Resize);
}
-template <DType InDtype, DType OutDtype, typename resize_t>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t>
OpResize<InDtype, OutDtype, resize_t>::~OpResize()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType OutDtype, typename resize_t>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t>
int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -64,7 +64,8 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
if (this->mode == ResizeMode_BILINEAR)
{
- if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16)
+ if (OutDtype != TOSA_REF_TYPE_INT32 && OutDtype != TOSA_REF_TYPE_INT48 && OutDtype != TOSA_REF_TYPE_FP32 &&
+ OutDtype != TOSA_REF_TYPE_FP16 && OutDtype != TOSA_REF_TYPE_BF16 && OutDtype != TOSA_REF_TYPE_FP64)
{
printNodeValidationError("OpResize: invalid data type for BILINEAR");
return 1;
@@ -72,7 +73,8 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
}
else
{
- if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16)
+ if (OutDtype != TOSA_REF_TYPE_INT8 && OutDtype != TOSA_REF_TYPE_INT16 && OutDtype != TOSA_REF_TYPE_FP32 &&
+ OutDtype != TOSA_REF_TYPE_FP16 && OutDtype != TOSA_REF_TYPE_BF16 && OutDtype != TOSA_REF_TYPE_FP64)
{
printNodeValidationError("OpResize: invalid data type for NEAREST");
return 1;
@@ -87,7 +89,7 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType OutDtype, typename resize_t>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t>
int OpResize<InDtype, OutDtype, resize_t>::eval()
{
int in_batch = in->getShape()[0];
@@ -157,24 +159,38 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
int32_t y = oy * scale_y_d + offset_y;
int32_t x = ox * scale_x_d + offset_x;
- float fy = static_cast<float>(y) / static_cast<float>(scale_y_n);
- float fx = static_cast<float>(x) / static_cast<float>(scale_x_n);
-
- int32_t iy = floor(fy);
- int32_t ix = floor(fx);
-
+ int32_t iy;
+ int32_t ix;
resize_t dy;
resize_t dx;
- if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) ||
- (typeid(resize_t) == typeid(half_float::half)))
+ if (std::is_same<resize_t, double>::value)
{
- dy = (resize_t)(fy - iy);
- dx = (resize_t)(fx - ix);
+ const double fy_double = static_cast<double>(y) / static_cast<double>(scale_y_n);
+ const double fx_double = static_cast<double>(x) / static_cast<double>(scale_x_n);
+ iy = floor(fy_double);
+ ix = floor(fx_double);
+
+ dy = (resize_t)(fy_double - iy);
+ dx = (resize_t)(fx_double - ix);
}
else
{
- dy = (resize_t)(y - (iy * scale_y_n));
- dx = (resize_t)(x - (ix * scale_x_n));
+ const float fy = static_cast<float>(y) / static_cast<float>(scale_y_n);
+ const float fx = static_cast<float>(x) / static_cast<float>(scale_x_n);
+ iy = floor(fy);
+ ix = floor(fx);
+
+ if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) ||
+ (typeid(resize_t) == typeid(half_float::half)))
+ {
+ dy = (resize_t)(fy - iy);
+ dx = (resize_t)(fx - ix);
+ }
+ else
+ {
+ dy = (resize_t)(y - (iy * scale_y_n));
+ dx = (resize_t)(x - (ix * scale_x_n));
+ }
}
int32_t iy0 = MAX(iy, 0);
@@ -248,3 +264,4 @@ DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, INT16, INT16, int16_t);
DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP16, FP16, half_float::half);
DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, BF16, BF16, Eigen::bfloat16);
DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP32, FP32, float);
+DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP64, FP64, double);
diff --git a/reference_model/src/ops/image.h b/reference_model/src/ops/image.h
index 508d2c8..6d5a418 100644
--- a/reference_model/src/ops/image.h
+++ b/reference_model/src/ops/image.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -23,7 +23,7 @@ using namespace tosa;
namespace TosaReference
{
-template <DType InDtype, DType OutDtype, typename resize_t>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t>
class OpResize : public GraphNode
{
public:
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 1db3974..0a78884 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -37,11 +37,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
Op opType,
TosaAttributeBase* attribute,
uint64_t id,
- DType inputDType,
+ TOSA_REF_TYPE inputDTYPE,
int inputRank,
- DType outputDType,
+ TOSA_REF_TYPE outputDTYPE,
int outputRank,
- DType weightDType,
+ TOSA_REF_TYPE weightDTYPE,
int weightRank)
{
switch (opType)
@@ -53,6 +53,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64);
break;
case Op_AVG_POOL2D:
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16);
@@ -61,6 +62,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP32, FP32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP64, FP64);
break;
case Op_CONV2D:
DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
@@ -70,6 +72,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
break;
case Op_CONV3D:
DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
@@ -79,6 +82,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
break;
case Op_DEPTHWISE_CONV2D:
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
@@ -88,9 +92,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
break;
case Op_FFT2D:
DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32);
+ DEF_FACTORY_ONE_TYPE(OpFFT2d, FP64);
break;
case Op_FULLY_CONNECTED:
DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
@@ -100,6 +106,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32);
DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP64, FP64, FP64);
break;
case Op_MATMUL:
DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP16, FP16);
@@ -108,6 +115,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP32, FP32);
DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT8, INT32);
DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT16, INT48);
+ DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP64, FP64);
break;
case Op_MAX_POOL2D:
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16);
@@ -115,9 +123,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP32);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16);
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP64);
break;
case Op_RFFT2D:
DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP32);
+ DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP64);
break;
case Op_TRANSPOSE_CONV2D:
DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
@@ -127,6 +137,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);
break;
// activation_funcs
@@ -136,16 +147,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP64);
break;
case Op_SIGMOID:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP64);
break;
case Op_TANH:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP64);
break;
// ewise_binary
@@ -154,6 +168,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP64);
break;
case Op_ARITHMETIC_RIGHT_SHIFT:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
@@ -202,12 +217,14 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP64);
break;
case Op_MINIMUM:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP64);
break;
case Op_MUL:
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16);
@@ -216,17 +233,20 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP64, FP64);
break;
case Op_POW:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP64);
break;
case Op_SUB:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP64);
break;
case Op_TABLE:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
@@ -239,6 +259,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP64);
break;
case Op_BITWISE_NOT:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT8);
@@ -249,6 +270,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP64);
break;
case Op_CLZ:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
@@ -257,16 +279,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP64);
break;
case Op_FLOOR:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP64);
break;
case Op_LOG:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP64);
break;
case Op_LOGICAL_NOT:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL);
@@ -278,16 +303,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP64);
break;
case Op_RECIPROCAL:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP64);
break;
case Op_RSQRT:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP64);
break;
// ewise_ternary
@@ -299,6 +327,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP64);
break;
// comparison
@@ -307,18 +336,21 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP64);
break;
case Op_GREATER:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP64);
break;
case Op_GREATER_EQUAL:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP64);
break;
// reduction
@@ -335,6 +367,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP64);
break;
case Op_REDUCE_MIN:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16);
@@ -343,16 +376,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP64);
break;
case Op_REDUCE_PRODUCT:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProductDouble, FP64);
break;
case Op_REDUCE_SUM:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumDouble, FP64);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);
break;
@@ -365,6 +401,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64);
break;
case Op_PAD:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
@@ -374,6 +411,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64);
break;
case Op_RESHAPE:
DEF_FACTORY_RESHAPE(OpReshape, FP16);
@@ -383,6 +421,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RESHAPE(OpReshape, INT16);
DEF_FACTORY_RESHAPE(OpReshape, INT32);
DEF_FACTORY_RESHAPE(OpReshape, BOOL);
+ DEF_FACTORY_RESHAPE(OpReshape, FP64);
break;
case Op_REVERSE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
@@ -392,6 +431,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64);
break;
case Op_SLICE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
@@ -401,6 +441,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64);
break;
case Op_TILE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
@@ -410,6 +451,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64);
break;
case Op_TRANSPOSE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
@@ -419,6 +461,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64);
break;
// scatter_gather
@@ -429,6 +472,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpGather, FP16);
DEF_FACTORY_ONE_TYPE(OpGather, BF16);
DEF_FACTORY_ONE_TYPE(OpGather, FP32);
+ DEF_FACTORY_ONE_TYPE(OpGather, FP64);
break;
case Op_SCATTER:
DEF_FACTORY_ONE_TYPE(OpScatter, INT8);
@@ -437,6 +481,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpScatter, FP16);
DEF_FACTORY_ONE_TYPE(OpScatter, BF16);
DEF_FACTORY_ONE_TYPE(OpScatter, FP32);
+ DEF_FACTORY_ONE_TYPE(OpScatter, FP64);
break;
// image
@@ -448,6 +493,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OpResize, FP16, FP16);
DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OpResize, BF16, BF16);
DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OpResize, FP32, FP32);
+ DEF_FACTORY_TWO_TYPE_RESIZE_FP64(OpResize, FP64, FP64);
break;
// data_nodes
@@ -461,6 +507,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP64);
break;
// type_conversion
@@ -499,6 +546,13 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64);
break;
case Op_RESCALE:
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h
index 9117df4..f276e03 100644
--- a/reference_model/src/ops/op_factory.h
+++ b/reference_model/src/ops/op_factory.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -23,19 +23,19 @@
#define DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) \
case RANK: \
- return new OP<RANK, DType_##DTYPE>(sgt, attribute, id);
+ return new OP<RANK, TOSA_REF_TYPE_##DTYPE>(sgt, attribute, id);
#define DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \
case RANK: \
- return new OP<RANK, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id);
+ return new OP<RANK, TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>(sgt, attribute, id);
#define DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \
case RANK2: \
- return new OP<RANK1, RANK2, DType_##DTYPE>(sgt, attribute, id);
+ return new OP<RANK1, RANK2, TOSA_REF_TYPE_##DTYPE>(sgt, attribute, id);
#define DEF_FACTORY_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \
case RANK2: \
- return new OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id);
+ return new OP<RANK1, RANK2, TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>(sgt, attribute, id);
#define DEF_FACTORY_ONE_RANK_0_6(OP) \
switch (inputRank) \
@@ -57,40 +57,42 @@
}
#define DEF_FACTORY_ONE_TYPE(OP, DTYPE) \
- if (inputDType == DType_##DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE) \
{ \
- return new OP<DType_##DTYPE>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE>(sgt, attribute, id); \
}
#define DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE, ACCUM_DTYPE) \
- if (inputDType == DType_##DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == TOSA_REF_TYPE_##ACCUM_DTYPE) \
{ \
- return new OP<DType_##DTYPE, DType_##ACCUM_DTYPE>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE, TOSA_REF_TYPE_##ACCUM_DTYPE>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && weightDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE_IN_OUT(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE1, DTYPE2, ACCUM_DTYPE) \
- if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 \
- && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && weightDTYPE == TOSA_REF_TYPE_##DTYPE2 && \
+ ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == TOSA_REF_TYPE_##ACCUM_DTYPE) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, DType_##ACCUM_DTYPE>(sgt, attribute, id); \
- } \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##ACCUM_DTYPE>(sgt, attribute, \
+ id); \
+ }
#define DEF_FACTORY_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3) \
- if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 && outputDType == DType_##DTYPE3) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && weightDTYPE == TOSA_REF_TYPE_##DTYPE2 && \
+ outputDTYPE == TOSA_REF_TYPE_##DTYPE3) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, DType_##DTYPE3>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##DTYPE3>(sgt, attribute, id); \
}
// Statement-expression to evaluate accumulate attribute in-place
@@ -108,35 +110,41 @@
FATAL_ERROR("Can't initialize Tosa" #ATTRIBUTE_NAME "Attribute.\nPre-initialization " \
"of this attribute is required in order to determine the accumulate type."); \
} \
- accumDType; \
- }) \
+ ConvertDType(accumDType); \
+ })
#define DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, int16_t>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, int16_t>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, half_float::half>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, half_float::half>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
+ { \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, Eigen::bfloat16>(sgt, attribute, id); \
+ }
+
+#define DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OP, DTYPE1, DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, Eigen::bfloat16>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, float>(sgt, attribute, id); \
}
-#define DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+#define DEF_FACTORY_TWO_TYPE_RESIZE_FP64(OP, DTYPE1, DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, float>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, double>(sgt, attribute, id); \
}
#define DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
- if (inputDType == DType_##DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE) \
{ \
switch (inputRank) \
{ \
@@ -151,7 +159,7 @@
}
#define DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
- if (inputDType == DType_##DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE) \
{ \
switch (inputRank) \
{ \
@@ -165,7 +173,7 @@
}
#define DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
switch (inputRank) \
{ \
@@ -180,7 +188,7 @@
}
#define DEF_FACTORY_RESHAPE(OP, DTYPE) \
- if (inputDType == DType_##DTYPE && outputDType == DType_##DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE && outputDTYPE == TOSA_REF_TYPE_##DTYPE) \
{ \
switch (inputRank) \
{ \
@@ -292,11 +300,11 @@ public:
tosa::Op opType,
tosa::TosaAttributeBase* attribute,
uint64_t id,
- tosa::DType inputDType,
+ TOSA_REF_TYPE inputDTYPE,
int inputRank,
- tosa::DType outputDType,
+ TOSA_REF_TYPE outputDTYPE,
int outputRank,
- tosa::DType weightDType,
+ TOSA_REF_TYPE weightDTYPE,
int weightRank);
};
}; // namespace TosaReference
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc
index cd9d55f..bf8ba57 100644
--- a/reference_model/src/ops/reduction.cc
+++ b/reference_model/src/ops/reduction.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -20,7 +20,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
ReduceNode<Rank, Dtype>::ReduceNode(SubgraphTraverser* sgt_, const Op& op_, TosaAttributeBase* attribute_, uint64_t id_)
: GraphNode(sgt_, op_, id_)
{
@@ -30,14 +30,14 @@ ReduceNode<Rank, Dtype>::ReduceNode(SubgraphTraverser* sgt_, const Op& op_, Tosa
INIT_ATTRIBUTE(Axis);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
ReduceNode<Rank, Dtype>::~ReduceNode()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int ReduceNode<Rank, Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -100,7 +100,7 @@ struct AnyReducer {
bool finalize(const bool accum) const { return accum; }
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceAll<Rank, Dtype>::eval()
{
this->out->getTensor() = this->in->getTensor().reduce(this->dims, AllReducer()).reshape(this->out->getTensor().dimensions());
@@ -108,7 +108,7 @@ int OpReduceAll<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceAny<Rank, Dtype>::eval()
{
this->out->getTensor() = this->in->getTensor().reduce(this->dims, AnyReducer()).reshape(this->out->getTensor().dimensions());
@@ -116,7 +116,7 @@ int OpReduceAny<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceMax<Rank, Dtype>::eval()
{
this->out->getTensor() = this->in->getTensor().maximum(this->dims).reshape(this->out->getTensor().dimensions());
@@ -124,7 +124,7 @@ int OpReduceMax<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceMin<Rank, Dtype>::eval()
{
this->out->getTensor() = this->in->getTensor().minimum(this->dims).reshape(this->out->getTensor().dimensions());
@@ -132,35 +132,74 @@ int OpReduceMin<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceProduct<Rank, Dtype>::eval()
{
switch(Dtype)
{
- case DType_FP16:
- case DType_BF16:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);});
break;
- default:
+ case TOSA_REF_TYPE_FP32:
this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions());
break;
+ default:
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
+ }
+
+ return GraphNode::eval();
+}
+
+struct ProductDoubleReducer
+{
+ static const bool PacketAccess = false;
+ void reduce(const double val, double* accum)
+ {
+ *accum *= val;
+ }
+ double initialize() const
+ {
+ return 1.0;
+ }
+ double finalize(const double accum) const
+ {
+ return accum;
+ }
+};
+
+template <int Rank, TOSA_REF_TYPE Dtype>
+int OpReduceProductDouble<Rank, Dtype>::eval()
+{
+ switch (Dtype)
+ {
+ case TOSA_REF_TYPE_FP64:
+ this->out->getTensor() = this->in->getTensor()
+ .reduce(this->dims, ProductDoubleReducer())
+ .reshape(this->out->getTensor().dimensions());
+ break;
+ default:
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceSum<Rank, Dtype>::eval()
{
switch(Dtype)
{
- case DType_FP16:
- case DType_BF16:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);});
break;
- default:
+ case TOSA_REF_TYPE_FP32:
+ case TOSA_REF_TYPE_INT32:
this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions());
break;
+ default:
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return GraphNode::eval();
@@ -183,7 +222,7 @@ struct SumRequiresReducer {
SubgraphTraverser* parent_sgt;
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceSumInt<Rank, Dtype>::eval()
{
this->out->getTensor() = this->in->getTensor().reduce(this->dims, SumRequiresReducer(this->parent_sgt)).reshape(this->out->getTensor().dimensions());
@@ -191,6 +230,40 @@ int OpReduceSumInt<Rank, Dtype>::eval()
return GraphNode::eval();
}
+struct SumDoubleReducer
+{
+ static const bool PacketAccess = false;
+ void reduce(const double val, double* accum)
+ {
+ *accum += val;
+ }
+ double initialize() const
+ {
+ return 0.0;
+ }
+ double finalize(const double accum) const
+ {
+ return accum;
+ }
+};
+
+template <int Rank, TOSA_REF_TYPE Dtype>
+int OpReduceSumDouble<Rank, Dtype>::eval()
+{
+ switch (Dtype)
+ {
+ case TOSA_REF_TYPE_FP64:
+ this->out->getTensor() = this->in->getTensor()
+ .reduce(this->dims, SumDoubleReducer())
+ .reshape(this->out->getTensor().dimensions());
+ break;
+ default:
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
+ }
+
+ return GraphNode::eval();
+}
+
// template explicit instantiation
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL);
@@ -202,6 +275,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16);
@@ -209,12 +283,15 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProductDouble, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumDouble, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);
diff --git a/reference_model/src/ops/reduction.h b/reference_model/src/ops/reduction.h
index 6e98a76..aeb9f1d 100644
--- a/reference_model/src/ops/reduction.h
+++ b/reference_model/src/ops/reduction.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -23,7 +23,7 @@ using namespace tosa;
namespace TosaReference
{
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class ReduceNode : public GraphNode
{
public:
@@ -44,7 +44,7 @@ protected:
TosaAxisAttribute* attribute;
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpReduceAll : public ReduceNode<Rank, Dtype>
{
public:
@@ -54,7 +54,7 @@ public:
virtual int eval();
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpReduceAny : public ReduceNode<Rank, Dtype>
{
public:
@@ -64,7 +64,7 @@ public:
virtual int eval();
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpReduceMax : public ReduceNode<Rank, Dtype>
{
public:
@@ -74,7 +74,7 @@ public:
virtual int eval();
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpReduceMin : public ReduceNode<Rank, Dtype>
{
public:
@@ -84,7 +84,7 @@ public:
virtual int eval();
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpReduceProduct : public ReduceNode<Rank, Dtype>
{
public:
@@ -94,7 +94,17 @@ public:
virtual int eval();
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
+class OpReduceProductDouble : public ReduceNode<Rank, Dtype>
+{
+public:
+ OpReduceProductDouble(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_PRODUCT, attribute_, id_)
+ {}
+ virtual int eval();
+};
+
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpReduceSum : public ReduceNode<Rank, Dtype>
{
public:
@@ -104,7 +114,7 @@ public:
virtual int eval();
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpReduceSumInt : public ReduceNode<Rank, Dtype>
{
public:
@@ -114,6 +124,16 @@ public:
virtual int eval();
};
+template <int Rank, TOSA_REF_TYPE Dtype>
+class OpReduceSumDouble : public ReduceNode<Rank, Dtype>
+{
+public:
+ OpReduceSumDouble(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_SUM, attribute_, id_)
+ {}
+ virtual int eval();
+};
+
}; // namespace TosaReference
#endif
diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc
index bcd8ce5..80b6c58 100644
--- a/reference_model/src/ops/scatter_gather.cc
+++ b/reference_model/src/ops/scatter_gather.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -20,7 +20,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
OpGather<Dtype>::OpGather(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -29,11 +29,11 @@ OpGather<Dtype>::OpGather(SubgraphTraverser* sgt_,
setRequiredOperands(2, 1);
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
OpGather<Dtype>::~OpGather()
{}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpGather<Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -96,7 +96,7 @@ int OpGather<Dtype>::checkTensorAttributes()
return 0;
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpGather<Dtype>::eval()
{
for (int32_t n = 0; n < N; n++)
@@ -116,7 +116,7 @@ int OpGather<Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
OpScatter<Dtype>::OpScatter(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -125,11 +125,11 @@ OpScatter<Dtype>::OpScatter(SubgraphTraverser* sgt_,
setRequiredOperands(3, 1);
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
OpScatter<Dtype>::~OpScatter()
{}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpScatter<Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -199,7 +199,7 @@ int OpScatter<Dtype>::checkTensorAttributes()
return 0;
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpScatter<Dtype>::eval()
{
// Initializes the output tensor with the input value for values that are unchanged by the scatter operation.
@@ -229,6 +229,7 @@ DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32);
DEF_INSTANTIATE_ONE_TYPE(OpGather, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpGather, BF16);
DEF_INSTANTIATE_ONE_TYPE(OpGather, FP32);
+DEF_INSTANTIATE_ONE_TYPE(OpGather, FP64);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16);
@@ -236,3 +237,4 @@ DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, BF16);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP32);
+DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP64);
diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h
index af09153..fb675a9 100644
--- a/reference_model/src/ops/scatter_gather.h
+++ b/reference_model/src/ops/scatter_gather.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -23,7 +23,7 @@ using namespace tosa;
namespace TosaReference
{
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
class OpGather : public GraphNode
{
public:
@@ -45,7 +45,7 @@ protected:
TosaReference::TensorTemplate<TOutput>* output;
};
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
class OpScatter : public GraphNode
{
public:
diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h
index ece14b1..6dd6e76 100644
--- a/reference_model/src/ops/template_types.h
+++ b/reference_model/src/ops/template_types.h
@@ -16,11 +16,10 @@
#ifndef OP_TEMPLATE_TYPES_H
#define OP_TEMPLATE_TYPES_H
-#include "tosa_generated.h"
-#include <Eigen/CXX11/Tensor>
+#include "dtype.h"
#include "half.hpp"
+#include <Eigen/CXX11/Tensor>
#include <Eigen/Core>
-#include "arith_util.h"
using namespace tosa;
@@ -64,213 +63,218 @@ using Tensor5 = TensorTemplate<ETensor5<T>>;
template <typename T>
using Tensor6 = TensorTemplate<ETensor6<T>>;
-template <DType type>
+template <TOSA_REF_TYPE type>
struct GetEigenType;
template <>
-struct GetEigenType<DType_FP32>
+struct GetEigenType<TOSA_REF_TYPE_FP64>
+{
+ using type = double;
+};
+template <>
+struct GetEigenType<TOSA_REF_TYPE_FP32>
{
using type = float;
};
template <>
-struct GetEigenType<DType_FP16>
+struct GetEigenType<TOSA_REF_TYPE_FP16>
{
// NOTE: full precision used
using type = float;
};
template <>
-struct GetEigenType<DType_BF16>
+struct GetEigenType<TOSA_REF_TYPE_BF16>
{
// NOTE: full precision used
using type = float;
};
template <>
-struct GetEigenType<DType_INT32>
+struct GetEigenType<TOSA_REF_TYPE_INT32>
{
using type = int32_t;
};
template <>
-struct GetEigenType<DType_INT48>
+struct GetEigenType<TOSA_REF_TYPE_INT48>
{
using type = int64_t;
};
template <>
-struct GetEigenType<DType_BOOL>
+struct GetEigenType<TOSA_REF_TYPE_BOOL>
{
using type = bool;
};
template <>
-struct GetEigenType<DType_UINT8>
+struct GetEigenType<TOSA_REF_TYPE_UINT8>
{
using type = int32_t;
};
template <>
-struct GetEigenType<DType_UINT16>
+struct GetEigenType<TOSA_REF_TYPE_UINT16>
{
using type = int32_t;
};
template <>
-struct GetEigenType<DType_INT4>
+struct GetEigenType<TOSA_REF_TYPE_INT4>
{
using type = int32_t;
};
template <>
-struct GetEigenType<DType_INT8>
+struct GetEigenType<TOSA_REF_TYPE_INT8>
{
using type = int32_t;
};
template <>
-struct GetEigenType<DType_INT16>
+struct GetEigenType<TOSA_REF_TYPE_INT16>
{
using type = int32_t;
};
/* Get Accumulate Eigen Type:
-Same behaviour as GetEigenType for all DTypes except the
-single specialised case of DType_FP16. */
-template <DType Dtype>
+Same behaviour as GetEigenType for all DTYPEs except the
+single specialised case of TOSA_REF_TYPE_FP16. */
+template <TOSA_REF_TYPE Dtype>
struct GetAccEigenType;
template <>
-struct GetAccEigenType<DType_FP16>
+struct GetAccEigenType<TOSA_REF_TYPE_FP16>
{
using type = half_float::half;
};
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
struct GetAccEigenType
{
using type = typename GetEigenType<Dtype>::type;
};
// Meta function to get number of bits
-template <DType T>
+template <TOSA_REF_TYPE T>
struct GetNumBits
{
static constexpr int32_t value = 0;
};
template <>
-struct GetNumBits<DType_BOOL>
+struct GetNumBits<TOSA_REF_TYPE_BOOL>
{
static constexpr int32_t value = 1;
};
template <>
-struct GetNumBits<DType_UINT8>
+struct GetNumBits<TOSA_REF_TYPE_UINT8>
{
static constexpr int32_t value = 8;
};
template <>
-struct GetNumBits<DType_UINT16>
+struct GetNumBits<TOSA_REF_TYPE_UINT16>
{
static constexpr int32_t value = 16;
};
template <>
-struct GetNumBits<DType_INT4>
+struct GetNumBits<TOSA_REF_TYPE_INT4>
{
static constexpr int32_t value = 4;
};
template <>
-struct GetNumBits<DType_INT8>
+struct GetNumBits<TOSA_REF_TYPE_INT8>
{
static constexpr int32_t value = 8;
};
template <>
-struct GetNumBits<DType_INT16>
+struct GetNumBits<TOSA_REF_TYPE_INT16>
{
static constexpr int32_t value = 16;
};
template <>
-struct GetNumBits<DType_INT32>
+struct GetNumBits<TOSA_REF_TYPE_INT32>
{
static constexpr int32_t value = 32;
};
template <>
-struct GetNumBits<DType_INT48>
+struct GetNumBits<TOSA_REF_TYPE_INT48>
{
static constexpr int32_t value = 48;
};
template <>
-struct GetNumBits<DType_FP16>
+struct GetNumBits<TOSA_REF_TYPE_FP16>
{
static constexpr int32_t value = 16;
};
// Meta function to get quantized min/max in compile time
-template <DType T>
+template <TOSA_REF_TYPE T>
struct GetQMin
{
static constexpr int64_t value = INT64_C(0);
};
template <>
-struct GetQMin<DType_UINT8>
+struct GetQMin<TOSA_REF_TYPE_UINT8>
{
static constexpr int64_t value = INT64_C(0);
};
template <>
-struct GetQMin<DType_UINT16>
+struct GetQMin<TOSA_REF_TYPE_UINT16>
{
static constexpr int64_t value = INT64_C(0);
};
template <>
-struct GetQMin<DType_INT4>
+struct GetQMin<TOSA_REF_TYPE_INT4>
{
static constexpr int64_t value = INT64_C(-8);
};
template <>
-struct GetQMin<DType_INT8>
+struct GetQMin<TOSA_REF_TYPE_INT8>
{
static constexpr int64_t value = INT64_C(-128);
};
template <>
-struct GetQMin<DType_INT16>
+struct GetQMin<TOSA_REF_TYPE_INT16>
{
static constexpr int64_t value = INT64_C(-32768);
};
template <>
-struct GetQMin<DType_INT32>
+struct GetQMin<TOSA_REF_TYPE_INT32>
{
static constexpr int64_t value = -(INT64_C(1) << 31);
};
template <>
-struct GetQMin<DType_INT48>
+struct GetQMin<TOSA_REF_TYPE_INT48>
{
static constexpr int64_t value = -(INT64_C(1) << 47);
};
-template <DType T>
+template <TOSA_REF_TYPE T>
struct GetQMax
{
static constexpr int64_t value = INT64_C(0);
};
template <>
-struct GetQMax<DType_UINT8>
+struct GetQMax<TOSA_REF_TYPE_UINT8>
{
static constexpr int64_t value = INT64_C(255);
};
template <>
-struct GetQMax<DType_UINT16>
+struct GetQMax<TOSA_REF_TYPE_UINT16>
{
static constexpr int64_t value = INT64_C(65535);
};
template <>
-struct GetQMax<DType_INT4>
+struct GetQMax<TOSA_REF_TYPE_INT4>
{
static constexpr int64_t value = INT64_C(7);
};
template <>
-struct GetQMax<DType_INT8>
+struct GetQMax<TOSA_REF_TYPE_INT8>
{
static constexpr int64_t value = INT64_C(127);
};
template <>
-struct GetQMax<DType_INT16>
+struct GetQMax<TOSA_REF_TYPE_INT16>
{
static constexpr int64_t value = INT64_C(32767);
};
template <>
-struct GetQMax<DType_INT32>
+struct GetQMax<TOSA_REF_TYPE_INT32>
{
static constexpr int64_t value = (INT64_C(1) << 31) - 1;
};
template <>
-struct GetQMax<DType_INT48>
+struct GetQMax<TOSA_REF_TYPE_INT48>
{
static constexpr int64_t value = (INT64_C(1) << 47) - 1;
};
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index b3845df..f8fd323 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -116,14 +116,14 @@ int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute,
}
int check_conv_attribute(tosa::TosaConvAttribute* attribute,
- uint32_t conv_dimension,
- std::vector<int32_t> input_shape,
- std::vector<int32_t> output_shape,
- std::vector<int32_t> weights,
- uint32_t offset_kernel,
- DType InDtype,
- DType WeightDtype,
- std::string& msg)
+ uint32_t conv_dimension,
+ std::vector<int32_t> input_shape,
+ std::vector<int32_t> output_shape,
+ std::vector<int32_t> weights,
+ uint32_t offset_kernel,
+ TOSA_REF_TYPE InDtype,
+ TOSA_REF_TYPE WeightDtype,
+ std::string& msg)
{
if (attribute->pad().size() != (2 * conv_dimension))
{
@@ -226,11 +226,13 @@ int check_conv_attribute(tosa::TosaConvAttribute* attribute,
return 1;
}
- if (InDtype != DType_INT8 && attribute->input_zp() != 0) {
+ if (InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0)
+ {
msg = "Input zero point must be zero for non-int8 data";
return 1;
}
- if (WeightDtype != DType_INT8 && attribute->weight_zp() != 0) {
+ if (WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0)
+ {
msg = "Weight zero point must be zero for non-int8 data";
return 1;
}
@@ -318,7 +320,7 @@ int check_fft_shape(const std::vector<int32_t>& in_real,
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -330,14 +332,14 @@ OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Axis);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpArgMax<Rank, Dtype>::~OpArgMax()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpArgMax<Rank, Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -355,7 +357,7 @@ int OpArgMax<Rank, Dtype>::checkTensorAttributes()
return 1;
}
- if (outputs[0]->getDtype() != DType_INT32)
+ if (outputs[0]->getDtype() != TOSA_REF_TYPE_INT32)
{
printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator");
return 1;
@@ -400,7 +402,7 @@ int OpArgMax<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpArgMax<Rank, Dtype>::eval()
{
Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis());
@@ -410,7 +412,7 @@ int OpArgMax<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype, DType AccDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -422,14 +424,14 @@ OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Pool);
}
-template <DType Dtype, DType AccDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
OpAvgPool2d<Dtype, AccDtype>::~OpAvgPool2d()
{
if (attribute)
delete attribute;
}
-template <DType Dtype, DType AccDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -449,8 +451,10 @@ int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes()
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
- ERROR_IF(Dtype != DType_INT8 && attribute->input_zp() != 0, "OpAvgPool2d: Input zeropoint must be zero for non int8_t data");
- ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpAvgPool2d: Output zeropoint must be zero for non int8_t data");
+ ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
+ "OpAvgPool2d: Input zeropoint must be zero for non int8_t data");
+ ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->output_zp() != 0,
+ "OpAvgPool2d: Output zeropoint must be zero for non int8_t data");
std::string msg;
if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
@@ -466,8 +470,9 @@ int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes()
// This calculates the number of padding elements used for each location along an axis
// Average pooling only divides by the number of elements used, not including padding.
// This function uses left/right, but is also used for vertical padding with top/bottom
-template <DType Dtype, DType AccDtype>
-ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right)
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
+ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d(
+ int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right)
{
ETensor1<int32_t> result(out_size);
@@ -495,7 +500,7 @@ ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d(int in_size
// assuming input and output tensor have same scales like tflite reference
// so no need to scale input and output
-template <DType Dtype, DType AccDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
int OpAvgPool2d<Dtype, AccDtype>::eval()
{
int in_batch = this->in->getShape()[0];
@@ -531,7 +536,7 @@ int OpAvgPool2d<Dtype, AccDtype>::eval()
LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
- tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
+ TOSA_REF_TYPE accum_dtype = ConvertDType(this->attribute->accum_dtype());
DEBUG_INFO(OP,
"perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
@@ -556,7 +561,7 @@ int OpAvgPool2d<Dtype, AccDtype>::eval()
pad[3] = std::make_pair(0, 0);
ETensor4<InEigenType> input_val = this->in->getTensor();
- if (Dtype == DType_INT8)
+ if (Dtype == TOSA_REF_TYPE_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
}
@@ -604,7 +609,8 @@ int OpAvgPool2d<Dtype, AccDtype>::eval()
dm2_h.contract(dm2_w, contract_dims)
.reshape(Eigen::array<Eigen::Index, 4>{ 1, out_height, out_width, 1 })
.broadcast(bcast);
- if (Dtype != DType_FP32 && Dtype != DType_FP16 && Dtype != DType_BF16)
+ if (Dtype != TOSA_REF_TYPE_FP32 && Dtype != TOSA_REF_TYPE_FP16 && Dtype != TOSA_REF_TYPE_BF16 &&
+ Dtype != TOSA_REF_TYPE_FP64)
{
try
{
@@ -632,7 +638,7 @@ int OpAvgPool2d<Dtype, AccDtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -644,14 +650,14 @@ OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Conv);
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpConv2d<InDtype, WeightDtype, OutDtype>::~OpConv2d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -688,7 +694,7 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
{
int in_batch = this->input->getShape()[0];
@@ -781,7 +787,7 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
+ if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -817,7 +823,7 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
// reshape back to [N, H, W, C]
this->output->getTensor() = biased_output.reshape(col2im_output_dims);
- if (OutDtype == DType_INT48)
+ if (OutDtype == TOSA_REF_TYPE_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -826,7 +832,7 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -838,14 +844,14 @@ OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Conv);
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpConv3d<InDtype, WeightDtype, OutDtype>::~OpConv3d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -882,7 +888,7 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
{
int in_batch = this->input->getShape()[0];
@@ -959,7 +965,7 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
+ if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -1020,7 +1026,7 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
}
}
- if (OutDtype == DType_INT48)
+ if (OutDtype == TOSA_REF_TYPE_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1029,10 +1035,10 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
- TosaAttributeBase* attribute_,
- uint64_t id_)
+ TosaAttributeBase* attribute_,
+ uint64_t id_)
: GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
{
setRequiredOperands(3, 1);
@@ -1041,14 +1047,14 @@ OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTra
INIT_ATTRIBUTE(Conv);
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::~OpDepthwiseConv2d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -1085,7 +1091,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
{
int in_batch = this->input->getShape()[0];
@@ -1149,7 +1155,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
+ if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -1205,7 +1211,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
}
}
- if (OutDtype == DType_INT48)
+ if (OutDtype == TOSA_REF_TYPE_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1214,10 +1220,10 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
- TosaAttributeBase* attribute_,
- uint64_t id_)
+ TosaAttributeBase* attribute_,
+ uint64_t id_)
: GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
{
setRequiredOperands(3, 1);
@@ -1226,14 +1232,14 @@ OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTrave
INIT_ATTRIBUTE(FullyConnected);
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpFullyConnected<InDtype, WeightDtype, OutDtype>::~OpFullyConnected()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -1265,13 +1271,15 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
- ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpFullyConnected: Input zeropoint must be zero for non int8_t data");
- ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpFullyConnected: Weight zeropoint must be zero for non int8_t data");
+ ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
+ "OpFullyConnected: Input zeropoint must be zero for non int8_t data");
+ ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0,
+ "OpFullyConnected: Weight zeropoint must be zero for non int8_t data");
return 0;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
{
typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
@@ -1289,7 +1297,7 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
- if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
+ if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -1299,7 +1307,7 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
input_val.template cast<AccEigenType>().contract(weight_val.template cast<AccEigenType>(), dims).template cast<OutEigenType>() +
this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast);
- if (OutDtype == DType_INT48)
+ if (OutDtype == TOSA_REF_TYPE_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1307,7 +1315,7 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype, DType OutDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -1319,14 +1327,14 @@ OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(MatMul);
}
-template <DType Dtype, DType OutDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
OpMatMul<Dtype, OutDtype>::~OpMatMul()
{
if (attribute)
delete attribute;
}
-template <DType Dtype, DType OutDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
int OpMatMul<Dtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -1382,13 +1390,15 @@ int OpMatMul<Dtype, OutDtype>::checkTensorAttributes()
}
W = b->getShape()[2];
- ERROR_IF(Dtype != DType_INT8 && attribute->a_zp() != 0, "OpMatMul: A zeropoint must be zero for non int8_t data");
- ERROR_IF(Dtype != DType_INT8 && attribute->b_zp() != 0, "OpMatMul: B zeropoint must be zero for non int8_t data");
+ ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->a_zp() != 0,
+ "OpMatMul: A zeropoint must be zero for non int8_t data");
+ ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->b_zp() != 0,
+ "OpMatMul: B zeropoint must be zero for non int8_t data");
return 0;
}
-template <DType Dtype, DType OutDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
int OpMatMul<Dtype, OutDtype>::eval()
{
typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
@@ -1396,7 +1406,7 @@ int OpMatMul<Dtype, OutDtype>::eval()
TIn a_val = this->a->getTensor();
TIn b_val = this->b->getTensor();
- if (Dtype == DType_INT8)
+ if (Dtype == TOSA_REF_TYPE_INT8)
{
a_val = a_val - (InEigenType)attribute->a_zp();
b_val = b_val - (InEigenType)attribute->b_zp();
@@ -1434,7 +1444,7 @@ int OpMatMul<Dtype, OutDtype>::eval()
}
}
- if (OutDtype == DType_INT48)
+ if (OutDtype == TOSA_REF_TYPE_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1443,7 +1453,7 @@ int OpMatMul<Dtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -1455,14 +1465,14 @@ OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Pool);
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
OpMaxPool2d<Dtype>::~OpMaxPool2d()
{
if (attribute)
delete attribute;
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpMaxPool2d<Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -1493,7 +1503,7 @@ int OpMaxPool2d<Dtype>::checkTensorAttributes()
return 0;
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpMaxPool2d<Dtype>::eval()
{
int in_batch = this->in->getShape()[0];
@@ -1586,10 +1596,8 @@ int OpMaxPool2d<Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
-OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_,
- TosaAttributeBase* attribute_,
- uint64_t id_)
+template <TOSA_REF_TYPE Dtype>
+OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: GraphNode(sgt_, Op_FFT2D, id_)
{
setRequiredOperands(2, 2);
@@ -1598,14 +1606,14 @@ OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(FFT);
}
-template <DType Dtype>
-OpFFT2d<Dtype>::~OpFFT2d() {
+template <TOSA_REF_TYPE Dtype>
+OpFFT2d<Dtype>::~OpFFT2d()
+{
if (attribute)
delete attribute;
}
-
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpFFT2d<Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -1643,7 +1651,7 @@ int OpFFT2d<Dtype>::checkTensorAttributes()
return 0;
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpFFT2d<Dtype>::eval()
{
int in_real_batch = this->in_real->getShape()[0];
@@ -1709,7 +1717,7 @@ int OpFFT2d<Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -1719,11 +1727,11 @@ OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_,
setRequiredRank(3);
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
OpRFFT2d<Dtype>::~OpRFFT2d() {}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpRFFT2d<Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -1759,7 +1767,7 @@ int OpRFFT2d<Dtype>::checkTensorAttributes()
return 0;
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpRFFT2d<Dtype>::eval()
{
int32_t in_batch = in->getShape()[0];
@@ -1815,10 +1823,10 @@ int OpRFFT2d<Dtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
- TosaAttributeBase* attribute_,
- uint64_t id_)
+ TosaAttributeBase* attribute_,
+ uint64_t id_)
: GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
{
setRequiredOperands(3, 1);
@@ -1827,14 +1835,14 @@ OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTra
INIT_ATTRIBUTE(TransposeConv);
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::~OpTransposeConv2d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -1923,13 +1931,15 @@ int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
return 1;
}
- ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data");
- ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data");
+ ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
+ "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data");
+ ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0,
+ "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data");
return 0;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
{
int in_batch = this->input->getShape()[0];
@@ -1985,7 +1995,7 @@ int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
+ if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -2040,7 +2050,7 @@ int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
}
}
- if (OutDtype == DType_INT48)
+ if (OutDtype == TOSA_REF_TYPE_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -2055,6 +2065,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32);
@@ -2062,6 +2073,7 @@ DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, BF16, FP32);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32);
+DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP64, FP64);
// [in_t, weight_t, out_t]
DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
@@ -2071,6 +2083,7 @@ DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP32, FP32, FP32);
DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
@@ -2079,6 +2092,7 @@ DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP32, FP32, FP32);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
@@ -2087,8 +2101,10 @@ DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32);
+DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64);
DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32);
@@ -2097,6 +2113,7 @@ DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP32, FP32, FP32);
DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP64, FP64, FP64);
DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT8, INT32);
DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT16, INT48);
@@ -2104,14 +2121,17 @@ DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP16);
DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32);
DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32);
DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32);
+DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP64, FP64);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
+DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP64);
DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32);
+DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64);
DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32);
@@ -2120,3 +2140,4 @@ DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32);
DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index 9ef4a58..df53f2b 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -24,7 +24,7 @@ using namespace tosa;
namespace TosaReference
{
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpArgMax : public GraphNode
{
public:
@@ -35,7 +35,7 @@ public:
virtual int eval();
using InEigenType = typename GetEigenType<Dtype>::type;
- using OutEigenType = typename GetEigenType<DType_INT32>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_INT32>::type;
using TIn = Eigen::Tensor<InEigenType, Rank>;
using TOut = Eigen::Tensor<OutEigenType, Rank - 1>;
@@ -45,7 +45,7 @@ protected:
TosaReference::TensorTemplate<TOut>* output;
};
-template <DType Dtype, DType AccDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
class OpAvgPool2d : public GraphNode
{
public:
@@ -74,7 +74,7 @@ protected:
ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right);
};
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
class OpConv2d : public GraphNode
{
public:
@@ -104,7 +104,7 @@ protected:
tosa::TosaConvAttribute* attribute;
};
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
class OpConv3d : public GraphNode
{
public:
@@ -134,7 +134,7 @@ protected:
tosa::TosaConvAttribute* attribute;
};
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
class OpDepthwiseConv2d : public GraphNode
{
public:
@@ -164,7 +164,7 @@ protected:
tosa::TosaConvAttribute* attribute;
};
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
class OpFullyConnected : public GraphNode
{
public:
@@ -195,7 +195,7 @@ protected:
tosa::TosaFullyConnectedAttribute* attribute;
};
-template <DType Dtype, DType OutDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
class OpMatMul : public GraphNode
{
public:
@@ -227,7 +227,7 @@ protected:
tosa::TosaMatMulAttribute* attribute;
};
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
class OpMaxPool2d : public GraphNode
{
public:
@@ -248,7 +248,7 @@ protected:
tosa::TosaPoolAttribute* attribute;
};
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
class OpFFT2d : public GraphNode
{
public:
@@ -271,7 +271,7 @@ protected:
tosa::TosaFFTAttribute* attribute;
};
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
class OpRFFT2d : public GraphNode
{
public:
@@ -292,7 +292,7 @@ protected:
TosaReference::TensorTemplate<TOut>* out_imag;
};
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
class OpTransposeConv2d : public GraphNode
{
public:
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 9034add..68ffb1f 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -24,7 +24,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -35,14 +35,14 @@ OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Rescale);
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
OpRescale<Rank, InDtype, OutDtype>::~OpRescale()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -69,31 +69,33 @@ int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
ASSERT_MEM(in && out);
- if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (InDtype != DType_UINT16) && (attribute->input_zp() != 0))
+ if ((InDtype != TOSA_REF_TYPE_INT8) && (InDtype != TOSA_REF_TYPE_UINT8) && (InDtype != TOSA_REF_TYPE_UINT16) &&
+ (attribute->input_zp() != 0))
{
- printNodeValidationError("OpRescale: Input DType not INT8/UINT8/UINT16 and zero point not 0");
+ printNodeValidationError("OpRescale: Input TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0");
return 1;
}
- if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (OutDtype != DType_UINT16) && (attribute->output_zp() != 0))
+ if ((OutDtype != TOSA_REF_TYPE_INT8) && (OutDtype != TOSA_REF_TYPE_UINT8) && (OutDtype != TOSA_REF_TYPE_UINT16) &&
+ (attribute->output_zp() != 0))
{
- printNodeValidationError("OpRescale: Output DType not INT8/UINT8/UINT16 and zero point not 0");
+ printNodeValidationError("OpRescale: Output TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0");
return 1;
}
- if ((InDtype == DType_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768)))
+ if ((InDtype == TOSA_REF_TYPE_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768)))
{
- printNodeValidationError("OpRescale: Input DType UINT16 and zero point not 0 or 32768");
+ printNodeValidationError("OpRescale: Input TOSA_REF_TYPE UINT16 and zero point not 0 or 32768");
return 1;
}
- if ((OutDtype == DType_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768)))
+ if ((OutDtype == TOSA_REF_TYPE_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768)))
{
- printNodeValidationError("OpRescale: Output DType UINT16 and zero point not 0 or 32768");
+ printNodeValidationError("OpRescale: Output TOSA_REF_TYPE UINT16 and zero point not 0 or 32768");
return 1;
}
- if (attribute->scale32() && (InDtype == DType_INT48))
+ if (attribute->scale32() && (InDtype == TOSA_REF_TYPE_INT48))
{
printNodeValidationError("OpRescale: Scale set to true but input type is INT48");
return 1;
@@ -108,7 +110,7 @@ int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int OpRescale<Rank, InDtype, OutDtype>::eval()
{
int32_t input_zp = attribute->input_zp();
@@ -237,7 +239,7 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
OpCast<Rank, InDtype, OutDtype>::OpCast(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -247,11 +249,11 @@ OpCast<Rank, InDtype, OutDtype>::OpCast(SubgraphTraverser* sgt_,
setRequiredRank(0, 6);
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
OpCast<Rank, InDtype, OutDtype>::~OpCast()
{}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int OpCast<Rank, InDtype, OutDtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -281,7 +283,7 @@ int OpCast<Rank, InDtype, OutDtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int OpCast<Rank, InDtype, OutDtype>::eval()
{
this->out->getTensor() = this->in->getTensor().unaryExpr(cast_helper.get_fcn());
@@ -289,7 +291,7 @@ int OpCast<Rank, InDtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
CastHelper<InDtype, OutDtype>::CastHelper()
{
fcn = [](InEigenType in) -> OutEigenType {
@@ -298,14 +300,14 @@ CastHelper<InDtype, OutDtype>::CastHelper()
};
}
-template <DType InDtype>
-CastHelper<InDtype, DType_BOOL>::CastHelper()
+template <TOSA_REF_TYPE InDtype>
+CastHelper<InDtype, TOSA_REF_TYPE_BOOL>::CastHelper()
{
fcn = [](InEigenType in) -> bool { return (in != 0) ? true : false; };
}
-template <DType OutDtype>
-CastHelper<DType_BOOL, OutDtype>::CastHelper()
+template <TOSA_REF_TYPE OutDtype>
+CastHelper<TOSA_REF_TYPE_BOOL, OutDtype>::CastHelper()
{
fcn = [](bool in) -> OutEigenType {
OutEigenType out = in ? (OutEigenType)1 : (OutEigenType)0;
@@ -313,8 +315,8 @@ CastHelper<DType_BOOL, OutDtype>::CastHelper()
};
}
-template <DType InDtype>
-CastHelper<InDtype, DType_FP16>::CastHelper()
+template <TOSA_REF_TYPE InDtype>
+CastHelper<InDtype, TOSA_REF_TYPE_FP16>::CastHelper()
{
// Integer data converted to fp16 (stored as fp32)
fcn = [](InEigenType in) -> float {
@@ -324,17 +326,17 @@ CastHelper<InDtype, DType_FP16>::CastHelper()
};
}
-CastHelper<DType_FP32, DType_FP16>::CastHelper()
+CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP16>::CastHelper()
{
// fp32 data converted to fp16 (stored as fp32)
fcn = [](float in) -> float {
- float out = fpTrunc<DType_FP16>(in); // truncate required for conversion from higher precision
+ float out = fpTrunc<TOSA_REF_TYPE_FP16>(in); // truncate required for conversion from higher precision
return out;
};
}
-template <DType InDtype>
-CastHelper<InDtype, DType_BF16>::CastHelper()
+template <TOSA_REF_TYPE InDtype>
+CastHelper<InDtype, TOSA_REF_TYPE_BF16>::CastHelper()
{
// Integer data converted to bf16 (stored as fp32)
fcn = [](InEigenType in) -> float {
@@ -343,16 +345,16 @@ CastHelper<InDtype, DType_BF16>::CastHelper()
};
}
-CastHelper<DType_FP32, DType_BF16>::CastHelper()
+CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_BF16>::CastHelper()
{
// fp32 data converted to bf16 (stored as fp32)
fcn = [](float in) -> float {
- return fpTrunc<DType_BF16>(in); // truncate required for conversions from higher precision
+ return fpTrunc<TOSA_REF_TYPE_BF16>(in); // truncate required for conversions from higher precision
};
}
-template <DType OutDtype>
-CastHelper<DType_FP16, OutDtype>::CastHelper()
+template <TOSA_REF_TYPE OutDtype>
+CastHelper<TOSA_REF_TYPE_FP16, OutDtype>::CastHelper()
{
// fp16 data (stored as fp32) converted to integer
fcn = [](float in) -> OutEigenType {
@@ -366,7 +368,7 @@ CastHelper<DType_FP16, OutDtype>::CastHelper()
};
}
-CastHelper<DType_FP16, DType_FP32>::CastHelper()
+CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP32>::CastHelper()
{
// No-op since fp16 values treated internally as their fp32 representation
fcn = [](float in) -> OutEigenType {
@@ -374,8 +376,8 @@ CastHelper<DType_FP16, DType_FP32>::CastHelper()
};
}
-template <DType OutDtype>
-CastHelper<DType_BF16, OutDtype>::CastHelper()
+template <TOSA_REF_TYPE OutDtype>
+CastHelper<TOSA_REF_TYPE_BF16, OutDtype>::CastHelper()
{
// bf16 data (stored as fp32) converted to integer
fcn = [](float in) -> OutEigenType {
@@ -386,7 +388,7 @@ CastHelper<DType_BF16, OutDtype>::CastHelper()
};
}
-CastHelper<DType_BF16, DType_FP32>::CastHelper()
+CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP32>::CastHelper()
{
// No-op since bf16 values treated as truncated fp32 internally
fcn = [](InEigenType in) -> OutEigenType {
@@ -394,8 +396,8 @@ CastHelper<DType_BF16, DType_FP32>::CastHelper()
};
}
-template <DType InDtype>
-CastHelper<InDtype, DType_FP32>::CastHelper()
+template <TOSA_REF_TYPE InDtype>
+CastHelper<InDtype, TOSA_REF_TYPE_FP32>::CastHelper()
{
// Integer data converted to fp32
fcn = [](InEigenType in) -> float {
@@ -404,8 +406,8 @@ CastHelper<InDtype, DType_FP32>::CastHelper()
};
}
-template <DType OutDtype>
-CastHelper<DType_FP32, OutDtype>::CastHelper()
+template <TOSA_REF_TYPE OutDtype>
+CastHelper<TOSA_REF_TYPE_FP32, OutDtype>::CastHelper()
{
// fp32 data converted to integer
fcn = [](float in) -> OutEigenType {
@@ -416,6 +418,31 @@ CastHelper<DType_FP32, OutDtype>::CastHelper()
};
}
+template <TOSA_REF_TYPE OutDtype>
+CastHelper<TOSA_REF_TYPE_FP64, OutDtype>::CastHelper()
+{
+ switch (OutDtype)
+ {
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
+ case TOSA_REF_TYPE_INT32:
+ // fp64 data converted to integer
+ fcn = [](InEigenType in) -> OutEigenType {
+ OutEigenType out = std::rint(in);
+ out = std::max<OutEigenType>(out, OutMin);
+ out = std::min<OutEigenType>(out, OutMax);
+ return out;
+ };
+ break;
+ case TOSA_REF_TYPE_FP64:
+ // no op
+ fcn = [](InEigenType in) -> OutEigenType { return in; };
+ break;
+ default:
+ ASSERT_MSG(false, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(OutDtype));
+ }
+}
+
// template explicit instantiation
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16);
@@ -451,6 +478,13 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16);
diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h
index e2fc6e2..98799a0 100644
--- a/reference_model/src/ops/type_conversion.h
+++ b/reference_model/src/ops/type_conversion.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -22,7 +22,7 @@ using namespace tosa;
namespace TosaReference
{
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
class OpRescale : public GraphNode
{
public:
@@ -46,7 +46,7 @@ protected:
TosaReference::TensorTemplate<TOut>* out;
};
-template <DType InDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
class CastHelper
{
public:
@@ -64,12 +64,12 @@ private:
FcnType fcn;
};
-template <DType InDtype>
-class CastHelper<InDtype, DType_BOOL>
+template <TOSA_REF_TYPE InDtype>
+class CastHelper<InDtype, TOSA_REF_TYPE_BOOL>
{
public:
using InEigenType = typename GetEigenType<InDtype>::type;
- using OutEigenType = typename GetEigenType<DType_BOOL>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BOOL>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
CastHelper();
const FcnType& get_fcn() const
@@ -81,11 +81,11 @@ private:
FcnType fcn;
};
-template <DType OutDtype>
-class CastHelper<DType_BOOL, OutDtype>
+template <TOSA_REF_TYPE OutDtype>
+class CastHelper<TOSA_REF_TYPE_BOOL, OutDtype>
{
public:
- using InEigenType = typename GetEigenType<DType_BOOL>::type;
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_BOOL>::type;
using OutEigenType = typename GetEigenType<OutDtype>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
@@ -100,12 +100,12 @@ private:
FcnType fcn;
};
-template <DType InDtype>
-class CastHelper<InDtype, DType_FP16>
+template <TOSA_REF_TYPE InDtype>
+class CastHelper<InDtype, TOSA_REF_TYPE_FP16>
{
public:
using InEigenType = typename GetEigenType<InDtype>::type;
- using OutEigenType = typename GetEigenType<DType_FP16>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
CastHelper();
const FcnType& get_fcn() const
@@ -117,11 +117,11 @@ private:
FcnType fcn;
};
-template <DType OutDtype>
-class CastHelper<DType_FP16, OutDtype>
+template <TOSA_REF_TYPE OutDtype>
+class CastHelper<TOSA_REF_TYPE_FP16, OutDtype>
{
public:
- using InEigenType = typename GetEigenType<DType_FP16>::type;
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type;
using OutEigenType = typename GetEigenType<OutDtype>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
@@ -137,11 +137,11 @@ private:
};
template <>
-class CastHelper<DType_FP32, DType_FP16>
+class CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP16>
{
public:
- using InEigenType = typename GetEigenType<DType_FP32>::type;
- using OutEigenType = typename GetEigenType<DType_FP16>::type;
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
CastHelper();
const FcnType& get_fcn() const
@@ -153,12 +153,12 @@ private:
FcnType fcn;
};
-template <DType InDtype>
-class CastHelper<InDtype, DType_BF16>
+template <TOSA_REF_TYPE InDtype>
+class CastHelper<InDtype, TOSA_REF_TYPE_BF16>
{
public:
using InEigenType = typename GetEigenType<InDtype>::type;
- using OutEigenType = typename GetEigenType<DType_BF16>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
CastHelper();
const FcnType& get_fcn() const
@@ -170,11 +170,11 @@ private:
FcnType fcn;
};
-template <DType OutDtype>
-class CastHelper<DType_BF16, OutDtype>
+template <TOSA_REF_TYPE OutDtype>
+class CastHelper<TOSA_REF_TYPE_BF16, OutDtype>
{
public:
- using InEigenType = typename GetEigenType<DType_BF16>::type;
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type;
using OutEigenType = typename GetEigenType<OutDtype>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
@@ -190,11 +190,11 @@ private:
};
template <>
-class CastHelper<DType_FP32, DType_BF16>
+class CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_BF16>
{
public:
- using InEigenType = typename GetEigenType<DType_FP32>::type;
- using OutEigenType = typename GetEigenType<DType_BF16>::type;
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
CastHelper();
const FcnType& get_fcn() const
@@ -206,12 +206,12 @@ private:
FcnType fcn;
};
-template <DType InDtype>
-class CastHelper<InDtype, DType_FP32>
+template <TOSA_REF_TYPE InDtype>
+class CastHelper<InDtype, TOSA_REF_TYPE_FP32>
{
public:
using InEigenType = typename GetEigenType<InDtype>::type;
- using OutEigenType = typename GetEigenType<DType_FP32>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
CastHelper();
const FcnType& get_fcn() const
@@ -224,11 +224,11 @@ private:
};
template <>
-class CastHelper<DType_FP16, DType_FP32>
+class CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP32>
{
public:
- using InEigenType = typename GetEigenType<DType_FP16>::type;
- using OutEigenType = typename GetEigenType<DType_FP32>::type;
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
CastHelper();
const FcnType& get_fcn() const
@@ -241,11 +241,11 @@ private:
};
template <>
-class CastHelper<DType_BF16, DType_FP32>
+class CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP32>
{
public:
- using InEigenType = typename GetEigenType<DType_BF16>::type;
- using OutEigenType = typename GetEigenType<DType_FP32>::type;
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
CastHelper();
const FcnType& get_fcn() const
@@ -257,11 +257,11 @@ private:
FcnType fcn;
};
-template <DType OutDtype>
-class CastHelper<DType_FP32, OutDtype>
+template <TOSA_REF_TYPE OutDtype>
+class CastHelper<TOSA_REF_TYPE_FP32, OutDtype>
{
public:
- using InEigenType = typename GetEigenType<DType_FP32>::type;
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type;
using OutEigenType = typename GetEigenType<OutDtype>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
@@ -276,7 +276,26 @@ private:
FcnType fcn;
};
-template <int Rank, DType InDtype, DType OutDtype>
+template <TOSA_REF_TYPE OutDtype>
+class CastHelper<TOSA_REF_TYPE_FP64, OutDtype>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP64>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
+ static constexpr int32_t OutMax = GetQMax<OutDtype>::value;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
class OpCast : public GraphNode
{
public:
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index e7641ba..4508291 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -138,9 +138,9 @@ int SubgraphTraverser::initializeGraph()
for (auto op : block->GetOperators())
{
// translated TosaSerializationOperator to GraphNode
- DType input_dtype = DType_UNKNOWN;
- DType output_dtype = DType_UNKNOWN;
- DType weight_dtype = DType_UNKNOWN;
+ TOSA_REF_TYPE input_dtype = TOSA_REF_TYPE_UNKNOWN;
+ TOSA_REF_TYPE output_dtype = TOSA_REF_TYPE_UNKNOWN;
+ TOSA_REF_TYPE weight_dtype = TOSA_REF_TYPE_UNKNOWN;
uint32_t input_rank = 0;
uint32_t output_rank = 0;
uint32_t weight_rank = 0;
@@ -185,7 +185,7 @@ int SubgraphTraverser::initializeGraph()
!input_tensor,
"SubgraphTraverser::initializeGraph(): fail to get input tensor %s from TosaSerializationHandler",
input_name.c_str());
- input_dtype = input_tensor->GetDtype();
+ input_dtype = ConvertDType(input_tensor->GetDtype());
input_rank = input_tensor->GetShape().size();
}
@@ -207,7 +207,7 @@ int SubgraphTraverser::initializeGraph()
!weight_tensor,
"SubgraphTraverser::initializeGraph(): fail to get weight tensor %s from TosaSerializationHandler",
weight_name.c_str());
- weight_dtype = weight_tensor->GetDtype();
+ weight_dtype = ConvertDType(weight_tensor->GetDtype());
weight_rank = weight_tensor->GetShape().size();
}
@@ -220,7 +220,7 @@ int SubgraphTraverser::initializeGraph()
!output_tensor,
"SubgraphTraverser::initializeGraph(): fail to get output tensor %s from TosaSerializationHandler",
output_name.c_str());
- output_dtype = output_tensor->GetDtype();
+ output_dtype = ConvertDType(output_tensor->GetDtype());
output_rank = output_tensor->GetShape().size();
DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx,
@@ -246,16 +246,16 @@ int SubgraphTraverser::initializeGraph()
fprintf(g_func_debug.func_debug_file,
"SubgraphTraverser::initializeGraph(): OpFactory could not allocate op %8s input=(%s rank %d) "
"-> (%s rank %d)",
- EnumNamesOp()[op->GetOp()], EnumNamesDType()[input_dtype], input_rank,
- EnumNamesDType()[output_dtype], output_rank);
+ EnumNamesOp()[op->GetOp()], EnumNameTOSAREFTYPE(input_dtype), input_rank,
+ EnumNameTOSAREFTYPE(output_dtype), output_rank);
}
else
{
fprintf(g_func_debug.func_debug_file,
"SubgraphTraverser::initializeGraph(): OpFactory could not allocate op %8s input=(%s rank %d), "
"weight=(%s rank %d) -> (%s rank %d)",
- EnumNamesOp()[op->GetOp()], EnumNamesDType()[input_dtype], input_rank,
- EnumNamesDType()[weight_dtype], weight_rank, EnumNamesDType()[output_dtype], output_rank);
+ EnumNamesOp()[op->GetOp()], EnumNameTOSAREFTYPE(input_dtype), input_rank,
+ EnumNameTOSAREFTYPE(weight_dtype), weight_rank, EnumNameTOSAREFTYPE(output_dtype), output_rank);
}
for (auto& ts : op->GetInputTensorNames())
@@ -309,7 +309,7 @@ int SubgraphTraverser::initializeGraph()
TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetShape(), ts->GetShape().size());
SUBGRAPH_ERROR_IF(!tensor, "SubgraphTraverser::initializeGraph(): Unsupported tensor name=%s, type=%s, rank=%d",
- ts->GetName().c_str(), EnumNamesDType()[ts->GetDtype()], (int)ts->GetShape().size());
+ ts->GetName().c_str(), EnumNameDType(ts->GetDtype()), (int)ts->GetShape().size());
addTensor(tensor);
}
@@ -411,73 +411,89 @@ int SubgraphTraverser::allocateTensor()
if (!ts->GetData().empty())
{
DEBUG_INFO(GT, "Allocating tensor %s", tensor->getName().c_str());
- switch (ts->GetDtype())
+ auto serialization_dtype = ts->GetDtype();
+ switch (serialization_dtype)
{
- case DType_INT4:
- {
+ case DType_INT4: {
std::vector<int8_t> i4_data;
TosaSerializationHandler::ConvertU8toI4(ts->GetData(), tensor->getElementCount(), i4_data);
std::vector<int32_t> i32_data(i4_data.begin(), i4_data.end());
tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
}
break;
- case DType_INT8:
- {
+ case DType_INT8: {
std::vector<int8_t> i8_data;
TosaSerializationHandler::ConvertU8toI8(ts->GetData(), tensor->getElementCount(), i8_data);
std::vector<int32_t> i32_data(i8_data.begin(), i8_data.end());
tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
}
break;
- case DType_INT16:
- {
+ case DType_INT16: {
std::vector<int16_t> i16_data;
TosaSerializationHandler::ConvertU8toI16(ts->GetData(), tensor->getElementCount(), i16_data);
std::vector<int32_t> i32_data(i16_data.begin(), i16_data.end());
tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
}
break;
- case DType_INT32:
- {
+ case DType_INT32: {
std::vector<int32_t> i32_data;
TosaSerializationHandler::ConvertU8toI32(ts->GetData(), tensor->getElementCount(), i32_data);
tensor->setTensorValueInt32(i32_data.size(), i32_data.data());
}
break;
- case DType_INT48:
- {
+ case DType_INT48: {
std::vector<int64_t> i64_data;
TosaSerializationHandler::ConvertU8toI48(ts->GetData(), tensor->getElementCount(), i64_data);
tensor->setTensorValueInt64(i64_data.size(), i64_data.data());
}
break;
- case DType_FP16:
- {
+ case DType_FP16: {
// Interpret f16 data as float
std::vector<float> f16_data;
TosaSerializationHandler::ConvertU8toF16(ts->GetData(), tensor->getElementCount(), f16_data);
- tensor->setTensorValueFloat(f16_data.size(), f16_data.data());
+ if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
+ {
+ std::vector<double> f64_data(f16_data.begin(), f16_data.end());
+ tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
+ }
+ else
+ {
+ tensor->setTensorValueFloat(f16_data.size(), f16_data.data());
+ }
}
break;
- case DType_BF16:
- {
+ case DType_BF16: {
std::vector<float> fp32_data;
TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
// Ensure valid bfloat16 stored in each float
for (auto f : fp32_data)
ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f);
- tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
+ if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
+ {
+ std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
+ tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
+ }
+ else
+ {
+ tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
+ }
}
break;
- case DType_FP32:
- {
+ case DType_FP32: {
std::vector<float> fp32_data;
TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
- tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
+ if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
+ {
+ std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
+ tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
+ }
+ else
+ {
+ tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
+ }
}
break;
- case DType_BOOL:
- {
+ case DType_BOOL: {
std::vector<bool> bool_data;
TosaSerializationHandler::ConvertU8toBool(ts->GetData(), tensor->getElementCount(), bool_data);
@@ -493,7 +509,7 @@ int SubgraphTraverser::allocateTensor()
break;
default:
SUBGRAPH_ERROR_IF(true, "SubgraphTraverser::initializeGraph(): Unsupported tensor type %s.",
- EnumNamesDType()[ts->GetDtype()]);
+ EnumNameDType(ts->GetDtype()));
}
}
}
@@ -802,14 +818,14 @@ int SubgraphTraverser::validateGraph()
if (g_func_config.tosa_profile == 0)
{
- DType dtype = currTensor->getDtype();
+ TOSA_REF_TYPE dtype = currTensor->getDtype();
// Float-point disallowed
- if (dtype == DType_FP32 || dtype == DType_FP16)
+ if (dtype == TOSA_REF_TYPE_FP32 || dtype == TOSA_REF_TYPE_FP16)
{
WARNING("SubgraphTraverser::validateGraph(): TOSA Base Inference profile selected: All floating point "
"disabled, but %s tensor %s found\n",
- EnumNamesDType()[dtype], currTensor->getName().c_str());
+ EnumNameTOSAREFTYPE(dtype), currTensor->getName().c_str());
return 1;
}
}
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index 7af2069..08d5b2a 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -22,11 +22,14 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-TosaReference::Tensor::Tensor(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_)
+TosaReference::Tensor::Tensor(const std::string tensorName_,
+ const DType serializationDtype_,
+ const std::vector<int> shape_)
+ : tensorName(tensorName_)
+ , serializationDtype(serializationDtype_)
+ , shape(shape_)
+ , tensorDtype(ConvertDType(serializationDtype_))
{
- tensorName = std::string(tensorName_);
- tensorDtype = tensorDtype_;
- shape = std::vector<int>(shape_);
producer = nullptr;
isValid = false;
consumers.clear();
@@ -75,7 +78,7 @@ int TosaReference::Tensor::addConsumer(GraphNode* node)
int TosaReference::Tensor::dumpTensorParams(FILE* out) const
{
- fprintf(out, "Name: %s DType=%s isValid=%d Rank=%d Shape=%s\n", tensorName.c_str(), EnumNamesDType()[getDtype()],
+ fprintf(out, "Name: %s DType=%s isValid=%d Rank=%d Shape=%s\n", tensorName.c_str(), EnumNameTOSAREFTYPE(getDtype()),
getIsValid(), getRank(), getShapeAsString().c_str());
return 0;
@@ -83,7 +86,7 @@ int TosaReference::Tensor::dumpTensorParams(FILE* out) const
int TosaReference::Tensor::dumpTensorParams(std::ostream& out) const
{
- out << "Name: " << getName() << " DType=" << EnumNamesDType()[getDtype()] << " isValid=" << getIsValid()
+ out << "Name: " << getName() << " DType=" << EnumNameTOSAREFTYPE(getDtype()) << " isValid=" << getIsValid()
<< " Rank=" << getRank() << " Shape=" << getShapeAsString() << "\n";
return 0;
@@ -92,28 +95,33 @@ int TosaReference::Tensor::dumpTensorParams(std::ostream& out) const
int TosaReference::Tensor::readFromNpyFile(const char* filename)
{
uint32_t elements = getElementCount();
- float* fdatabuf = nullptr;
+ double* f64databuf = nullptr;
+ float* f32databuf = nullptr;
half_float::half* f16databuf = nullptr;
int32_t* i32databuf = nullptr;
int64_t* i64databuf = nullptr;
bool* bdatabuf = nullptr;
NumpyUtilities::NPError nperror;
- DType dtype = getDtype();
+ TOSA_REF_TYPE dtype = getDtype();
+ DType serialization_dtype = getSerializationDtype();
- switch (dtype)
+ assert(dtype == ConvertDType(serialization_dtype));
+ // if dtype is FP64, serialization_dtype must be one of FP32, FP16, BF16
+ assert(dtype != TOSA_REF_TYPE_FP64 || serialization_dtype == DType_FP32 || serialization_dtype == DType_FP16 ||
+ serialization_dtype == DType_BF16);
+
+ switch (serialization_dtype)
{
case DType_FP32:
case DType_BF16:
- fdatabuf = (float*)calloc(sizeof(float), elements);
- ASSERT_MEM(fdatabuf);
+ f32databuf = (float*)calloc(sizeof(float), elements);
+ ASSERT_MEM(f32databuf);
- nperror = NumpyUtilities::readFromNpyFile(filename, elements, fdatabuf);
+ nperror = NumpyUtilities::readFromNpyFile(filename, elements, f32databuf);
break;
case DType_FP16:
f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements);
ASSERT_MEM(f16databuf);
- fdatabuf = (float*)calloc(sizeof(float), elements);
- ASSERT_MEM(fdatabuf);
nperror = NumpyUtilities::readFromNpyFile(filename, elements, f16databuf);
break;
@@ -141,7 +149,7 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
nperror = NumpyUtilities::readFromNpyFile(filename, elements, bdatabuf);
break;
default:
- FATAL_ERROR("unsupported tensor type=%s", EnumNamesDType()[getDtype()]);
+ FATAL_ERROR("unknown tensor type=%s", EnumNameDType(serialization_dtype));
}
switch (nperror)
@@ -154,7 +162,7 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
FATAL_ERROR("readFromNpyFile: IO error reading file: %s", filename);
case NumpyUtilities::FILE_TYPE_MISMATCH:
FATAL_ERROR("readFromNpyFile: Tensor type %s and Numpy file type mismatch for tensor %s filename %s",
- EnumNamesDType()[getDtype()], getName().c_str(), filename);
+ EnumNameTOSAREFTYPE(getDtype()), getName().c_str(), filename);
case NumpyUtilities::HEADER_PARSE_ERROR:
FATAL_ERROR("Numpy header parsing error for file: %s", filename);
case NumpyUtilities::BUFFER_SIZE_MISMATCH:
@@ -166,75 +174,133 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
switch (dtype)
{
- case DType_FP16:
+ case TOSA_REF_TYPE_FP16:
// Convert from fp16 to fp32 so that fp16 values can be manipulated as float
+ f32databuf = (float*)calloc(sizeof(float), elements);
+ ASSERT_MEM(f32databuf);
for (uint32_t i=0; i < elements; i++) {
- fdatabuf[i] = half_float::half_cast<float, half_float::half>(f16databuf[i]);
+ f32databuf[i] = half_float::half_cast<float, half_float::half>(f16databuf[i]);
}
- if (setTensorValueFloat(elements, fdatabuf))
+ if (setTensorValueFloat(elements, f32databuf))
{
free(f16databuf);
- free(fdatabuf);
+ free(f32databuf);
return 1;
}
break;
- case DType_BF16:
+ case TOSA_REF_TYPE_BF16:
for (uint32_t i=0; i < elements; i++)
{
ASSERT_MSG(
- checkValidBFloat(fdatabuf[i]),
+ checkValidBFloat(f32databuf[i]),
"Input float value not a valid bfloat16 value."
);
}
- if (setTensorValueFloat(elements, fdatabuf))
+ if (setTensorValueFloat(elements, f32databuf))
{
- free(fdatabuf);
+ free(f32databuf);
return 1;
}
break;
- case DType_FP32:
- if (setTensorValueFloat(elements, fdatabuf))
+ case TOSA_REF_TYPE_FP32:
+ if (setTensorValueFloat(elements, f32databuf))
{
- free(fdatabuf);
+ free(f32databuf);
return 1;
}
break;
- case DType_INT32:
- case DType_UINT8:
- case DType_INT4:
- case DType_INT8:
- case DType_INT16:
- case DType_UINT16:
+ case TOSA_REF_TYPE_INT32:
+ case TOSA_REF_TYPE_UINT8:
+ case TOSA_REF_TYPE_INT4:
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
+ case TOSA_REF_TYPE_UINT16:
if (setTensorValueInt32(elements, i32databuf))
{
free(i32databuf);
return 1;
}
break;
- case DType_INT48:
+ case TOSA_REF_TYPE_INT48:
if (setTensorValueInt64(elements, i64databuf))
{
free(i64databuf);
return 1;
}
break;
- case DType_BOOL:
+ case TOSA_REF_TYPE_BOOL:
if (setTensorValueBool(elements, bdatabuf))
{
free(i32databuf);
return 1;
}
break;
+ case TOSA_REF_TYPE_FP64:
+ switch (serialization_dtype)
+ {
+ case DType_FP16:
+ // FP16 -> FP64
+ f64databuf = (double*)calloc(sizeof(double), elements);
+ ASSERT_MEM(f64databuf);
+ for (uint32_t i = 0; i < elements; i++)
+ {
+ f64databuf[i] = half_float::half_cast<double, half_float::half>(f16databuf[i]);
+ }
+ if (setTensorValueDouble(elements, f64databuf))
+ {
+ free(f16databuf);
+ free(f64databuf);
+ return 1;
+ }
+ break;
+ case DType_BF16:
+ // BF16 -> FP64
+ f64databuf = (double*)calloc(sizeof(double), elements);
+ ASSERT_MEM(f64databuf);
+ for (uint32_t i = 0; i < elements; i++)
+ {
+ ASSERT_MSG(checkValidBFloat(f32databuf[i]), "Input float value not a valid bfloat16 value.");
+ f64databuf[i] = static_cast<double>(f32databuf[i]);
+ }
+ if (setTensorValueDouble(elements, f64databuf))
+ {
+ free(f32databuf);
+ free(f64databuf);
+ return 1;
+ }
+ break;
+ case DType_FP32:
+ // FP32 -> FP64
+ f64databuf = (double*)calloc(sizeof(double), elements);
+ ASSERT_MEM(f64databuf);
+ for (uint32_t i = 0; i < elements; i++)
+ {
+ f64databuf[i] = static_cast<double>(f32databuf[i]);
+ }
+ if (setTensorValueDouble(elements, f64databuf))
+ {
+ free(f32databuf);
+ free(f64databuf);
+ return 1;
+ }
+ break;
+ default:
+ FATAL_ERROR("unexpected tensor type=%s and original tensor type=%s", EnumNameTOSAREFTYPE(dtype),
+ EnumNameDType(serialization_dtype));
+ }
+ break;
default:
- FATAL_ERROR("unsupported tensor type=%s", EnumNamesDType()[getDtype()]);
+ FATAL_ERROR("unsupported tensor type=%s", EnumNameTOSAREFTYPE(dtype));
}
setIsValid();
- if (fdatabuf)
- free(fdatabuf);
+ if (f32databuf)
+ free(f32databuf);
if (f16databuf)
free(f16databuf);
+ if (f64databuf)
+ free(f64databuf);
if (i32databuf)
free(i32databuf);
if (i64databuf)
@@ -247,58 +313,59 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
int TosaReference::Tensor::writeToNpyFile(const char* filename) const
{
- float* fdatabuf = nullptr;
+ float* f32databuf = nullptr;
+ double* f64databuf = nullptr;
half_float::half* f16databuf = nullptr;
int32_t* i32databuf = nullptr;
int64_t* i64databuf = nullptr;
bool* bdatabuf = nullptr;
NumpyUtilities::NPError nperror;
uint32_t elements = getElementCount();
- DType dtype = getDtype();
+ const TOSA_REF_TYPE dtype = getDtype();
switch (dtype)
{
- case DType_FP32:
- case DType_BF16:
- fdatabuf = (float*)calloc(sizeof(float), elements);
- ASSERT_MEM(fdatabuf);
+ case TOSA_REF_TYPE_FP32:
+ case TOSA_REF_TYPE_BF16:
+ f32databuf = (float*)calloc(sizeof(float), elements);
+ ASSERT_MEM(f32databuf);
- if (getTensorValueFloat(elements, fdatabuf))
+ if (getTensorValueFloat(elements, f32databuf))
{
- free(fdatabuf);
+ free(f32databuf);
return 1;
}
- nperror = NumpyUtilities::writeToNpyFile(filename, shape, fdatabuf);
+ nperror = NumpyUtilities::writeToNpyFile(filename, shape, f32databuf);
- free(fdatabuf);
+ free(f32databuf);
break;
- case DType_FP16:
- fdatabuf = (float*)calloc(sizeof(float), elements);
- ASSERT_MEM(fdatabuf);
+ case TOSA_REF_TYPE_FP16:
+ f32databuf = (float*)calloc(sizeof(float), elements);
+ ASSERT_MEM(f32databuf);
f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements);
ASSERT_MEM(f16databuf);
- if (getTensorValueFloat(elements, fdatabuf))
+ if (getTensorValueFloat(elements, f32databuf))
{
- free(fdatabuf);
+ free(f32databuf);
free(f16databuf);
return 1;
}
// Convert fp32 to fp16 so that output file contains valid fp16 data
for (uint32_t i=0; i < elements; i++) {
- f16databuf[i] = half_float::half_cast<half_float::half, float>(fdatabuf[i]);
+ f16databuf[i] = half_float::half_cast<half_float::half, float>(f32databuf[i]);
}
nperror = NumpyUtilities::writeToNpyFile(filename, shape, f16databuf);
- free(fdatabuf);
+ free(f32databuf);
free(f16databuf);
break;
- case DType_INT32:
- case DType_UINT8:
- case DType_INT4:
- case DType_INT8:
- case DType_INT16:
- case DType_UINT16:
+ case TOSA_REF_TYPE_INT32:
+ case TOSA_REF_TYPE_UINT8:
+ case TOSA_REF_TYPE_INT4:
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
+ case TOSA_REF_TYPE_UINT16:
i32databuf = (int32_t*)calloc(sizeof(int32_t), elements);
ASSERT_MEM(i32databuf);
@@ -312,7 +379,7 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const
free(i32databuf);
break;
- case DType_INT48:
+ case TOSA_REF_TYPE_INT48:
i64databuf = (int64_t*)calloc(sizeof(int64_t), elements);
ASSERT_MEM(i64databuf);
@@ -326,7 +393,7 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const
free(i64databuf);
break;
- case DType_BOOL:
+ case TOSA_REF_TYPE_BOOL:
bdatabuf = (bool*)calloc(sizeof(bool), elements);
ASSERT_MEM(bdatabuf);
@@ -340,8 +407,22 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const
free(bdatabuf);
break;
- default:
- FATAL_ERROR("unsupported tensor type=%s", EnumNamesDType()[getDtype()]);
+ case TOSA_REF_TYPE_FP64:
+ // @todo : support FP64 dtype
+ f64databuf = (double*)calloc(sizeof(double), elements);
+ ASSERT_MEM(f64databuf);
+
+ if (getTensorValueDouble(elements, f64databuf))
+ {
+ free(f64databuf);
+ return 1;
+ }
+ nperror = NumpyUtilities::writeToNpyFile(filename, shape, f64databuf);
+
+ free(f64databuf);
+ break;
+ case TOSA_REF_TYPE_UNKNOWN:
+ FATAL_ERROR("unsupported tensor type=%s", EnumNameTOSAREFTYPE(getDtype()));
}
switch (nperror)
@@ -386,11 +467,11 @@ int TosaReference::TensorTemplate<T>::copyValueFrom(TosaReference::Tensor* src)
return 1; \
} \
\
- uint32_t src_rank = src->getRank(); \
- uint32_t dst_rank = this->getRank(); \
- DType src_dtype = src->getDtype(); \
- DType dst_dtype = this->getDtype(); \
- bool tensor_match = true; \
+ const uint32_t src_rank = src->getRank(); \
+ const uint32_t dst_rank = this->getRank(); \
+ const TOSA_REF_TYPE src_dtype = src->getDtype(); \
+ const TOSA_REF_TYPE dst_dtype = this->getDtype(); \
+ bool tensor_match = true; \
\
if ((src_rank != dst_rank) || (src_dtype != dst_dtype)) \
{ \
@@ -413,8 +494,9 @@ int TosaReference::TensorTemplate<T>::copyValueFrom(TosaReference::Tensor* src)
{ \
WARNING("source tensor %s (rank=%u, dtype=%s, shape=%s) doesn't match destination tensor %s (rank=%u, " \
"dtype=%s, shape=%s)", \
- src->getName().c_str(), src_rank, EnumNamesDType()[src_dtype], src->getShapeAsString().c_str(), \
- this->getName().c_str(), dst_rank, EnumNamesDType()[dst_dtype], this->getShapeAsString().c_str()); \
+ src->getName().c_str(), src_rank, EnumNameTOSAREFTYPE(src_dtype), src->getShapeAsString().c_str(), \
+ this->getName().c_str(), dst_rank, EnumNameTOSAREFTYPE(dst_dtype), \
+ this->getShapeAsString().c_str()); \
return 1; \
} \
\
@@ -429,6 +511,13 @@ DEF_CTENSOR_COPY_VALUE_FROM(3, float)
DEF_CTENSOR_COPY_VALUE_FROM(4, float)
DEF_CTENSOR_COPY_VALUE_FROM(5, float)
DEF_CTENSOR_COPY_VALUE_FROM(6, float)
+DEF_CTENSOR_COPY_VALUE_FROM(0, double)
+DEF_CTENSOR_COPY_VALUE_FROM(1, double)
+DEF_CTENSOR_COPY_VALUE_FROM(2, double)
+DEF_CTENSOR_COPY_VALUE_FROM(3, double)
+DEF_CTENSOR_COPY_VALUE_FROM(4, double)
+DEF_CTENSOR_COPY_VALUE_FROM(5, double)
+DEF_CTENSOR_COPY_VALUE_FROM(6, double)
DEF_CTENSOR_COPY_VALUE_FROM(0, int32_t)
DEF_CTENSOR_COPY_VALUE_FROM(1, int32_t)
DEF_CTENSOR_COPY_VALUE_FROM(2, int32_t)
@@ -453,13 +542,37 @@ DEF_CTENSOR_COPY_VALUE_FROM(6, bool)
#undef DEF_CTENSOR_COPY_VALUE_FROM
+int TosaReference::Tensor::readfromVector(const ArrayProxy<double> vals)
+{
+ uint32_t elements = getElementCount();
+ switch (getDtype())
+ {
+ case TOSA_REF_TYPE_FP64:
+ if (vals.size() != elements)
+ {
+ WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
+ vals.size(), elements);
+ return -1;
+ }
+
+ setTensorValueDouble(elements, vals.data());
+ break;
+ default:
+ WARNING("The input type (float) doesn't match the data type assigned to the tensor (%s).",
+ EnumNameTOSAREFTYPE(getDtype()));
+ return -2;
+ }
+ setIsValid();
+ return 0;
+}
+
int TosaReference::Tensor::readfromVector(const ArrayProxy<float> vals)
{
uint32_t elements = getElementCount();
switch (getDtype())
{
- case DType_FP16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_FP32:
if (vals.size() != elements)
{
WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
@@ -469,7 +582,7 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<float> vals)
setTensorValueFloat(elements, vals.data());
break;
- case DType_BF16:
+ case TOSA_REF_TYPE_BF16:
if (vals.size() != elements)
{
WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
@@ -489,7 +602,7 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<float> vals)
break;
default:
WARNING("The input type (float) doesn't match the data type assigned to the tensor (%s).",
- EnumNameDType(getDtype()));
+ EnumNameTOSAREFTYPE(getDtype()));
return -2;
}
setIsValid();
@@ -503,7 +616,7 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<half_float::half> val
switch (getDtype())
{
- case DType_FP16:
+ case TOSA_REF_TYPE_FP16:
if (vals.size() != elements)
{
WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
@@ -521,7 +634,7 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<half_float::half> val
break;
default:
WARNING("The input type doesn't match the data type assigned to the tensor (%s).",
- EnumNameDType(getDtype()));
+ EnumNameTOSAREFTYPE(getDtype()));
return -2;
}
setIsValid();
@@ -533,12 +646,12 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<int32_t> vals)
uint32_t elements = getElementCount();
switch (getDtype())
{
- case DType_INT32:
- case DType_UINT8:
- case DType_INT4:
- case DType_INT8:
- case DType_INT16:
- case DType_UINT16:
+ case TOSA_REF_TYPE_INT32:
+ case TOSA_REF_TYPE_UINT8:
+ case TOSA_REF_TYPE_INT4:
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
+ case TOSA_REF_TYPE_UINT16:
if (vals.size() != elements)
{
WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
@@ -550,7 +663,7 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<int32_t> vals)
break;
default:
WARNING("The input type doesn't match the data type assigned to the tensor (%s).",
- EnumNameDType(getDtype()));
+ EnumNameTOSAREFTYPE(getDtype()));
return -2;
}
setIsValid();
@@ -562,7 +675,7 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<int64_t> vals)
uint32_t elements = getElementCount();
switch (getDtype())
{
- case DType_INT48:
+ case TOSA_REF_TYPE_INT48:
if (vals.size() != elements)
{
WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
@@ -574,7 +687,7 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<int64_t> vals)
break;
default:
WARNING("The input type doesn't match the data type assigned to the tensor (%s).",
- EnumNameDType(getDtype()));
+ EnumNameTOSAREFTYPE(getDtype()));
return -2;
}
setIsValid();
@@ -587,7 +700,7 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<unsigned char> vals)
switch (getDtype())
{
- case DType_BOOL:
+ case TOSA_REF_TYPE_BOOL:
if (vals.size() != elements)
{
WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
@@ -599,21 +712,45 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<unsigned char> vals)
break;
default:
WARNING("The input type (bool) doesn't match the data type assigned to the tensor (%s).",
- EnumNameDType(getDtype()));
+ EnumNameTOSAREFTYPE(getDtype()));
return -2;
}
setIsValid();
return 0;
}
+int TosaReference::Tensor::writeToVector(ArrayProxy<double> vals)
+{
+ uint32_t elements = getElementCount();
+
+ switch (getDtype())
+ {
+ case TOSA_REF_TYPE_FP64:
+ if (vals.size() != elements)
+ {
+ WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
+ vals.size(), elements);
+ return -1;
+ }
+
+ getTensorValueDouble(elements, vals.data());
+ break;
+ default:
+ WARNING("The output type (float) doesn't match the data type assigned to the tensor (%s).",
+ EnumNameTOSAREFTYPE(getDtype()));
+ return -2;
+ }
+ return 0;
+}
+
int TosaReference::Tensor::writeToVector(ArrayProxy<float> vals)
{
uint32_t elements = getElementCount();
switch (getDtype())
{
- case DType_FP16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_FP32:
if (vals.size() != elements)
{
WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
@@ -623,7 +760,7 @@ int TosaReference::Tensor::writeToVector(ArrayProxy<float> vals)
getTensorValueFloat(elements, vals.data());
break;
- case DType_BF16:
+ case TOSA_REF_TYPE_BF16:
if (vals.size() != elements)
{
WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
@@ -644,7 +781,7 @@ int TosaReference::Tensor::writeToVector(ArrayProxy<float> vals)
break;
default:
WARNING("The output type (float) doesn't match the data type assigned to the tensor (%s).",
- EnumNameDType(getDtype()));
+ EnumNameTOSAREFTYPE(getDtype()));
return -2;
}
return 0;
@@ -657,7 +794,7 @@ int TosaReference::Tensor::writeToVector(ArrayProxy<half_float::half> vals)
switch (getDtype())
{
- case DType_FP16:
+ case TOSA_REF_TYPE_FP16:
if (vals.size() != elements)
{
WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
@@ -675,7 +812,7 @@ int TosaReference::Tensor::writeToVector(ArrayProxy<half_float::half> vals)
break;
default:
WARNING("The output type doesn't match the data type assigned to the tensor (%s).",
- EnumNameDType(getDtype()));
+ EnumNameTOSAREFTYPE(getDtype()));
return -2;
}
return 0;
@@ -687,12 +824,12 @@ int TosaReference::Tensor::writeToVector(ArrayProxy<int32_t> vals)
switch (getDtype())
{
- case DType_INT32:
- case DType_UINT8:
- case DType_INT4:
- case DType_INT8:
- case DType_INT16:
- case DType_UINT16:
+ case TOSA_REF_TYPE_INT32:
+ case TOSA_REF_TYPE_UINT8:
+ case TOSA_REF_TYPE_INT4:
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
+ case TOSA_REF_TYPE_UINT16:
if (vals.size() != elements)
{
WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
@@ -704,7 +841,7 @@ int TosaReference::Tensor::writeToVector(ArrayProxy<int32_t> vals)
break;
default:
WARNING("The output type doesn't match the data type assigned to the tensor (%s).",
- EnumNameDType(getDtype()));
+ EnumNameTOSAREFTYPE(getDtype()));
return -2;
}
return 0;
@@ -716,7 +853,7 @@ int TosaReference::Tensor::writeToVector(ArrayProxy<int64_t> vals)
switch (getDtype())
{
- case tosa::DType_INT48:
+ case TOSA_REF_TYPE_INT48:
if (vals.size() != elements)
{
WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
@@ -728,7 +865,7 @@ int TosaReference::Tensor::writeToVector(ArrayProxy<int64_t> vals)
break;
default:
WARNING("The output type doesn't match the data type assigned to the tensor (%s).",
- EnumNameDType(getDtype()));
+ EnumNameTOSAREFTYPE(getDtype()));
return -2;
}
return 0;
@@ -740,7 +877,7 @@ int TosaReference::Tensor::writeToVector(ArrayProxy<unsigned char> vals)
switch (getDtype())
{
- case tosa::DType_BOOL:
+ case TOSA_REF_TYPE_BOOL:
if (vals.size() != elements)
{
WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
@@ -752,13 +889,165 @@ int TosaReference::Tensor::writeToVector(ArrayProxy<unsigned char> vals)
break;
default:
WARNING("The output type (bool) doesn't match the data type assigned to the tensor (%s).",
- EnumNameDType(getDtype()));
+ EnumNameTOSAREFTYPE(getDtype()));
return -2;
}
return 0;
}
template <class T>
+int TosaReference::TensorTemplate<T>::setTensorValueDouble(const size_t buflen, const double* vals)
+{
+ FATAL_ERROR("TensorTemplate<T>::setTensorValueFloat should not be called. "
+ "Implement template specialization version.");
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<double>::setTensorValueDouble(const size_t bufLen, const double* vals)
+{
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ (*tensor)(0) = vals[0];
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<double>::setTensorValueDouble(const size_t bufLen, const double* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ (*tensor)(i0) = vals[idx++];
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<double>::setTensorValueDouble(const size_t bufLen, const double* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ (*tensor)(i0, i1) = vals[idx++];
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<double>::setTensorValueDouble(const size_t bufLen, const double* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ (*tensor)(i0, i1, i2) = vals[idx++];
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<double>::setTensorValueDouble(const size_t bufLen, const double* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ (*tensor)(i0, i1, i2, i3) = vals[idx++];
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<double>::setTensorValueDouble(const size_t bufLen, const double* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ (*tensor)(i0, i1, i2, i3, i4) = vals[idx++];
+ }
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<double>::setTensorValueDouble(const size_t bufLen, const double* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ (*tensor)(i0, i1, i2, i3, i4, i5) = vals[idx++];
+ }
+ }
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+template <class T>
int TosaReference::TensorTemplate<T>::setTensorValueFloat(const size_t buflen, const float* vals)
{
FATAL_ERROR("TensorTemplate<T>::setTensorValueFloat should not be called. "
@@ -1367,6 +1656,196 @@ int TosaReference::Tensor6<bool>::setTensorValueBool(const size_t bufLen, const
}
template <class T>
+int TosaReference::TensorTemplate<T>::getTensorValueDouble(const size_t bufLen, double* vals) const
+{
+ FATAL_ERROR("TensorTemplate<T>::getTensorValueDouble should not be called. "
+ "Implement template specialization version.");
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<double>::getTensorValueDouble(const size_t bufLen, double* vals) const
+{
+ int totalVals = 1;
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ vals[0] = (*tensor)(0);
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<double>::getTensorValueDouble(const size_t bufLen, double* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ vals[idx++] = (*tensor)(i0);
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<double>::getTensorValueDouble(const size_t bufLen, double* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ vals[idx++] = (*tensor)(i0, i1);
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<double>::getTensorValueDouble(const size_t bufLen, double* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2);
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<double>::getTensorValueDouble(const size_t bufLen, double* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3);
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<double>::getTensorValueDouble(const size_t bufLen, double* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3, i4);
+ }
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<double>::getTensorValueDouble(const size_t bufLen, double* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3, i4, i5);
+ }
+ }
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+template <class T>
int TosaReference::TensorTemplate<T>::getTensorValueFloat(const size_t bufLen, float* vals) const
{
FATAL_ERROR("TensorTemplate<T>::getTensorValueFloat should not be called. "
@@ -2127,6 +2606,82 @@ int TosaReference::Tensor6<bool>::getTensorValueBool(const size_t bufLen, bool*
}
template <>
+int TosaReference::Tensor0<double>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor0<double>();
+
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+template <>
+int TosaReference::Tensor1<double>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor1<double>(shape[0]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+template <>
+int TosaReference::Tensor2<double>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor2<double>(shape[0], shape[1]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor3<double>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor3<double>(shape[0], shape[1], shape[2]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor4<double>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor4<double>(shape[0], shape[1], shape[2], shape[3]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor5<double>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor5<double>(shape[0], shape[1], shape[2], shape[3], shape[4]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor6<double>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor6<double>(shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
int TosaReference::Tensor0<float>::allocate()
{
ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
@@ -2428,6 +2983,230 @@ int TosaReference::Tensor6<bool>::allocate()
}
template <>
+int TosaReference::Tensor0<double>::dumpTensor(FILE* out) const
+{
+ char fp_fmt[32];
+ snprintf(fp_fmt, sizeof(fp_fmt), "[ %%%sf ]\n", g_func_config.fp_format.c_str());
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, fp_fmt, (*tensor)(0));
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<double>::dumpTensor(FILE* out) const
+{
+ char fp_fmt[32];
+ snprintf(fp_fmt, sizeof(fp_fmt), " %%%sf ", g_func_config.fp_format.c_str());
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, fp_fmt, (*tensor)(i0));
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<double>::dumpTensor(FILE* out) const
+{
+ char fp_fmt[32];
+ snprintf(fp_fmt, sizeof(fp_fmt), " %%%sf ", g_func_config.fp_format.c_str());
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, fp_fmt, (*tensor)(i0, i1));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<double>::dumpTensor(FILE* out) const
+{
+ char fp_fmt[32];
+ snprintf(fp_fmt, sizeof(fp_fmt), " %%%sf ", g_func_config.fp_format.c_str());
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, fp_fmt, (*tensor)(i0, i1, i2));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<double>::dumpTensor(FILE* out) const
+{
+ char fp_fmt[32];
+ snprintf(fp_fmt, sizeof(fp_fmt), " %%%sf ", g_func_config.fp_format.c_str());
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, fp_fmt, (*tensor)(i0, i1, i2, i3));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<double>::dumpTensor(FILE* out) const
+{
+ char fp_fmt[32];
+ snprintf(fp_fmt, sizeof(fp_fmt), " %%%sf ", g_func_config.fp_format.c_str());
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, "[");
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ fprintf(out, fp_fmt, (*tensor)(i0, i1, i2, i3, i4));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<double>::dumpTensor(FILE* out) const
+{
+ char fp_fmt[32];
+ snprintf(fp_fmt, sizeof(fp_fmt), " %%%sf ", g_func_config.fp_format.c_str());
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, "[");
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ fprintf(out, "[");
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ fprintf(out, fp_fmt, (*tensor)(i0, i1, i2, i3, i4, i5));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
int TosaReference::Tensor0<float>::dumpTensor(FILE* out) const
{
char fp_fmt[32];
@@ -3342,6 +4121,14 @@ int TosaReference::TensorTemplate<T>::dumpTensor(FILE* out) const
}
// template explicit specialization
+template class TosaReference::TensorTemplate<Eigen::Tensor<double, 0>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<double, 1>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<double, 2>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<double, 3>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<double, 4>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<double, 5>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<double, 6>>;
+
template class TosaReference::TensorTemplate<Eigen::Tensor<float, 0>>;
template class TosaReference::TensorTemplate<Eigen::Tensor<float, 1>>;
template class TosaReference::TensorTemplate<Eigen::Tensor<float, 2>>;
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index d5f1de8..08ee8bf 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -17,9 +17,9 @@
#define TOSA_REFERENCE_TENSOR_H
#include "array_proxy.h"
+#include "dtype.h"
#include "model_common.h"
#include "ops/template_types.h"
-#include "tosa_generated.h"
#include "tosa_serialization_handler.h"
#include <Eigen/CXX11/Tensor>
#include <list>
@@ -34,7 +34,7 @@ class GraphNode;
class Tensor
{
public:
- Tensor(std::string tensorName_, DType tensorDtype__, std::vector<int> shape_);
+ Tensor(const std::string tensorName_, const DType serializationDtype_, const std::vector<int> shape_);
virtual ~Tensor();
@@ -212,19 +212,26 @@ public:
return shape.size();
}
- const DType getDtype() const
+ const TOSA_REF_TYPE getDtype() const
{
return tensorDtype;
}
+ const DType getSerializationDtype() const
+ {
+ return serializationDtype;
+ }
+
virtual int dumpTensor(FILE* out) const = 0;
virtual int dumpTensorParams(FILE* out) const;
virtual int dumpTensorParams(std::ostream& out) const;
+ virtual int setTensorValueDouble(const size_t bufLen, const double* vals) = 0;
virtual int setTensorValueFloat(const size_t bufLen, const float* vals) = 0;
virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals) = 0;
virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals) = 0;
virtual int setTensorValueBool(const size_t bufLen, const bool* vals) = 0;
+ virtual int getTensorValueDouble(const size_t bufLen, double* fbuf) const = 0;
virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const = 0;
virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const = 0;
virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const = 0;
@@ -234,12 +241,14 @@ public:
virtual int writeToNpyFile(const char* filename) const;
virtual int copyValueFrom(Tensor* tensor) = 0;
+ virtual int readfromVector(const ArrayProxy<double> vals);
virtual int readfromVector(const ArrayProxy<float> vals);
virtual int readfromVector(const ArrayProxy<half_float::half> vals);
virtual int readfromVector(const ArrayProxy<int32_t> vals);
virtual int readfromVector(const ArrayProxy<int64_t> vals);
virtual int readfromVector(const ArrayProxy<unsigned char> vals);
+ virtual int writeToVector(ArrayProxy<double> vals);
virtual int writeToVector(ArrayProxy<float> vals);
virtual int writeToVector(ArrayProxy<half_float::half> vals);
virtual int writeToVector(ArrayProxy<int32_t> vals);
@@ -258,10 +267,11 @@ public:
virtual bool is_allocated() = 0;
protected:
- std::string tensorName;
- DType tensorDtype;
+ const std::string tensorName;
+ const DType serializationDtype;
+ const std::vector<int> shape;
+ const TOSA_REF_TYPE tensorDtype;
int isValid;
- std::vector<int> shape;
int isSubgraphInput;
int isSubgraphOutput;
bool isAllocated;
@@ -284,8 +294,8 @@ template <class T>
class TensorTemplate : public Tensor
{
public:
- TensorTemplate(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_)
- : Tensor(tensorName_, tensorDtype_, shape_)
+ TensorTemplate(const std::string tensorName_, const DType dtype_, const std::vector<int> shape_)
+ : Tensor(tensorName_, dtype_, shape_)
{
tensor = nullptr;
}
@@ -330,10 +340,13 @@ public:
virtual int dumpTensor(FILE* out) const;
+ virtual int setTensorValueDouble(const size_t bufLen, const double* vals);
virtual int setTensorValueFloat(const size_t bufLen, const float* vals);
virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals);
virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals);
virtual int setTensorValueBool(const size_t bufLen, const bool* vals);
+
+ virtual int getTensorValueDouble(const size_t bufLen, double* fbuf) const;
virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const;
virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const;
virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const;
@@ -363,6 +376,21 @@ template <>
int Tensor6<float>::allocate();
template <>
+int Tensor0<double>::allocate();
+template <>
+int Tensor1<double>::allocate();
+template <>
+int Tensor2<double>::allocate();
+template <>
+int Tensor3<double>::allocate();
+template <>
+int Tensor4<double>::allocate();
+template <>
+int Tensor5<double>::allocate();
+template <>
+int Tensor6<double>::allocate();
+
+template <>
int Tensor0<int32_t>::allocate();
template <>
int Tensor1<int32_t>::allocate();
@@ -423,6 +451,21 @@ template <>
int Tensor6<float>::copyValueFrom(Tensor* src);
template <>
+int Tensor0<double>::copyValueFrom(Tensor* src);
+template <>
+int Tensor1<double>::copyValueFrom(Tensor* src);
+template <>
+int Tensor2<double>::copyValueFrom(Tensor* src);
+template <>
+int Tensor3<double>::copyValueFrom(Tensor* src);
+template <>
+int Tensor4<double>::copyValueFrom(Tensor* src);
+template <>
+int Tensor5<double>::copyValueFrom(Tensor* src);
+template <>
+int Tensor6<double>::copyValueFrom(Tensor* src);
+
+template <>
int Tensor0<int32_t>::copyValueFrom(Tensor* src);
template <>
int Tensor1<int32_t>::copyValueFrom(Tensor* src);
@@ -558,6 +601,36 @@ template <>
int Tensor6<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
template <>
+int Tensor0<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
+template <>
+int Tensor1<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
+template <>
+int Tensor2<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
+template <>
+int Tensor3<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
+template <>
+int Tensor4<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
+template <>
+int Tensor5<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
+template <>
+int Tensor6<double>::setTensorValueDouble(const size_t bufLen, const double* vals);
+
+template <>
+int Tensor0<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
+template <>
+int Tensor1<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
+template <>
+int Tensor2<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
+template <>
+int Tensor3<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
+template <>
+int Tensor4<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
+template <>
+int Tensor5<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
+template <>
+int Tensor6<double>::getTensorValueDouble(const size_t bufLen, double* vals) const;
+
+template <>
int Tensor0<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
template <>
int Tensor1<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
@@ -587,7 +660,6 @@ int Tensor5<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
template <>
int Tensor6<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
-// assume we only dump float type tensor now
template <>
int Tensor0<float>::dumpTensor(FILE* out) const;
template <>
@@ -603,6 +675,20 @@ int Tensor5<float>::dumpTensor(FILE* out) const;
template <>
int Tensor6<float>::dumpTensor(FILE* out) const;
template <>
+int Tensor0<double>::dumpTensor(FILE* out) const;
+template <>
+int Tensor1<double>::dumpTensor(FILE* out) const;
+template <>
+int Tensor2<double>::dumpTensor(FILE* out) const;
+template <>
+int Tensor3<double>::dumpTensor(FILE* out) const;
+template <>
+int Tensor4<double>::dumpTensor(FILE* out) const;
+template <>
+int Tensor5<float>::dumpTensor(FILE* out) const;
+template <>
+int Tensor6<double>::dumpTensor(FILE* out) const;
+template <>
int Tensor0<int32_t>::dumpTensor(FILE* out) const;
template <>
int Tensor1<int32_t>::dumpTensor(FILE* out) const;
@@ -648,100 +734,119 @@ int Tensor6<bool>::dumpTensor(FILE* out) const;
class TensorFactory
{
public:
- static Tensor* newTensor(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_, const uint32_t rank)
+ static Tensor* newTensor(std::string tensorName_, DType dtype_, std::vector<int> shape_, const uint32_t rank)
{
+ TOSA_REF_TYPE tensorDtype_ = ConvertDType(dtype_);
switch (tensorDtype_)
{
- case DType_FP32:
- case DType_FP16:
- case DType_BF16:
+ case TOSA_REF_TYPE_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ switch (rank)
+ {
+ case 0:
+ return new Tensor0<float>(tensorName_, dtype_, shape_);
+ case 1:
+ return new Tensor1<float>(tensorName_, dtype_, shape_);
+ case 2:
+ return new Tensor2<float>(tensorName_, dtype_, shape_);
+ case 3:
+ return new Tensor3<float>(tensorName_, dtype_, shape_);
+ case 4:
+ return new Tensor4<float>(tensorName_, dtype_, shape_);
+ case 5:
+ return new Tensor5<float>(tensorName_, dtype_, shape_);
+ case 6:
+ return new Tensor6<float>(tensorName_, dtype_, shape_);
+ }
+ break;
+ case TOSA_REF_TYPE_INT32:
+ case TOSA_REF_TYPE_UINT8:
+ case TOSA_REF_TYPE_INT4:
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
+ case TOSA_REF_TYPE_UINT16:
switch (rank)
{
case 0:
- return new Tensor0<float>(tensorName_, tensorDtype_, shape_);
+ return new Tensor0<int32_t>(tensorName_, dtype_, shape_);
case 1:
- return new Tensor1<float>(tensorName_, tensorDtype_, shape_);
+ return new Tensor1<int32_t>(tensorName_, dtype_, shape_);
case 2:
- return new Tensor2<float>(tensorName_, tensorDtype_, shape_);
+ return new Tensor2<int32_t>(tensorName_, dtype_, shape_);
case 3:
- return new Tensor3<float>(tensorName_, tensorDtype_, shape_);
+ return new Tensor3<int32_t>(tensorName_, dtype_, shape_);
case 4:
- return new Tensor4<float>(tensorName_, tensorDtype_, shape_);
+ return new Tensor4<int32_t>(tensorName_, dtype_, shape_);
case 5:
- return new Tensor5<float>(tensorName_, tensorDtype_, shape_);
+ return new Tensor5<int32_t>(tensorName_, dtype_, shape_);
case 6:
- return new Tensor6<float>(tensorName_, tensorDtype_, shape_);
+ return new Tensor6<int32_t>(tensorName_, dtype_, shape_);
}
break;
- case DType_INT32:
- case DType_UINT8:
- case DType_INT4:
- case DType_INT8:
- case DType_INT16:
- case DType_UINT16:
+ case TOSA_REF_TYPE_INT48:
switch (rank)
{
case 0:
- return new Tensor0<int32_t>(tensorName_, tensorDtype_, shape_);
+ return new Tensor0<int64_t>(tensorName_, dtype_, shape_);
case 1:
- return new Tensor1<int32_t>(tensorName_, tensorDtype_, shape_);
+ return new Tensor1<int64_t>(tensorName_, dtype_, shape_);
case 2:
- return new Tensor2<int32_t>(tensorName_, tensorDtype_, shape_);
+ return new Tensor2<int64_t>(tensorName_, dtype_, shape_);
case 3:
- return new Tensor3<int32_t>(tensorName_, tensorDtype_, shape_);
+ return new Tensor3<int64_t>(tensorName_, dtype_, shape_);
case 4:
- return new Tensor4<int32_t>(tensorName_, tensorDtype_, shape_);
+ return new Tensor4<int64_t>(tensorName_, dtype_, shape_);
case 5:
- return new Tensor5<int32_t>(tensorName_, tensorDtype_, shape_);
+ return new Tensor5<int64_t>(tensorName_, dtype_, shape_);
case 6:
- return new Tensor6<int32_t>(tensorName_, tensorDtype_, shape_);
+ return new Tensor6<int64_t>(tensorName_, dtype_, shape_);
}
break;
- case DType_INT48:
+ case TOSA_REF_TYPE_BOOL:
switch (rank)
{
case 0:
- return new Tensor0<int64_t>(tensorName_, tensorDtype_, shape_);
+ return new Tensor0<bool>(tensorName_, dtype_, shape_);
case 1:
- return new Tensor1<int64_t>(tensorName_, tensorDtype_, shape_);
+ return new Tensor1<bool>(tensorName_, dtype_, shape_);
case 2:
- return new Tensor2<int64_t>(tensorName_, tensorDtype_, shape_);
+ return new Tensor2<bool>(tensorName_, dtype_, shape_);
case 3:
- return new Tensor3<int64_t>(tensorName_, tensorDtype_, shape_);
+ return new Tensor3<bool>(tensorName_, dtype_, shape_);
case 4:
- return new Tensor4<int64_t>(tensorName_, tensorDtype_, shape_);
+ return new Tensor4<bool>(tensorName_, dtype_, shape_);
case 5:
- return new Tensor5<int64_t>(tensorName_, tensorDtype_, shape_);
+ return new Tensor5<bool>(tensorName_, dtype_, shape_);
case 6:
- return new Tensor6<int64_t>(tensorName_, tensorDtype_, shape_);
+ return new Tensor6<bool>(tensorName_, dtype_, shape_);
}
break;
- case DType_BOOL:
+ case TOSA_REF_TYPE_FP64:
switch (rank)
{
case 0:
- return new Tensor0<bool>(tensorName_, tensorDtype_, shape_);
+ return new Tensor0<double>(tensorName_, dtype_, shape_);
case 1:
- return new Tensor1<bool>(tensorName_, tensorDtype_, shape_);
+ return new Tensor1<double>(tensorName_, dtype_, shape_);
case 2:
- return new Tensor2<bool>(tensorName_, tensorDtype_, shape_);
+ return new Tensor2<double>(tensorName_, dtype_, shape_);
case 3:
- return new Tensor3<bool>(tensorName_, tensorDtype_, shape_);
+ return new Tensor3<double>(tensorName_, dtype_, shape_);
case 4:
- return new Tensor4<bool>(tensorName_, tensorDtype_, shape_);
+ return new Tensor4<double>(tensorName_, dtype_, shape_);
case 5:
- return new Tensor5<bool>(tensorName_, tensorDtype_, shape_);
+ return new Tensor5<double>(tensorName_, dtype_, shape_);
case 6:
- return new Tensor6<bool>(tensorName_, tensorDtype_, shape_);
+ return new Tensor6<double>(tensorName_, dtype_, shape_);
}
break;
- default:
+ case TOSA_REF_TYPE_UNKNOWN:
+ assert(0); // tensorDtype_ is uninitialized
break;
}
return nullptr;
}
-
- static Tensor* newTensor(DType type, const std::vector<int> shape);
};
}; // namespace TosaReference
diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib
-Subproject dce6cebbeb6c45625c4ef8fafb5a7775319101c
+Subproject cfcb20d08c4c409bbcd2d2dde6ca5ecdac29945
diff --git a/verif/frameworks/tosa_verif_framework_compiler_runner.py b/verif/frameworks/tosa_verif_framework_compiler_runner.py
index 0d98c17..28e4369 100755
--- a/verif/frameworks/tosa_verif_framework_compiler_runner.py
+++ b/verif/frameworks/tosa_verif_framework_compiler_runner.py
@@ -57,6 +57,13 @@ def parse_args():
help="Reference model base directory",
)
parser.add_argument(
+ "-p",
+ "--precise-mode",
+ dest="precise_mode",
+ action="store_true",
+ help="run in precise mode (FP64)",
+ )
+ parser.add_argument(
"-v", "--verbose", dest="verbose", action="count", help="Verbose run"
)
parser.add_argument(
@@ -552,6 +559,9 @@ def run_test(args, test, framework):
if args.debug_ref_model:
ref_model_cmd.extend(["-D ALL", "-l high"])
+ if args.precise_mode:
+ ref_model_cmd.extend(["--precise_mode=1"])
+
if args.valgrind:
ref_model_cmd = [
"valgrind",
@@ -594,7 +604,11 @@ def run_test(args, test, framework):
)
return (TestResult.REF_MODEL_RUNTIME_ERROR, 0.0, e)
- if tf_result.dtype == np.float16:
+ if args.precise_mode == 1 and (
+ tf_result.dtype == np.float16 or tf_result.dtype == np.float32
+ ):
+ tf_result = tf_result.astype(np.float64)
+ elif tf_result.dtype == np.float16:
tf_result = tf_result.astype(np.float32)
elif (
tf_result.dtype == np.uint8
diff --git a/verif/runner/tosa_refmodel_sut_run.py b/verif/runner/tosa_refmodel_sut_run.py
index 95f6e7b..df5c0db 100644
--- a/verif/runner/tosa_refmodel_sut_run.py
+++ b/verif/runner/tosa_refmodel_sut_run.py
@@ -45,6 +45,9 @@ class TosaSUTRunner(TosaTestRunner):
if args.ref_intermediates:
cmd.extend(["--dump_intermediates", str(args.ref_intermediates)])
+ if args.precise_mode:
+ cmd.extend(["--precise_mode=1"])
+
# Run command and interpret tosa graph result via process return codes
graphMessage = None
try:
diff --git a/verif/runner/tosa_verif_run_tests.py b/verif/runner/tosa_verif_run_tests.py
index 6b5d77e..814c864 100644
--- a/verif/runner/tosa_verif_run_tests.py
+++ b/verif/runner/tosa_verif_run_tests.py
@@ -147,6 +147,13 @@ def parseArgs(argv):
help="A TOSA level defines operator parameter ranges that an implementation shall support."
"Config tosa_level for running the reference model only. Default is EIGHTK",
)
+ parser.add_argument(
+ "-p",
+ "--precise-mode",
+ dest="precise_mode",
+ action="store_true",
+ help="Run the reference model in precise mode (FP64)",
+ )
args = parser.parse_args(argv)