aboutsummaryrefslogtreecommitdiff
path: root/reference_model
diff options
context:
space:
mode:
authorJames Ward <james.ward@arm.com>2022-08-12 20:48:56 +0100
committerJames Ward <james.ward@arm.com>2022-10-11 11:56:02 +0100
commit8b39043c70332e1e4c95ee6a9616aec40dd3baf1 (patch)
treefea519246b698eb944b9d58537fc90bc30481d11 /reference_model
parentba5fad356a926d5e1c6e0fe6b546a310230cc5a8 (diff)
downloadreference_model-8b39043c70332e1e4c95ee6a9616aec40dd3baf1.tar.gz
Reference model changes for fp16 support
Change-Id: I72f21fcfa153046274969d327313e3349981dbe6 Signed-off-by: James Ward <james.ward@arm.com>
Diffstat (limited to 'reference_model')
-rw-r--r--reference_model/CMakeLists.txt10
-rw-r--r--reference_model/src/graph_node.h19
-rw-r--r--reference_model/src/ops/activation_funcs.cc8
-rw-r--r--reference_model/src/ops/comparison.cc8
-rw-r--r--reference_model/src/ops/data_layout.cc10
-rw-r--r--reference_model/src/ops/data_nodes.cc3
-rw-r--r--reference_model/src/ops/ewise_binary.cc17
-rw-r--r--reference_model/src/ops/ewise_ternary.cc3
-rw-r--r--reference_model/src/ops/ewise_unary.cc18
-rw-r--r--reference_model/src/ops/image.cc7
-rw-r--r--reference_model/src/ops/op_factory.cc112
-rw-r--r--reference_model/src/ops/op_factory.h39
-rw-r--r--reference_model/src/ops/reduction.cc6
-rw-r--r--reference_model/src/ops/scatter_gather.cc4
-rw-r--r--reference_model/src/ops/template_types.h41
-rw-r--r--reference_model/src/ops/tensor_ops.cc315
-rw-r--r--reference_model/src/ops/tensor_ops.h81
-rw-r--r--reference_model/src/ops/type_conversion.cc33
-rw-r--r--reference_model/src/ops/type_conversion.h38
-rw-r--r--reference_model/src/subgraph_traverser.cc12
-rw-r--r--reference_model/src/tensor.cc46
-rw-r--r--reference_model/src/tensor.h3
22 files changed, 597 insertions, 236 deletions
diff --git a/reference_model/CMakeLists.txt b/reference_model/CMakeLists.txt
index a790968..04b0db5 100644
--- a/reference_model/CMakeLists.txt
+++ b/reference_model/CMakeLists.txt
@@ -43,7 +43,7 @@ else()
set(SERIALIZATION_DIR "../thirdparty/serialization_lib/")
endif()
-# If Flatbuffers or Eigen path isn't specified, set to thirdparty directory.
+# If Flatbuffers, Eigen, Half path isn't specified, set to thirdparty directory.
if(NOT FLATBUFFERS_DIR)
set(FLATBUFFERS_DIR "../thirdparty/serialization_lib/third_party/flatbuffers/")
endif()
@@ -52,6 +52,10 @@ if(NOT EIGEN_DIR)
set(EIGEN_DIR "../thirdparty/eigen/")
endif()
+if(NOT HALF_DIR)
+ set(HALF_DIR "../thirdparty/serialization_lib/third_party/half")
+endif()
+
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
# Common sources required for TOSA Reference Model library, executable and unit tests
@@ -92,6 +96,7 @@ target_include_directories(tosa_reference_model_lib
${EIGEN_DIR}
${EIGEN_DIR}/unsupported/
${SERIALIZATION_DIR}/include
+ ${HALF_DIR}/include
)
target_link_libraries(tosa_reference_model_lib
@@ -130,6 +135,7 @@ if(BUILD_TOSA_REFERENCE_MODEL_EXECUTABLE)
${EIGEN_DIR}
${EIGEN_DIR}/unsupported/
${SERIALIZATION_DIR}/include
+ ${HALF_DIR}/include
)
target_link_libraries(tosa_reference_model
@@ -171,6 +177,7 @@ if(BUILD_TOSA_REFERENCE_MODEL_TESTS)
${EIGEN_DIR}
${EIGEN_DIR}/unsupported/
${SERIALIZATION_DIR}/include
+ ${HALF_DIR}/include
${DOCTEST_DIR}
)
@@ -203,6 +210,7 @@ if(BUILD_MODEL_RUNNER_SAMPLE)
${EIGEN_DIR}
${EIGEN_DIR}/unsupported/
${SERIALIZATION_DIR}/include
+ ${HALF_DIR}/include
)
target_link_libraries(model_runner_sample
diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h
index adf81b2..874f1d8 100644
--- a/reference_model/src/graph_node.h
+++ b/reference_model/src/graph_node.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -24,6 +24,9 @@
#define DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) template class TosaReference::OP<RANK, DType_##DTYPE>;
+#define DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, RANK, DTYPE, ACCUM_DTYPE) \
+ template class TosaReference::OP<RANK, DType_##DTYPE, DType_##ACCUM_DTYPE>;
+
#define DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \
template class TosaReference::OP<RANK, DType_##DTYPE1, DType_##DTYPE2>;
@@ -35,8 +38,14 @@
#define DEF_INSTANTIATE_ONE_TYPE(OP, DTYPE) template class TosaReference::OP<DType_##DTYPE>;
+#define DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OP, DTYPE, ACCUM_DTYPE) \
+ template class TosaReference::OP<DType_##DTYPE, DType_##ACCUM_DTYPE>;
+
#define DEF_INSTANTIATE_TWO_TYPE(OP, DTYPE1, DTYPE2) template class TosaReference::OP<DType_##DTYPE1, DType_##DTYPE2>;
+#define DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OP, DTYPE1, DTYPE2, ACCUM_DTYPE) \
+ template class TosaReference::OP<DType_##DTYPE1, DType_##DTYPE2, DType_##ACCUM_DTYPE>;
+
#define DEF_INSTANTIATE_THREE_TYPE(OP, DTYPE1, DTYPE2, OP_TYPE) \
template class TosaReference::OP<DType_##DTYPE1, DType_##DTYPE2, OP_TYPE>;
@@ -57,6 +66,14 @@
DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \
DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 6, DTYPE)
+#define DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, DTYPE, ACCUM_DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 1, DTYPE, ACCUM_DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 2, DTYPE, ACCUM_DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 3, DTYPE, ACCUM_DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 4, DTYPE, ACCUM_DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 5, DTYPE, ACCUM_DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 6, DTYPE, ACCUM_DTYPE)
+
#define DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \
DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 0, DTYPE1, DTYPE2) \
DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 1, DTYPE1, DTYPE2) \
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc
index c344bcb..1c0c23a 100644
--- a/reference_model/src/ops/activation_funcs.cc
+++ b/reference_model/src/ops/activation_funcs.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -27,6 +27,7 @@ int OpClamp<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
+ case DType_FP16:
case DType_FLOAT:
{
InEigenType min = (InEigenType)attribute->min_fp();
@@ -57,6 +58,7 @@ int OpSigmoid<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
+ case DType_FP16:
case DType_FLOAT:
this->fcn = [](InEigenType a) -> OutEigenType { return (1.0 / (1.0 + (expf(-1.0 * a)))); };
break;
@@ -72,6 +74,7 @@ int OpTanh<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
+ case DType_FP16:
case DType_FLOAT:
this->fcn = [](InEigenType a) -> OutEigenType { return tanhf(a); };
break;
@@ -83,10 +86,13 @@ int OpTanh<Rank, Dtype>::register_fcn()
}
// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FLOAT);
diff --git a/reference_model/src/ops/comparison.cc b/reference_model/src/ops/comparison.cc
index ab89e24..5930c1a 100644
--- a/reference_model/src/ops/comparison.cc
+++ b/reference_model/src/ops/comparison.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -27,6 +27,7 @@ int OpEqual<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
+ case DType_FP16:
case DType_FLOAT:
case DType_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; };
@@ -43,6 +44,7 @@ int OpGreater<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
+ case DType_FP16:
case DType_FLOAT:
case DType_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; };
@@ -59,6 +61,7 @@ int OpGreaterEqual<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
+ case DType_FP16:
case DType_FLOAT:
case DType_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; };
@@ -71,11 +74,14 @@ int OpGreaterEqual<Rank, Dtype>::register_fcn()
}
// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index 7840450..1ed0be2 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -190,6 +190,7 @@ int OpPad<Rank, Dtype>::eval()
case DType_INT32:
pad_value = (InEigenType)attribute->pad_const_int();
break;
+ case DType_FP16:
case DType_FLOAT:
pad_value = (InEigenType)attribute->pad_const_fp();
break;
@@ -637,42 +638,49 @@ int OpTranspose<Rank, Dtype>::eval()
}
// template explicit instantiation
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
+DEF_INSTANTIATE_RESHAPE(OpReshape, FP16);
DEF_INSTANTIATE_RESHAPE(OpReshape, FLOAT);
DEF_INSTANTIATE_RESHAPE(OpReshape, INT8);
DEF_INSTANTIATE_RESHAPE(OpReshape, INT16);
DEF_INSTANTIATE_RESHAPE(OpReshape, INT32);
DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc
index 30c9511..4ff08be 100644
--- a/reference_model/src/ops/data_nodes.cc
+++ b/reference_model/src/ops/data_nodes.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -89,6 +89,7 @@ int OpIdentity<Rank, Dtype>::eval()
// template explicit instantiation
// note OpConst is not templated
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index a5e1a20..917d56e 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -142,6 +142,7 @@ int OpAdd<Rank, Dtype>::register_fcn()
return static_cast<InEigenType>(res_in_64);
};
break;
+ case DType_FP16:
case DType_FLOAT:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; };
break;
@@ -369,6 +370,7 @@ int OpMaximum<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
+ case DType_FP16:
case DType_FLOAT:
case DType_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
@@ -385,6 +387,7 @@ int OpMinimum<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
+ case DType_FP16:
case DType_FLOAT:
case DType_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
@@ -403,6 +406,7 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn()
switch (InDtype)
{
+ case DType_FP16:
case DType_FLOAT:
this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
break;
@@ -452,6 +456,7 @@ int OpPow<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
+ case DType_FP16:
case DType_FLOAT:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); };
break;
@@ -476,6 +481,7 @@ int OpSub<Rank, Dtype>::register_fcn()
return static_cast<InEigenType>(res_in_64);
};
break;
+ case DType_FP16:
case DType_FLOAT:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; };
break;
@@ -574,6 +580,7 @@ int OpTable<Rank, InDtype>::eval()
}
// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
@@ -609,24 +616,32 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
+// Instantiation of nodes for comparison operators opEqual, opGreater
+// and opGreaterEqual
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP16, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);
diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc
index 72fe5a0..da046a7 100644
--- a/reference_model/src/ops/ewise_ternary.cc
+++ b/reference_model/src/ops/ewise_ternary.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -107,6 +107,7 @@ int OpSelect<0, Dtype>::eval()
}
// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc
index 8ef1e3c..52f5aff 100644
--- a/reference_model/src/ops/ewise_unary.cc
+++ b/reference_model/src/ops/ewise_unary.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -79,6 +79,7 @@ int OpAbs<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FLOAT:
+ case DType_FP16:
case DType_INT32:
this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); };
break;
@@ -111,6 +112,7 @@ int OpCeil<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
+ case DType_FP16:
case DType_FLOAT:
this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); };
break;
@@ -158,6 +160,7 @@ int OpExp<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
+ case DType_FP16:
case DType_FLOAT:
this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); };
break;
@@ -173,6 +176,7 @@ int OpFloor<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
+ case DType_FP16:
case DType_FLOAT:
this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); };
break;
@@ -188,6 +192,7 @@ int OpLog<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
+ case DType_FP16:
case DType_FLOAT:
this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); };
break;
@@ -239,6 +244,7 @@ int OpNegate<Rank, Dtype>::register_fcn()
switch (Dtype)
{
+ case DType_FP16:
case DType_FLOAT:
this->fcn = [](InEigenType a) -> OutEigenType {
InEigenType result = -(a);
@@ -290,6 +296,7 @@ int OpReciprocal<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
+ case DType_FP16:
case DType_FLOAT:
this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; };
break;
@@ -305,6 +312,7 @@ int OpRsqrt<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
+ case DType_FP16:
case DType_FLOAT:
this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); };
break;
@@ -316,6 +324,7 @@ int OpRsqrt<Rank, Dtype>::register_fcn()
}
// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
@@ -323,23 +332,30 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT);
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index a021226..891261b 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -63,7 +63,7 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
if (this->mode == ResizeMode_BILINEAR)
{
- if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FLOAT)
+ if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FLOAT && OutDtype != DType_FP16)
{
printNodeValidationError("OpResize: invalid data type for BILINEAR");
return 1;
@@ -71,7 +71,7 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
}
else
{
- if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FLOAT)
+ if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FLOAT && OutDtype != DType_FP16)
{
printNodeValidationError("OpResize: invalid data type for NEAREST");
return 1;
@@ -224,4 +224,5 @@ DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT32, int16_t);
DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT8, int16_t);
DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT48, int16_t);
DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT16, int16_t);
+DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, float);
DEF_INSTANTIATE_THREE_TYPE(OpResize, FLOAT, FLOAT, float);
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index b6a2e15..fd73eb5 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -48,71 +48,91 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
{
// tensor_ops
case Op_ARGMAX:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
break;
case Op_AVG_POOL2D:
- DEF_FACTORY_ONE_TYPE(OpAvgPool2d, FLOAT);
- DEF_FACTORY_ONE_TYPE(OpAvgPool2d, INT8);
- DEF_FACTORY_ONE_TYPE(OpAvgPool2d, INT16);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FLOAT);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FLOAT, FLOAT);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32);
break;
case Op_CONV2D:
- DEF_FACTORY_TWO_TYPE(OpConv2d, FLOAT, FLOAT);
- DEF_FACTORY_TWO_TYPE(OpConv2d, INT8, INT4);
- DEF_FACTORY_TWO_TYPE(OpConv2d, INT8, INT8);
- DEF_FACTORY_TWO_TYPE(OpConv2d, INT16, INT8);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FLOAT, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT16, INT8, INT48);
break;
case Op_CONV3D:
- DEF_FACTORY_TWO_TYPE(OpConv3d, FLOAT, FLOAT);
- DEF_FACTORY_TWO_TYPE(OpConv3d, INT8, INT4);
- DEF_FACTORY_TWO_TYPE(OpConv3d, INT8, INT8);
- DEF_FACTORY_TWO_TYPE(OpConv3d, INT16, INT8);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FLOAT, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT16, INT8, INT48);
break;
case Op_DEPTHWISE_CONV2D:
- DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT);
- DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT8, INT4);
- DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT8, INT8);
- DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT16, INT8);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FLOAT, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT16, INT8, INT48);
break;
case Op_FULLY_CONNECTED:
- DEF_FACTORY_TWO_TYPE(OpFullyConnected, FLOAT, FLOAT);
- DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT8, INT4);
- DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT8, INT8);
- DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT16, INT8);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP16);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FLOAT, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT4, INT32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT8, INT32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT16, INT8, INT48);
break;
case Op_MATMUL:
- DEF_FACTORY_ONE_TYPE(OpMatMul, FLOAT);
- DEF_FACTORY_ONE_TYPE(OpMatMul, INT8);
- DEF_FACTORY_ONE_TYPE(OpMatMul, INT16);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP16);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FLOAT);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FLOAT, FLOAT);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT8, INT32);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT16, INT48);
break;
case Op_MAX_POOL2D:
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FLOAT);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16);
break;
case Op_TRANSPOSE_CONV2D:
- DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, FLOAT, FLOAT);
- DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT8, INT4);
- DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT8, INT8);
- DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT16, INT8);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FLOAT, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT16, INT8, INT48);
break;
// activation_funcs
case Op_CLAMP:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16);
break;
case Op_SIGMOID:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FLOAT);
break;
case Op_TANH:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FLOAT);
break;
// ewise_binary
case Op_ADD:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
break;
@@ -159,23 +179,28 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
break;
case Op_MAXIMUM:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
break;
case Op_MINIMUM:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
break;
case Op_MUL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
break;
case Op_POW:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
break;
case Op_SUB:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
break;
@@ -186,6 +211,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
// ewise_unary
case Op_ABS:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
break;
@@ -195,38 +221,46 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32);
break;
case Op_CEIL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT);
break;
case Op_CLZ:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
break;
case Op_EXP:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT);
break;
case Op_FLOOR:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT);
break;
case Op_LOG:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT);
break;
case Op_LOGICAL_NOT:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL);
break;
case Op_NEGATE:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
break;
case Op_RECIPROCAL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT);
break;
case Op_RSQRT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT);
break;
// ewise_ternary
case Op_SELECT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
@@ -236,14 +270,17 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
// comparison
case Op_EQUAL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
break;
case Op_GREATER:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
break;
case Op_GREATER_EQUAL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);
break;
@@ -256,27 +293,32 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL);
break;
case Op_REDUCE_MAX:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
break;
case Op_REDUCE_MIN:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
break;
case Op_REDUCE_PRODUCT:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT);
break;
case Op_REDUCE_SUM:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);
break;
// data layout
case Op_CONCAT:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16);
@@ -284,6 +326,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL);
break;
case Op_PAD:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
@@ -291,6 +334,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
break;
case Op_RESHAPE:
+ DEF_FACTORY_RESHAPE(OpReshape, FP16);
DEF_FACTORY_RESHAPE(OpReshape, FLOAT);
DEF_FACTORY_RESHAPE(OpReshape, INT8);
DEF_FACTORY_RESHAPE(OpReshape, INT16);
@@ -298,6 +342,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RESHAPE(OpReshape, BOOL);
break;
case Op_REVERSE:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
@@ -305,6 +350,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
break;
case Op_SLICE:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
@@ -312,6 +358,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
break;
case Op_TILE:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
@@ -320,6 +367,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_TRANSPOSE:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
@@ -331,12 +379,14 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpGather, INT8);
DEF_FACTORY_ONE_TYPE(OpGather, INT16);
DEF_FACTORY_ONE_TYPE(OpGather, INT32);
+ DEF_FACTORY_ONE_TYPE(OpGather, FP16);
DEF_FACTORY_ONE_TYPE(OpGather, FLOAT);
break;
case Op_SCATTER:
DEF_FACTORY_ONE_TYPE(OpScatter, INT8);
DEF_FACTORY_ONE_TYPE(OpScatter, INT16);
DEF_FACTORY_ONE_TYPE(OpScatter, INT32);
+ DEF_FACTORY_ONE_TYPE(OpScatter, FP16);
DEF_FACTORY_ONE_TYPE(OpScatter, FLOAT);
break;
@@ -346,6 +396,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT8, INT8);
DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT48);
DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT16);
+ DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OpResize, FP16, FP16);
DEF_FACTORY_TWO_TYPE_RESIZE_FLOAT(OpResize, FLOAT, FLOAT);
break;
@@ -353,6 +404,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
case Op_CONST:
return new OpConst(sgt, id);
case Op_IDENTITY:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
@@ -368,15 +420,21 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32);
diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h
index 341d7dc..25dfc6e 100644
--- a/reference_model/src/ops/op_factory.h
+++ b/reference_model/src/ops/op_factory.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -62,18 +62,55 @@
return new OP<DType_##DTYPE>(sgt, attribute, id); \
}
+#define DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE, ACCUM_DTYPE) \
+ if (inputDType == DType_##DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \
+ { \
+ return new OP<DType_##DTYPE, DType_##ACCUM_DTYPE>(sgt, attribute, id); \
+ }
+
#define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \
if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \
{ \
return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); \
}
+#define DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE1, DTYPE2, ACCUM_DTYPE) \
+ if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 \
+ && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \
+ { \
+ return new OP<DType_##DTYPE1, DType_##DTYPE2, DType_##ACCUM_DTYPE>(sgt, attribute, id); \
+ } \
+
+// Statement-expression to evaluate accumulate attribute in-place
+#define ACCUM_FROM_ATTRIBUTE(ATTRIBUTE_NAME) \
+ ({ \
+ tosa::DType accumDType = tosa::DType_UNKNOWN; \
+ if (auto p = dynamic_cast<tosa::Tosa##ATTRIBUTE_NAME##Attribute*>(attribute)) \
+ { \
+ auto attr = new tosa::Tosa##ATTRIBUTE_NAME##Attribute(p); \
+ ASSERT_MEM(attr); \
+ accumDType = tosa::EnumValuesDType()[attr->accum_dtype()]; \
+ } \
+ else \
+ { \
+ FATAL_ERROR("Can't initialize Tosa" #ATTRIBUTE_NAME "Attribute.\nPre-initialization " \
+ "of this attribute is required in order to determine the accumulate type."); \
+ } \
+ accumDType; \
+ }) \
+
#define DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OP, DTYPE1, DTYPE2) \
if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
{ \
return new OP<DType_##DTYPE1, DType_##DTYPE2, int16_t>(sgt, attribute, id); \
}
+#define DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OP, DTYPE1, DTYPE2) \
+ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ { \
+ return new OP<DType_##DTYPE1, DType_##DTYPE2, float>(sgt, attribute, id); \
+ }
+
#define DEF_FACTORY_TWO_TYPE_RESIZE_FLOAT(OP, DTYPE1, DTYPE2) \
if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
{ \
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc
index 18fac44..03ee660 100644
--- a/reference_model/src/ops/reduction.cc
+++ b/reference_model/src/ops/reduction.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -158,17 +158,21 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);
diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc
index faf7db9..25174bd 100644
--- a/reference_model/src/ops/scatter_gather.cc
+++ b/reference_model/src/ops/scatter_gather.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -226,9 +226,11 @@ int OpScatter<Dtype>::eval()
DEF_INSTANTIATE_ONE_TYPE(OpGather, INT8);
DEF_INSTANTIATE_ONE_TYPE(OpGather, INT16);
DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32);
+DEF_INSTANTIATE_ONE_TYPE(OpGather, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpGather, FLOAT);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32);
+DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, FLOAT);
diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h
index 2bc7e04..9511c31 100644
--- a/reference_model/src/ops/template_types.h
+++ b/reference_model/src/ops/template_types.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@
#include "tosa_generated.h"
#include <Eigen/CXX11/Tensor>
+#include "half.hpp"
using namespace tosa;
@@ -69,6 +70,12 @@ struct GetEigenType<DType_FLOAT>
using type = float;
};
template <>
+struct GetEigenType<DType_FP16>
+{
+ // NOTE: full precision used
+ using type = float;
+};
+template <>
struct GetEigenType<DType_INT32>
{
using type = int32_t;
@@ -109,6 +116,28 @@ struct GetEigenType<DType_INT16>
using type = int32_t;
};
+/* Get Accumulate Eigen Type:
+Same behaviour as GetEigenType for all DTypes except the
+single specialised case of DType_FP16. */
+template <DType Dtype>
+struct GetAccEigenType;
+template <>
+struct GetAccEigenType<DType_FP16>
+{
+ using type = half_float::half;
+};
+template <DType Dtype>
+struct GetAccEigenType
+{
+ using type = typename GetEigenType<Dtype>::type;
+};
+
+template <DType Dtype>
+struct GetHalfEigenType
+{
+ using type = half_float::half;
+};
+
// Meta function to get number of bits
template <DType T>
struct GetNumBits
@@ -155,6 +184,11 @@ struct GetNumBits<DType_INT48>
{
static constexpr int32_t value = 48;
};
+template <>
+struct GetNumBits<DType_FP16>
+{
+ static constexpr int32_t value = 16;
+};
// Meta function to get quantized min/max in compile time
template <DType T>
@@ -262,6 +296,11 @@ struct GetAccDType<DType_INT16, DType_INT16>
static constexpr DType value = DType_INT48;
};
template <>
+struct GetAccDType<DType_FP16, DType_FP16>
+{
+ static constexpr DType value = DType_FP16;
+};
+template <>
struct GetAccDType<DType_FLOAT, DType_FLOAT>
{
static constexpr DType value = DType_FLOAT;
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 2cd94bb..c617dda 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
#include "tensor_ops.h"
#include "quant_util.h"
#include "template_types.h"
+#include "half.hpp"
using namespace TosaReference;
using namespace Eigen;
@@ -329,8 +330,8 @@ int OpArgMax<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
-OpAvgPool2d<Dtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
+template <DType Dtype, DType AccDtype>
+OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
: GraphNode(sgt_, Op_AVG_POOL2D, id_)
@@ -341,15 +342,15 @@ OpAvgPool2d<Dtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Pool);
}
-template <DType Dtype>
-OpAvgPool2d<Dtype>::~OpAvgPool2d()
+template <DType Dtype, DType AccDtype>
+OpAvgPool2d<Dtype, AccDtype>::~OpAvgPool2d()
{
if (attribute)
delete attribute;
}
-template <DType Dtype>
-int OpAvgPool2d<Dtype>::checkTensorAttributes()
+template <DType Dtype, DType AccDtype>
+int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -385,8 +386,8 @@ int OpAvgPool2d<Dtype>::checkTensorAttributes()
// This calculates the number of padding elements used for each location along an axis
// Average pooling only divides by the number of elements used, not including padding.
// This function uses left/right, but is also used for vertical padding with top/bottom
-template <DType Dtype>
-ETensor1<int32_t> OpAvgPool2d<Dtype>::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right)
+template <DType Dtype, DType AccDtype>
+ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right)
{
ETensor1<int32_t> result(out_size);
@@ -414,8 +415,8 @@ ETensor1<int32_t> OpAvgPool2d<Dtype>::calculate_div_map_1d(int in_size, int out_
// assuming input and output tensor have same scales like tflite reference
// so no need to scale input and output
-template <DType Dtype>
-int OpAvgPool2d<Dtype>::eval()
+template <DType Dtype, DType AccDtype>
+int OpAvgPool2d<Dtype, AccDtype>::eval()
{
int in_batch = this->in->getShape()[0];
int in_height = this->in->getShape()[1];
@@ -439,11 +440,13 @@ int OpAvgPool2d<Dtype>::eval()
int stride_h = this->attribute->stride()[0];
int stride_w = this->attribute->stride()[1];
+ tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
+
DEBUG_INFO(OP,
"perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
- "stride=[%d,%d], pad=[%d,%d,%d,%d]",
+ "stride=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s",
in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_h,
- kernel_w, stride_h, stride_w, pad_top, pad_bottom, pad_left, pad_right);
+ kernel_w, stride_h, stride_w, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
Eigen::array<Eigen::Index, 2> im2col_input_dims;
im2col_input_dims[0] = kernel_h * kernel_w;
@@ -509,8 +512,7 @@ int OpAvgPool2d<Dtype>::eval()
.contract(div_map_w.reshape(Eigen::array<Eigen::Index, 2>{ 1, out_width }), contract_dims)
.reshape(Eigen::array<Eigen::Index, 4>{ 1, out_height, out_width, 1 })
.broadcast(bcast);
-
- if (Dtype != DType_FLOAT)
+ if (Dtype != DType_FLOAT && Dtype != DType_FP16)
{
try
{
@@ -531,14 +533,15 @@ int OpAvgPool2d<Dtype>::eval()
}
else
{
+ // Case for float-type resizes
this->out->getTensor() = (sum / div_map.template cast<AccEigenType>()).template cast<OutEigenType>();
}
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype>
-OpConv2d<InDtype, WeightDtype>::OpConv2d(SubgraphTraverser* sgt_,
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+OpConv2d<InDtype, WeightDtype, AccDtype>::OpConv2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
: GraphNode(sgt_, Op_CONV2D, id_)
@@ -549,15 +552,15 @@ OpConv2d<InDtype, WeightDtype>::OpConv2d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Conv);
}
-template <DType InDtype, DType WeightDtype>
-OpConv2d<InDtype, WeightDtype>::~OpConv2d()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+OpConv2d<InDtype, WeightDtype, AccDtype>::~OpConv2d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype>
-int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+int OpConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -574,12 +577,12 @@ int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes()
}
ERROR_IF(outputs[0]->getDtype() != AccDtype,
- "OpConv2d: Output data type not supported for this configuration of operator");
+ "OpConv2d: Output data type not supported for this configuration of operator");
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
- output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
std::string msg;
if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
@@ -593,8 +596,8 @@ int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype>
-int OpConv2d<InDtype, WeightDtype>::eval()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+int OpConv2d<InDtype, WeightDtype, AccDtype>::eval()
{
int in_batch = this->input->getShape()[0];
int in_height = this->input->getShape()[1];
@@ -630,12 +633,14 @@ int OpConv2d<InDtype, WeightDtype>::eval()
int dilation_h = this->attribute->dilation()[0];
int dilation_w = this->attribute->dilation()[1];
+ tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
+
DEBUG_INFO(OP,
"perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], "
- "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
+ "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s",
in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch,
out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, pad_top,
- pad_bottom, pad_left, pad_right);
+ pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
// GEMM-conv2d, left matrix is input, right matrix is weight
Eigen::array<Eigen::Index, 2> im2col_input_dims;
@@ -695,33 +700,33 @@ int OpConv2d<InDtype, WeightDtype>::eval()
// transpose and reshape weight from [OC, H, W, IC] to [H * W * IC, OC]
ETensor2<WeightEigenType> im2col_weight =
- weight_val.shuffle(Eigen::array<Eigen::Index, 4>({ 1, 2, 3, 0 })).reshape(im2col_weight_dims);
+ weight_val.shuffle(Eigen::array<Eigen::Index, 4>({ 1, 2, 3, 0 })).reshape(im2col_weight_dims);
// don't need to apply bias_multiplier ( * bias_scale and >> bias_shift) since tflite already scale it
// and reshaped from [C] to [1, C], and broadcast to [N * H * W, C]
- ETensor2<AccEigenType> bias_2d = this->bias->getTensor().reshape(bias_reshaped_dims).broadcast(bias_bcast_dims);
+ ETensor2<OutEigenType> bias_2d = (this->bias->getTensor().reshape(bias_reshaped_dims).broadcast(bias_bcast_dims)).template cast<OutEigenType>();
// output matrix is [N * H * W, C]
- ETensor2<AccEigenType> contracted_result =
- im2col_input.template cast<AccEigenType>().contract(im2col_weight.template cast<AccEigenType>(), contract_dims);
+ ETensor2<OutEigenType> contracted_result =
+ (im2col_input.template cast<AccEigenType>().contract(im2col_weight.template cast<AccEigenType>(), contract_dims)).template cast<OutEigenType>();
// adding bias
- ETensor2<AccEigenType> biased_output = contracted_result + bias_2d.template cast<AccEigenType>();
+ ETensor2<OutEigenType> biased_output = contracted_result + bias_2d;
// reshape back to [N, H, W, C]
this->output->getTensor() = biased_output.reshape(col2im_output_dims);
if (AccDtype == DType_INT48)
{
- this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
- this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
}
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype>
-OpConv3d<InDtype, WeightDtype>::OpConv3d(SubgraphTraverser* sgt_,
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+OpConv3d<InDtype, WeightDtype, AccDtype>::OpConv3d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
: GraphNode(sgt_, Op_CONV3D, id_)
@@ -732,15 +737,15 @@ OpConv3d<InDtype, WeightDtype>::OpConv3d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Conv);
}
-template <DType InDtype, DType WeightDtype>
-OpConv3d<InDtype, WeightDtype>::~OpConv3d()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+OpConv3d<InDtype, WeightDtype, AccDtype>::~OpConv3d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype>
-int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+int OpConv3d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -757,12 +762,12 @@ int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes()
}
ERROR_IF(outputs[0]->getDtype() != AccDtype,
- "OpConv3d: Output data type not supported for this configuration of operator");
+ "OpConv3d: Output data type not supported for this configuration of operator");
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
- output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
std::string msg;
if (check_conv_attribute(attribute, 3 /* conv_dimension */, input->getShape(), output->getShape(),
@@ -776,8 +781,8 @@ int OpConv3d<InDtype, WeightDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype>
-int OpConv3d<InDtype, WeightDtype>::eval()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+int OpConv3d<InDtype, WeightDtype, AccDtype>::eval()
{
int in_batch = this->input->getShape()[0];
int in_depth = this->input->getShape()[1];
@@ -821,13 +826,15 @@ int OpConv3d<InDtype, WeightDtype>::eval()
int dilation_h = this->attribute->dilation()[1];
int dilation_w = this->attribute->dilation()[2];
+ tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
+
DEBUG_INFO(
OP,
"perform OpConv3d, input.shape=[%d,%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d,%d], output.shape=[%d,%d,%d,%d,%d], "
- "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d]",
+ "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d], accum_dtype=%s",
in_batch, in_depth, in_height, in_width, in_channels, f_out_channels, f_depth, f_height, f_width, f_in_channels,
out_batch, out_depth, out_height, out_width, out_channels, stride_d, stride_h, stride_w, dilation_d, dilation_h,
- dilation_w, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right);
+ dilation_w, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
Eigen::array<std::pair<int32_t, int32_t>, 5> pad;
pad[0] = std::make_pair(0, 0);
@@ -860,7 +867,7 @@ int OpConv3d<InDtype, WeightDtype>::eval()
this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
// 2. direct convolution
- AccEigenType acc = 0;
+ AccEigenType acc(0.0);
int d_idx, h_idx, w_idx;
for (int ob = 0; ob < out_batch; ob++)
@@ -874,7 +881,7 @@ int OpConv3d<InDtype, WeightDtype>::eval()
for (int oc = 0; oc < out_channels; oc++)
{
// Initialize accumulator with bias value
- acc = this->output->getTensor()(ob, od, oh, ow, oc);
+ acc = (AccEigenType)this->output->getTensor()(ob, od, oh, ow, oc);
for (int fd = 0; fd < f_depth; fd++)
{
d_idx = od * stride_d + fd * dilation_d;
@@ -892,7 +899,7 @@ int OpConv3d<InDtype, WeightDtype>::eval()
}
}
}
- this->output->getTensor()(ob, od, oh, ow, oc) = acc;
+ this->output->getTensor()(ob, od, oh, ow, oc) = (OutEigenType)acc;
}
}
}
@@ -901,15 +908,15 @@ int OpConv3d<InDtype, WeightDtype>::eval()
if (AccDtype == DType_INT48)
{
- this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
- this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
}
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype>
-OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
: GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
@@ -920,15 +927,15 @@ OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(SubgraphTraverser* sg
INIT_ATTRIBUTE(Conv);
}
-template <DType InDtype, DType WeightDtype>
-OpDepthwiseConv2d<InDtype, WeightDtype>::~OpDepthwiseConv2d()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::~OpDepthwiseConv2d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype>
-int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -945,12 +952,12 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes()
}
ERROR_IF(outputs[0]->getDtype() != AccDtype,
- "OpDepthwiseConv2d: Output data type not supported for this configuration of operator");
+ "OpDepthwiseConv2d: Output data type not supported for this configuration of operator");
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
- output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
std::string msg;
if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(),
@@ -964,8 +971,8 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype>
-int OpDepthwiseConv2d<InDtype, WeightDtype>::eval()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::eval()
{
int in_batch = this->input->getShape()[0];
int in_height = this->input->getShape()[1];
@@ -1002,12 +1009,14 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::eval()
int dilation_h = this->attribute->dilation()[0];
int dilation_w = this->attribute->dilation()[1];
+ tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
+
DEBUG_INFO(OP,
"perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
- "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
+ "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s",
in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch,
out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, pad_top,
- pad_bottom, pad_left, pad_right);
+ pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
pad[0] = std::make_pair(0, 0);
@@ -1061,9 +1070,10 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::eval()
{
for (int fw = 0; fw < f_width; fw++)
{
+ // Perform multiplication in AccEigenType then cast to OutEigenType
this->output->getTensor()(ob, oh, ow, ic * f_multiplier + cm) +=
- ((AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh, ic) *
- (AccEigenType)weight_val(fh, fw, ic, cm));
+ (OutEigenType)((AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh, ic) *
+ (AccEigenType)weight_val(fh, fw, ic, cm));
}
}
}
@@ -1074,15 +1084,15 @@ int OpDepthwiseConv2d<InDtype, WeightDtype>::eval()
if (AccDtype == DType_INT48)
{
- this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
- this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
}
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype>
-OpFullyConnected<InDtype, WeightDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+OpFullyConnected<InDtype, WeightDtype, AccDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
: GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
@@ -1093,15 +1103,15 @@ OpFullyConnected<InDtype, WeightDtype>::OpFullyConnected(SubgraphTraverser* sgt_
INIT_ATTRIBUTE(FullyConnected);
}
-template <DType InDtype, DType WeightDtype>
-OpFullyConnected<InDtype, WeightDtype>::~OpFullyConnected()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+OpFullyConnected<InDtype, WeightDtype, AccDtype>::~OpFullyConnected()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype>
-int OpFullyConnected<InDtype, WeightDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+int OpFullyConnected<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -1128,9 +1138,9 @@ int OpFullyConnected<InDtype, WeightDtype>::checkTensorAttributes()
}
ERROR_IF(outputs[0]->getDtype() != AccDtype,
- "OpFullyConnected: Output data type not supported for this configuration of operator");
+ "OpFullyConnected: Output data type not supported for this configuration of operator");
- output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpFullyConnected: Input zeropoint must be zero for non int8_t data");
ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpFullyConnected: Weight zeropoint must be zero for non int8_t data");
@@ -1138,8 +1148,8 @@ int OpFullyConnected<InDtype, WeightDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype>
-int OpFullyConnected<InDtype, WeightDtype>::eval()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+int OpFullyConnected<InDtype, WeightDtype, AccDtype>::eval()
{
typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
@@ -1163,19 +1173,19 @@ int OpFullyConnected<InDtype, WeightDtype>::eval()
}
this->output->getTensor() =
- input_val.template cast<AccEigenType>().contract(weight_val.template cast<AccEigenType>(), dims) +
- this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast);
+ input_val.template cast<AccEigenType>().contract(weight_val.template cast<AccEigenType>(), dims).template cast<OutEigenType>() +
+ this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast);
if (AccDtype == DType_INT48)
{
- this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
- this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
}
return GraphNode::eval();
}
-template <DType Dtype>
-OpMatMul<Dtype>::OpMatMul(SubgraphTraverser* sgt_,
+template <DType Dtype, DType AccDtype>
+OpMatMul<Dtype, AccDtype>::OpMatMul(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
: GraphNode(sgt_, Op_MATMUL, id_)
@@ -1186,15 +1196,15 @@ OpMatMul<Dtype>::OpMatMul(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(MatMul);
}
-template <DType Dtype>
-OpMatMul<Dtype>::~OpMatMul()
+template <DType Dtype, DType AccDtype>
+OpMatMul<Dtype, AccDtype>::~OpMatMul()
{
if (attribute)
delete attribute;
}
-template <DType Dtype>
-int OpMatMul<Dtype>::checkTensorAttributes()
+template <DType Dtype, DType AccDtype>
+int OpMatMul<Dtype, AccDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -1205,11 +1215,11 @@ int OpMatMul<Dtype>::checkTensorAttributes()
}
ERROR_IF(outputs[0]->getDtype() != AccDtype,
- "OpMatMul: Output data type not supported for this configuration of operator");
+ "OpMatMul: Output data type not supported for this configuration of operator");
a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
- output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
ASSERT_MEM(a && b && output);
@@ -1255,8 +1265,8 @@ int OpMatMul<Dtype>::checkTensorAttributes()
return 0;
}
-template <DType Dtype>
-int OpMatMul<Dtype>::eval()
+template <DType Dtype, DType AccDtype>
+int OpMatMul<Dtype, AccDtype>::eval()
{
typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
@@ -1289,22 +1299,22 @@ int OpMatMul<Dtype>::eval()
TInRank2 b_rank2_val = b_val.slice(b_begin_array, b_size_array).reshape(b_rank2_shape);
TAccRank2 output_rank2_val =
a_rank2_val.template cast<AccEigenType>().contract(b_rank2_val.template cast<AccEigenType>(), dims);
- TAcc output_rank3_val = output_rank2_val.reshape(output_rank3_shape);
+ TOut output_rank3_val = output_rank2_val.reshape(output_rank3_shape).template cast<OutEigenType>();
if (i == 0)
{
this->output->getTensor() = output_rank3_val;
}
else
{
- TAcc temp = this->output->getTensor().concatenate(output_rank3_val, 0);
+ TOut temp = this->output->getTensor().concatenate(output_rank3_val, 0);
this->output->getTensor() = temp;
}
}
if (AccDtype == DType_INT48)
{
- this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
- this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
}
return GraphNode::eval();
@@ -1442,8 +1452,8 @@ int OpMaxPool2d<Dtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype>
-OpTransposeConv2d<InDtype, WeightDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
: GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
@@ -1454,15 +1464,15 @@ OpTransposeConv2d<InDtype, WeightDtype>::OpTransposeConv2d(SubgraphTraverser* sg
INIT_ATTRIBUTE(TransposeConv);
}
-template <DType InDtype, DType WeightDtype>
-OpTransposeConv2d<InDtype, WeightDtype>::~OpTransposeConv2d()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::~OpTransposeConv2d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype>
-int OpTransposeConv2d<InDtype, WeightDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -1473,12 +1483,12 @@ int OpTransposeConv2d<InDtype, WeightDtype>::checkTensorAttributes()
}
ERROR_IF(outputs[0]->getDtype() != AccDtype,
- "OpTransposeConv2d: Output data type not supported for this configuration of operator");
+ "OpTransposeConv2d: Output data type not supported for this configuration of operator");
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
- output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
if (attribute->out_pad().size() != 4)
{
@@ -1556,8 +1566,8 @@ int OpTransposeConv2d<InDtype, WeightDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype>
-int OpTransposeConv2d<InDtype, WeightDtype>::eval()
+template <DType InDtype, DType WeightDtype, DType AccDtype>
+int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::eval()
{
int in_batch = this->input->getShape()[0];
int in_height = this->input->getShape()[1];
@@ -1584,6 +1594,8 @@ int OpTransposeConv2d<InDtype, WeightDtype>::eval()
int stride_h = this->attribute->stride()[0];
int stride_w = this->attribute->stride()[1];
+ tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
+
ERROR_IF(in_batch != out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
ERROR_IF(f_in_channels != in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", f_in_channels,
in_channels);
@@ -1594,10 +1606,10 @@ int OpTransposeConv2d<InDtype, WeightDtype>::eval()
DEBUG_INFO(OP,
"perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
- "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d]",
+ "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d], accum_dtype=%s",
in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels,
out_batch, out_height, out_width, out_channels, stride_h, stride_w, out_pad_top,
- out_pad_bottom, out_pad_left, out_pad_right);
+ out_pad_bottom, out_pad_left, out_pad_right, EnumNamesDType()[accum_dtype]);
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
@@ -1645,8 +1657,8 @@ int OpTransposeConv2d<InDtype, WeightDtype>::eval()
if ((out_x >= 0 && out_x < out_width) && (out_y >= 0 && out_y < out_height))
{
this->output->getTensor()(ob, out_y, out_x, oc) +=
- ((AccEigenType)input_val(ob, ih, iw, ic) *
- (AccEigenType)weight_val(oc, fh, fw, ic));
+ (OutEigenType) ((AccEigenType)input_val(ob, ih, iw, ic) *
+ (AccEigenType)weight_val(oc, fh, fw, ic));
}
}
}
@@ -1658,51 +1670,68 @@ int OpTransposeConv2d<InDtype, WeightDtype>::eval()
if (AccDtype == DType_INT48)
{
- this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
- this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
}
return GraphNode::eval();
}
// template explicit instantiation
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
-DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, FLOAT)
-DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, INT8)
-DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, INT16)
-
-DEF_INSTANTIATE_TWO_TYPE(OpConv2d, FLOAT, FLOAT);
-DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT8, INT4);
-DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT8, INT8);
-DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT16, INT8);
-
-DEF_INSTANTIATE_TWO_TYPE(OpConv3d, FLOAT, FLOAT);
-DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT8, INT4);
-DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT8, INT8);
-DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT16, INT8);
-
-DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT);
-DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT8, INT4);
-DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT8, INT8);
-DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT16, INT8);
-
-DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, FLOAT, FLOAT);
-DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, INT8, INT4);
-DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, INT8, INT8);
-DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, INT16, INT8);
-
-DEF_INSTANTIATE_ONE_TYPE(OpMatMul, INT8);
-DEF_INSTANTIATE_ONE_TYPE(OpMatMul, INT16);
-DEF_INSTANTIATE_ONE_TYPE(OpMatMul, FLOAT);
-
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP16);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FLOAT);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FLOAT, FLOAT);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT8, INT32);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT16, INT32);
+
+ // [in_t, weight_t, acc_t]
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP16);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FLOAT, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT4, INT32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT8, INT32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT16, INT8, INT48);
+
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP16);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FLOAT, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT4, INT32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT8, INT32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT16, INT8, INT48);
+
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP16);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FLOAT, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT4, INT32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT8, INT32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT16, INT8, INT48);
+
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP16);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FLOAT, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT4, INT32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT8, INT32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT16, INT8, INT48);
+
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT8, INT32);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT16, INT48);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP16);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FLOAT);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FLOAT, FLOAT);
+
+DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FLOAT);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
-DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, FLOAT, FLOAT);
-DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, INT8, INT4);
-DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, INT8, INT8);
-DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, INT16, INT8);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP16);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FLOAT, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT4, INT32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT8, INT32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT16, INT8, INT48);
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index 24eadeb..fd6dd25 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -45,7 +45,7 @@ protected:
TosaReference::TensorTemplate<TOut>* output;
};
-template <DType Dtype>
+template <DType Dtype, DType AccDtype>
class OpAvgPool2d : public GraphNode
{
public:
@@ -55,9 +55,8 @@ public:
virtual int checkTensorAttributes();
virtual int eval();
- static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value;
using InEigenType = typename GetEigenType<Dtype>::type;
- using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
using OutEigenType = typename GetEigenType<Dtype>::type;
using TIn = Eigen::Tensor<InEigenType, 4>;
using TOut = Eigen::Tensor<OutEigenType, 4>;
@@ -75,7 +74,7 @@ protected:
ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right);
};
-template <DType InDtype, DType WeightDtype>
+template <DType InDtype, DType WeightDtype, DType AccDtype>
class OpConv2d : public GraphNode
{
public:
@@ -85,15 +84,14 @@ public:
virtual int checkTensorAttributes() final;
virtual int eval() final;
- static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
-
using InEigenType = typename GetEigenType<InDtype>::type;
using WeightEigenType = typename GetEigenType<WeightDtype>::type;
- using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
+ using OutEigenType = typename GetEigenType<AccDtype>::type;
using TIn = Eigen::Tensor<InEigenType, 4>;
using TWeight = Eigen::Tensor<WeightEigenType, 4>;
- using TBias = Eigen::Tensor<AccEigenType, 1>;
- using TAcc = Eigen::Tensor<AccEigenType, 4>;
+ using TBias = Eigen::Tensor<OutEigenType, 1>;
+ using TOut = Eigen::Tensor<OutEigenType, 4>;
static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
@@ -102,11 +100,11 @@ protected:
TosaReference::TensorTemplate<TIn>* input;
TosaReference::TensorTemplate<TWeight>* weight;
TosaReference::TensorTemplate<TBias>* bias;
- TosaReference::TensorTemplate<TAcc>* output;
+ TosaReference::TensorTemplate<TOut>* output;
tosa::TosaConvAttribute* attribute;
};
-template <DType InDtype, DType WeightDtype>
+template <DType InDtype, DType WeightDtype, DType AccDtype>
class OpConv3d : public GraphNode
{
public:
@@ -116,15 +114,14 @@ public:
virtual int checkTensorAttributes() final;
virtual int eval() final;
- static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
-
using InEigenType = typename GetEigenType<InDtype>::type;
using WeightEigenType = typename GetEigenType<WeightDtype>::type;
- using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
+ using OutEigenType = typename GetEigenType<AccDtype>::type;
using TIn = Eigen::Tensor<InEigenType, 5>;
using TWeight = Eigen::Tensor<WeightEigenType, 5>;
- using TBias = Eigen::Tensor<AccEigenType, 1>;
- using TAcc = Eigen::Tensor<AccEigenType, 5>;
+ using TBias = Eigen::Tensor<OutEigenType, 1>;
+ using TOut = Eigen::Tensor<OutEigenType, 5>;
static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
@@ -133,11 +130,11 @@ protected:
TosaReference::TensorTemplate<TIn>* input;
TosaReference::TensorTemplate<TWeight>* weight;
TosaReference::TensorTemplate<TBias>* bias;
- TosaReference::TensorTemplate<TAcc>* output;
+ TosaReference::TensorTemplate<TOut>* output;
tosa::TosaConvAttribute* attribute;
};
-template <DType InDtype, DType WeightDtype>
+template <DType InDtype, DType WeightDtype, DType AccDtype>
class OpDepthwiseConv2d : public GraphNode
{
public:
@@ -147,15 +144,14 @@ public:
virtual int checkTensorAttributes() final;
virtual int eval() final;
- static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
-
using InEigenType = typename GetEigenType<InDtype>::type;
using WeightEigenType = typename GetEigenType<WeightDtype>::type;
- using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
+ using OutEigenType = typename GetEigenType<AccDtype>::type;
using TIn = Eigen::Tensor<InEigenType, 4>;
using TWeight = Eigen::Tensor<WeightEigenType, 4>;
- using TBias = Eigen::Tensor<AccEigenType, 1>;
- using TAcc = Eigen::Tensor<AccEigenType, 4>;
+ using TBias = Eigen::Tensor<OutEigenType, 1>;
+ using TOut = Eigen::Tensor<OutEigenType, 4>;
static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
@@ -164,11 +160,11 @@ protected:
TosaReference::TensorTemplate<TIn>* input;
TosaReference::TensorTemplate<TWeight>* weight;
TosaReference::TensorTemplate<TBias>* bias;
- TosaReference::TensorTemplate<TAcc>* output;
+ TosaReference::TensorTemplate<TOut>* output;
tosa::TosaConvAttribute* attribute;
};
-template <DType InDtype, DType WeightDtype>
+template <DType InDtype, DType WeightDtype, DType AccDtype>
class OpFullyConnected : public GraphNode
{
public:
@@ -178,14 +174,14 @@ public:
virtual int checkTensorAttributes() final;
virtual int eval() final;
- static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
using InEigenType = typename GetEigenType<InDtype>::type;
using WeightEigenType = typename GetEigenType<WeightDtype>::type;
- using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
+ using OutEigenType = typename GetEigenType<AccDtype>::type;
using TIn = Eigen::Tensor<InEigenType, 2>;
using TWeight = Eigen::Tensor<WeightEigenType, 2>;
- using TBias = Eigen::Tensor<AccEigenType, 1>;
- using TAcc = Eigen::Tensor<AccEigenType, 2>;
+ using TBias = Eigen::Tensor<OutEigenType, 1>;
+ using TOut = Eigen::Tensor<OutEigenType, 2>;
static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
@@ -194,12 +190,12 @@ protected:
TosaReference::TensorTemplate<TIn>* input;
TosaReference::TensorTemplate<TWeight>* weight;
TosaReference::TensorTemplate<TBias>* bias;
- TosaReference::TensorTemplate<TAcc>* output;
+ TosaReference::TensorTemplate<TOut>* output;
tosa::TosaFullyConnectedAttribute* attribute;
};
-template <DType Dtype>
+template <DType Dtype, DType AccDtype>
class OpMatMul : public GraphNode
{
public:
@@ -209,11 +205,11 @@ public:
virtual int checkTensorAttributes() final;
virtual int eval() final;
- static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value;
using InEigenType = typename GetEigenType<Dtype>::type;
- using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
+ using OutEigenType = typename GetEigenType<AccDtype>::type;
using TIn = Eigen::Tensor<InEigenType, 3>;
- using TAcc = Eigen::Tensor<AccEigenType, 3>;
+ using TOut = Eigen::Tensor<OutEigenType, 3>;
using TInRank2 = Eigen::Tensor<InEigenType, 2>;
using TAccRank2 = Eigen::Tensor<AccEigenType, 2>;
static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
@@ -222,7 +218,7 @@ public:
protected:
TosaReference::TensorTemplate<TIn>* a;
TosaReference::TensorTemplate<TIn>* b;
- TosaReference::TensorTemplate<TAcc>* output;
+ TosaReference::TensorTemplate<TOut>* output;
int64_t N;
int64_t H;
int64_t W;
@@ -252,7 +248,7 @@ protected:
tosa::TosaPoolAttribute* attribute;
};
-template <DType InDtype, DType WeightDtype>
+template <DType InDtype, DType WeightDtype, DType AccDtype>
class OpTransposeConv2d : public GraphNode
{
public:
@@ -262,15 +258,14 @@ public:
virtual int checkTensorAttributes() final;
virtual int eval() final;
- static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
-
using InEigenType = typename GetEigenType<InDtype>::type;
using WeightEigenType = typename GetEigenType<WeightDtype>::type;
- using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
+ using OutEigenType = typename GetEigenType<AccDtype>::type;
using TIn = Eigen::Tensor<InEigenType, 4>;
using TWeight = Eigen::Tensor<WeightEigenType, 4>;
- using TBias = Eigen::Tensor<AccEigenType, 1>;
- using TAcc = Eigen::Tensor<AccEigenType, 4>;
+ using TBias = Eigen::Tensor<OutEigenType, 1>;
+ using TOut = Eigen::Tensor<OutEigenType, 4>;
static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
@@ -279,7 +274,7 @@ protected:
TosaReference::TensorTemplate<TIn>* input;
TosaReference::TensorTemplate<TWeight>* weight;
TosaReference::TensorTemplate<TBias>* bias;
- TosaReference::TensorTemplate<TAcc>* output;
+ TosaReference::TensorTemplate<TOut>* output;
TosaTransposeConvAttribute* attribute;
};
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 52de2e4..50e710a 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -17,6 +17,7 @@
#include "quant_util.h"
#include "template_types.h"
#include <cmath>
+#include "half.hpp"
using namespace TosaReference;
using namespace Eigen;
@@ -287,6 +288,30 @@ CastHelper<DType_BOOL, OutDtype>::CastHelper()
}
template <DType InDtype>
+CastHelper<InDtype, DType_FP16>::CastHelper()
+{
+ fcn = [](InEigenType in) -> float {
+ half_float::half out = half_float::half_cast<half_float::half, InEigenType>(in); // Cast to half_float
+ return half_float::half_cast<float, half_float::half>(out); // Cast to float (underlying FP16 EigenType)
+ };
+}
+
+template <DType OutDtype>
+CastHelper<DType_FP16, OutDtype>::CastHelper()
+{
+ // Assuming InEigenType = float.
+ fcn = [](float in) -> OutEigenType {
+ // Perform initial rounding in half-precision then cast back to float
+ half_float::half h = half_float::half_cast<half_float::half, float>(in);
+ h = std::round(h);
+ OutEigenType out = half_float::half_cast<float, half_float::half>(h);
+ out = std::max<OutEigenType>(out, OutMin);
+ out = std::min<OutEigenType>(out, OutMax);
+ return out;
+ };
+}
+
+template <DType InDtype>
CastHelper<InDtype, DType_FLOAT>::CastHelper()
{
fcn = [](InEigenType in) -> float {
@@ -313,15 +338,21 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32);
diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h
index 53470d1..5f197cf 100644
--- a/reference_model/src/ops/type_conversion.h
+++ b/reference_model/src/ops/type_conversion.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -101,6 +101,42 @@ private:
};
template <DType InDtype>
+class CastHelper<InDtype, DType_FP16>
+{
+public:
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<DType_FP16>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <DType OutDtype>
+class CastHelper<DType_FP16, OutDtype>
+{
+public:
+ using InEigenType = typename GetEigenType<DType_FP16>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
+ static constexpr int32_t OutMax = GetQMax<OutDtype>::value;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <DType InDtype>
class CastHelper<InDtype, DType_FLOAT>
{
public:
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index d0cc6cf..386f0e5 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -385,6 +385,14 @@ int SubgraphTraverser::allocateTensor()
tensor->setTensorValueInt64(i64_data.size(), i64_data.data());
}
break;
+ case DType_FP16:
+ {
+ // Interpret f16 data as float
+ std::vector<float> f16_data;
+ TosaSerializationHandler::ConvertU8toF16(ts->GetData(), tensor->getElementCount(), f16_data);
+ tensor->setTensorValueFloat(f16_data.size(), f16_data.data());
+ }
+ break;
case DType_FLOAT:
{
std::vector<float> fp32_data;
@@ -693,7 +701,7 @@ int SubgraphTraverser::validateGraph()
DType dtype = currTensor->getDtype();
// Float-point disallowed
- if (dtype == DType_FLOAT)
+ if (dtype == DType_FLOAT || dtype == DType_FP16)
{
WARNING("SubgraphTraverser::validateGraph(): TOSA Base Inference profile selected: All floating point "
"disabled, but %s tensor %s found\n",
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index 7cbeb13..cbe12a9 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -15,6 +15,7 @@
#include "tensor.h"
#include "arith_util.h"
+#include "half.hpp"
using namespace TosaReference;
using namespace Eigen;
@@ -84,6 +85,7 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
{
uint32_t elements = getElementCount();
float* fdatabuf = nullptr;
+ half_float::half* f16databuf = nullptr;
int32_t* i32databuf = nullptr;
int64_t* i64databuf = nullptr;
bool* bdatabuf = nullptr;
@@ -97,6 +99,14 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
nperror = NumpyUtilities::readFromNpyFile(filename, elements, fdatabuf);
break;
+ case DType_FP16:
+ f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements);
+ ASSERT_MEM(f16databuf);
+ fdatabuf = (float*)calloc(sizeof(float), elements);
+ ASSERT_MEM(fdatabuf);
+
+ nperror = NumpyUtilities::readFromNpyFile(filename, elements, f16databuf);
+ break;
case DType_INT32:
case DType_UINT8:
case DType_INT4:
@@ -146,9 +156,17 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
switch (getDtype())
{
+ case DType_FP16:
+ // Convert from fp16 to fp32
+ for (uint32_t i=0; i < elements; i++) {
+ fdatabuf[i] = half_float::half_cast<float, half_float::half>(f16databuf[i]);
+ }
+ // Fall through to DType_FLOAT case
case DType_FLOAT:
if (setTensorValueFloat(elements, fdatabuf))
{
+ if (f16databuf)
+ free(f16databuf);
free(fdatabuf);
return 1;
}
@@ -187,6 +205,8 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
if (fdatabuf)
free(fdatabuf);
+ if (f16databuf)
+ free(f16databuf);
if (i32databuf)
free(i32databuf);
if (i64databuf)
@@ -200,11 +220,12 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
int TosaReference::Tensor::writeToNpyFile(const char* filename) const
{
float* fdatabuf = nullptr;
+ half_float::half* f16databuf = nullptr;
int32_t* i32databuf = nullptr;
int64_t* i64databuf = nullptr;
bool* bdatabuf = nullptr;
NumpyUtilities::NPError nperror;
- int elements = getElementCount();
+ uint32_t elements = getElementCount();
switch (getDtype())
{
@@ -222,6 +243,27 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const
free(fdatabuf);
break;
+ case DType_FP16:
+ fdatabuf = (float*)calloc(sizeof(float), elements);
+ ASSERT_MEM(fdatabuf);
+ f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements);
+ ASSERT_MEM(f16databuf);
+
+ if (getTensorValueFloat(elements, fdatabuf))
+ {
+ free(fdatabuf);
+ free(f16databuf);
+ return 1;
+ }
+ // Convert fp32 to fp16
+ for (uint32_t i=0; i < elements; i++) {
+ f16databuf[i] = half_float::half_cast<half_float::half, float>(fdatabuf[i]);
+ }
+ nperror = NumpyUtilities::writeToNpyFile(filename, shape, f16databuf);
+
+ free(fdatabuf);
+ free(f16databuf);
+ break;
case DType_INT32:
case DType_UINT8:
case DType_INT4:
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index 6b7d5f1..78a210e 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -643,6 +643,7 @@ public:
switch (tensorDtype_)
{
case DType_FLOAT:
+ case DType_FP16:
switch (rank)
{
case 0: