From 8b39043c70332e1e4c95ee6a9616aec40dd3baf1 Mon Sep 17 00:00:00 2001 From: James Ward Date: Fri, 12 Aug 2022 20:48:56 +0100 Subject: Reference model changes for fp16 support Change-Id: I72f21fcfa153046274969d327313e3349981dbe6 Signed-off-by: James Ward --- reference_model/CMakeLists.txt | 10 +- reference_model/src/graph_node.h | 19 +- reference_model/src/ops/activation_funcs.cc | 8 +- reference_model/src/ops/comparison.cc | 8 +- reference_model/src/ops/data_layout.cc | 10 +- reference_model/src/ops/data_nodes.cc | 3 +- reference_model/src/ops/ewise_binary.cc | 17 +- reference_model/src/ops/ewise_ternary.cc | 3 +- reference_model/src/ops/ewise_unary.cc | 18 +- reference_model/src/ops/image.cc | 7 +- reference_model/src/ops/op_factory.cc | 112 +++++++--- reference_model/src/ops/op_factory.h | 39 +++- reference_model/src/ops/reduction.cc | 6 +- reference_model/src/ops/scatter_gather.cc | 4 +- reference_model/src/ops/template_types.h | 41 +++- reference_model/src/ops/tensor_ops.cc | 315 +++++++++++++++------------- reference_model/src/ops/tensor_ops.h | 81 ++++--- reference_model/src/ops/type_conversion.cc | 33 ++- reference_model/src/ops/type_conversion.h | 38 +++- reference_model/src/subgraph_traverser.cc | 12 +- reference_model/src/tensor.cc | 46 +++- reference_model/src/tensor.h | 3 +- 22 files changed, 597 insertions(+), 236 deletions(-) (limited to 'reference_model') diff --git a/reference_model/CMakeLists.txt b/reference_model/CMakeLists.txt index a790968..04b0db5 100644 --- a/reference_model/CMakeLists.txt +++ b/reference_model/CMakeLists.txt @@ -43,7 +43,7 @@ else() set(SERIALIZATION_DIR "../thirdparty/serialization_lib/") endif() -# If Flatbuffers or Eigen path isn't specified, set to thirdparty directory. +# If Flatbuffers, Eigen, Half path isn't specified, set to thirdparty directory. if(NOT FLATBUFFERS_DIR) set(FLATBUFFERS_DIR "../thirdparty/serialization_lib/third_party/flatbuffers/") endif() @@ -52,6 +52,10 @@ if(NOT EIGEN_DIR) set(EIGEN_DIR "../thirdparty/eigen/") endif() +if(NOT HALF_DIR) + set(HALF_DIR "../thirdparty/serialization_lib/third_party/half") +endif() + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) # Common sources required for TOSA Reference Model library, executable and unit tests @@ -92,6 +96,7 @@ target_include_directories(tosa_reference_model_lib ${EIGEN_DIR} ${EIGEN_DIR}/unsupported/ ${SERIALIZATION_DIR}/include + ${HALF_DIR}/include ) target_link_libraries(tosa_reference_model_lib @@ -130,6 +135,7 @@ if(BUILD_TOSA_REFERENCE_MODEL_EXECUTABLE) ${EIGEN_DIR} ${EIGEN_DIR}/unsupported/ ${SERIALIZATION_DIR}/include + ${HALF_DIR}/include ) target_link_libraries(tosa_reference_model @@ -171,6 +177,7 @@ if(BUILD_TOSA_REFERENCE_MODEL_TESTS) ${EIGEN_DIR} ${EIGEN_DIR}/unsupported/ ${SERIALIZATION_DIR}/include + ${HALF_DIR}/include ${DOCTEST_DIR} ) @@ -203,6 +210,7 @@ if(BUILD_MODEL_RUNNER_SAMPLE) ${EIGEN_DIR} ${EIGEN_DIR}/unsupported/ ${SERIALIZATION_DIR}/include + ${HALF_DIR}/include ) target_link_libraries(model_runner_sample diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h index adf81b2..874f1d8 100644 --- a/reference_model/src/graph_node.h +++ b/reference_model/src/graph_node.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,6 +24,9 @@ #define DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) template class TosaReference::OP; +#define DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, RANK, DTYPE, ACCUM_DTYPE) \ + template class TosaReference::OP; + #define DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \ template class TosaReference::OP; @@ -35,8 +38,14 @@ #define DEF_INSTANTIATE_ONE_TYPE(OP, DTYPE) template class TosaReference::OP; +#define DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OP, DTYPE, ACCUM_DTYPE) \ + template class TosaReference::OP; + #define DEF_INSTANTIATE_TWO_TYPE(OP, DTYPE1, DTYPE2) template class TosaReference::OP; +#define DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OP, DTYPE1, DTYPE2, ACCUM_DTYPE) \ + template class TosaReference::OP; + #define DEF_INSTANTIATE_THREE_TYPE(OP, DTYPE1, DTYPE2, OP_TYPE) \ template class TosaReference::OP; @@ -57,6 +66,14 @@ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 6, DTYPE) +#define DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, DTYPE, ACCUM_DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 1, DTYPE, ACCUM_DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 2, DTYPE, ACCUM_DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 3, DTYPE, ACCUM_DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 4, DTYPE, ACCUM_DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 5, DTYPE, ACCUM_DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 6, DTYPE, ACCUM_DTYPE) + #define DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \ DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 0, DTYPE1, DTYPE2) \ DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 1, DTYPE1, DTYPE2) \ diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc index c344bcb..1c0c23a 100644 --- a/reference_model/src/ops/activation_funcs.cc +++ b/reference_model/src/ops/activation_funcs.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ int OpClamp::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: { InEigenType min = (InEigenType)attribute->min_fp(); @@ -57,6 +58,7 @@ int OpSigmoid::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { return (1.0 / (1.0 + (expf(-1.0 * a)))); }; break; @@ -72,6 +74,7 @@ int OpTanh::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { return tanhf(a); }; break; @@ -83,10 +86,13 @@ int OpTanh::register_fcn() } // template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FLOAT); diff --git a/reference_model/src/ops/comparison.cc b/reference_model/src/ops/comparison.cc index ab89e24..5930c1a 100644 --- a/reference_model/src/ops/comparison.cc +++ b/reference_model/src/ops/comparison.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ int OpEqual::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; }; @@ -43,6 +44,7 @@ int OpGreater::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; }; @@ -59,6 +61,7 @@ int OpGreaterEqual::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; }; @@ -71,11 +74,14 @@ int OpGreaterEqual::register_fcn() } // template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32); diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index 7840450..1ed0be2 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -190,6 +190,7 @@ int OpPad::eval() case DType_INT32: pad_value = (InEigenType)attribute->pad_const_int(); break; + case DType_FP16: case DType_FLOAT: pad_value = (InEigenType)attribute->pad_const_fp(); break; @@ -637,42 +638,49 @@ int OpTranspose::eval() } // template explicit instantiation +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL) +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); +DEF_INSTANTIATE_RESHAPE(OpReshape, FP16); DEF_INSTANTIATE_RESHAPE(OpReshape, FLOAT); DEF_INSTANTIATE_RESHAPE(OpReshape, INT8); DEF_INSTANTIATE_RESHAPE(OpReshape, INT16); DEF_INSTANTIATE_RESHAPE(OpReshape, INT32); DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc index 30c9511..4ff08be 100644 --- a/reference_model/src/ops/data_nodes.cc +++ b/reference_model/src/ops/data_nodes.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -89,6 +89,7 @@ int OpIdentity::eval() // template explicit instantiation // note OpConst is not templated +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index a5e1a20..917d56e 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -142,6 +142,7 @@ int OpAdd::register_fcn() return static_cast(res_in_64); }; break; + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; }; break; @@ -369,6 +370,7 @@ int OpMaximum::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; }; @@ -385,6 +387,7 @@ int OpMinimum::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; }; @@ -403,6 +406,7 @@ int OpMul::register_fcn() switch (InDtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; }; break; @@ -452,6 +456,7 @@ int OpPow::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); }; break; @@ -476,6 +481,7 @@ int OpSub::register_fcn() return static_cast(res_in_64); }; break; + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; }; break; @@ -574,6 +580,7 @@ int OpTable::eval() } // template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); @@ -609,24 +616,32 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16); +// Instantiation of nodes for comparison operators opEqual, opGreater +// and opGreaterEqual +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL); diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc index 72fe5a0..da046a7 100644 --- a/reference_model/src/ops/ewise_ternary.cc +++ b/reference_model/src/ops/ewise_ternary.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -107,6 +107,7 @@ int OpSelect<0, Dtype>::eval() } // template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16); diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc index 8ef1e3c..52f5aff 100644 --- a/reference_model/src/ops/ewise_unary.cc +++ b/reference_model/src/ops/ewise_unary.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -79,6 +79,7 @@ int OpAbs::register_fcn() switch (Dtype) { case DType_FLOAT: + case DType_FP16: case DType_INT32: this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); }; break; @@ -111,6 +112,7 @@ int OpCeil::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); }; break; @@ -158,6 +160,7 @@ int OpExp::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); }; break; @@ -173,6 +176,7 @@ int OpFloor::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); }; break; @@ -188,6 +192,7 @@ int OpLog::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); }; break; @@ -239,6 +244,7 @@ int OpNegate::register_fcn() switch (Dtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { InEigenType result = -(a); @@ -290,6 +296,7 @@ int OpReciprocal::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; }; break; @@ -305,6 +312,7 @@ int OpRsqrt::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); }; break; @@ -316,6 +324,7 @@ int OpRsqrt::register_fcn() } // template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); @@ -323,23 +332,30 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT); diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc index a021226..891261b 100644 --- a/reference_model/src/ops/image.cc +++ b/reference_model/src/ops/image.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -63,7 +63,7 @@ int OpResize::checkTensorAttributes() if (this->mode == ResizeMode_BILINEAR) { - if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FLOAT) + if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FLOAT && OutDtype != DType_FP16) { printNodeValidationError("OpResize: invalid data type for BILINEAR"); return 1; @@ -71,7 +71,7 @@ int OpResize::checkTensorAttributes() } else { - if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FLOAT) + if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FLOAT && OutDtype != DType_FP16) { printNodeValidationError("OpResize: invalid data type for NEAREST"); return 1; @@ -224,4 +224,5 @@ DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT32, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT8, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT48, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT16, int16_t); +DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, float); DEF_INSTANTIATE_THREE_TYPE(OpResize, FLOAT, FLOAT, float); diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index b6a2e15..fd73eb5 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -48,71 +48,91 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, { // tensor_ops case Op_ARGMAX: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); break; case Op_AVG_POOL2D: - DEF_FACTORY_ONE_TYPE(OpAvgPool2d, FLOAT); - DEF_FACTORY_ONE_TYPE(OpAvgPool2d, INT8); - DEF_FACTORY_ONE_TYPE(OpAvgPool2d, INT16); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FLOAT); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FLOAT, FLOAT); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32); break; case Op_CONV2D: - DEF_FACTORY_TWO_TYPE(OpConv2d, FLOAT, FLOAT); - DEF_FACTORY_TWO_TYPE(OpConv2d, INT8, INT4); - DEF_FACTORY_TWO_TYPE(OpConv2d, INT8, INT8); - DEF_FACTORY_TWO_TYPE(OpConv2d, INT16, INT8); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FLOAT, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT16, INT8, INT48); break; case Op_CONV3D: - DEF_FACTORY_TWO_TYPE(OpConv3d, FLOAT, FLOAT); - DEF_FACTORY_TWO_TYPE(OpConv3d, INT8, INT4); - DEF_FACTORY_TWO_TYPE(OpConv3d, INT8, INT8); - DEF_FACTORY_TWO_TYPE(OpConv3d, INT16, INT8); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FLOAT, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT16, INT8, INT48); break; case Op_DEPTHWISE_CONV2D: - DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT); - DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT8, INT4); - DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT8, INT8); - DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT16, INT8); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FLOAT, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT16, INT8, INT48); break; case Op_FULLY_CONNECTED: - DEF_FACTORY_TWO_TYPE(OpFullyConnected, FLOAT, FLOAT); - DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT8, INT4); - DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT8, INT8); - DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT16, INT8); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP16); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FLOAT, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT4, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT8, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT16, INT8, INT48); break; case Op_MATMUL: - DEF_FACTORY_ONE_TYPE(OpMatMul, FLOAT); - DEF_FACTORY_ONE_TYPE(OpMatMul, INT8); - DEF_FACTORY_ONE_TYPE(OpMatMul, INT16); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP16); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FLOAT); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FLOAT, FLOAT); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT8, INT32); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT16, INT48); break; case Op_MAX_POOL2D: + DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FLOAT); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16); break; case Op_TRANSPOSE_CONV2D: - DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, FLOAT, FLOAT); - DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT8, INT4); - DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT8, INT8); - DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT16, INT8); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FLOAT, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT16, INT8, INT48); break; // activation_funcs case Op_CLAMP: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16); break; case Op_SIGMOID: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FLOAT); break; case Op_TANH: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FLOAT); break; // ewise_binary case Op_ADD: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); break; @@ -159,23 +179,28 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL); break; case Op_MAXIMUM: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); break; case Op_MINIMUM: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); break; case Op_MUL: + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32); break; case Op_POW: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT); break; case Op_SUB: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); break; @@ -186,6 +211,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // ewise_unary case Op_ABS: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); break; @@ -195,38 +221,46 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32); break; case Op_CEIL: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT); break; case Op_CLZ: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); break; case Op_EXP: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT); break; case Op_FLOOR: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT); break; case Op_LOG: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT); break; case Op_LOGICAL_NOT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL); break; case Op_NEGATE: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32); break; case Op_RECIPROCAL: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT); break; case Op_RSQRT: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT); break; // ewise_ternary case Op_SELECT: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16); @@ -236,14 +270,17 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // comparison case Op_EQUAL: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32); break; case Op_GREATER: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32); break; case Op_GREATER_EQUAL: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32); break; @@ -256,27 +293,32 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL); break; case Op_REDUCE_MAX: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32); break; case Op_REDUCE_MIN: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32); break; case Op_REDUCE_PRODUCT: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT); break; case Op_REDUCE_SUM: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32); break; // data layout case Op_CONCAT: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16); @@ -284,6 +326,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL); break; case Op_PAD: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8); @@ -291,6 +334,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); break; case Op_RESHAPE: + DEF_FACTORY_RESHAPE(OpReshape, FP16); DEF_FACTORY_RESHAPE(OpReshape, FLOAT); DEF_FACTORY_RESHAPE(OpReshape, INT8); DEF_FACTORY_RESHAPE(OpReshape, INT16); @@ -298,6 +342,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RESHAPE(OpReshape, BOOL); break; case Op_REVERSE: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); @@ -305,6 +350,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); break; case Op_SLICE: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); @@ -312,6 +358,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); break; case Op_TILE: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16); @@ -320,6 +367,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_TRANSPOSE: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); @@ -331,12 +379,14 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpGather, INT8); DEF_FACTORY_ONE_TYPE(OpGather, INT16); DEF_FACTORY_ONE_TYPE(OpGather, INT32); + DEF_FACTORY_ONE_TYPE(OpGather, FP16); DEF_FACTORY_ONE_TYPE(OpGather, FLOAT); break; case Op_SCATTER: DEF_FACTORY_ONE_TYPE(OpScatter, INT8); DEF_FACTORY_ONE_TYPE(OpScatter, INT16); DEF_FACTORY_ONE_TYPE(OpScatter, INT32); + DEF_FACTORY_ONE_TYPE(OpScatter, FP16); DEF_FACTORY_ONE_TYPE(OpScatter, FLOAT); break; @@ -346,6 +396,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT8, INT8); DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT48); DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT16); + DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OpResize, FP16, FP16); DEF_FACTORY_TWO_TYPE_RESIZE_FLOAT(OpResize, FLOAT, FLOAT); break; @@ -353,6 +404,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_CONST: return new OpConst(sgt, id); case Op_IDENTITY: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); @@ -368,15 +420,21 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32); diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h index 341d7dc..25dfc6e 100644 --- a/reference_model/src/ops/op_factory.h +++ b/reference_model/src/ops/op_factory.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -62,18 +62,55 @@ return new OP(sgt, attribute, id); \ } +#define DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE, ACCUM_DTYPE) \ + if (inputDType == DType_##DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \ + { \ + return new OP(sgt, attribute, id); \ + } + #define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \ if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \ { \ return new OP(sgt, attribute, id); \ } +#define DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE1, DTYPE2, ACCUM_DTYPE) \ + if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 \ + && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \ + { \ + return new OP(sgt, attribute, id); \ + } \ + +// Statement-expression to evaluate accumulate attribute in-place +#define ACCUM_FROM_ATTRIBUTE(ATTRIBUTE_NAME) \ + ({ \ + tosa::DType accumDType = tosa::DType_UNKNOWN; \ + if (auto p = dynamic_cast(attribute)) \ + { \ + auto attr = new tosa::Tosa##ATTRIBUTE_NAME##Attribute(p); \ + ASSERT_MEM(attr); \ + accumDType = tosa::EnumValuesDType()[attr->accum_dtype()]; \ + } \ + else \ + { \ + FATAL_ERROR("Can't initialize Tosa" #ATTRIBUTE_NAME "Attribute.\nPre-initialization " \ + "of this attribute is required in order to determine the accumulate type."); \ + } \ + accumDType; \ + }) \ + #define DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OP, DTYPE1, DTYPE2) \ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ { \ return new OP(sgt, attribute, id); \ } +#define DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OP, DTYPE1, DTYPE2) \ + if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ + { \ + return new OP(sgt, attribute, id); \ + } + #define DEF_FACTORY_TWO_TYPE_RESIZE_FLOAT(OP, DTYPE1, DTYPE2) \ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ { \ diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc index 18fac44..03ee660 100644 --- a/reference_model/src/ops/reduction.cc +++ b/reference_model/src/ops/reduction.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -158,17 +158,21 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32); diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc index faf7db9..25174bd 100644 --- a/reference_model/src/ops/scatter_gather.cc +++ b/reference_model/src/ops/scatter_gather.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -226,9 +226,11 @@ int OpScatter::eval() DEF_INSTANTIATE_ONE_TYPE(OpGather, INT8); DEF_INSTANTIATE_ONE_TYPE(OpGather, INT16); DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32); +DEF_INSTANTIATE_ONE_TYPE(OpGather, FP16); DEF_INSTANTIATE_ONE_TYPE(OpGather, FLOAT); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, FLOAT); diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h index 2bc7e04..9511c31 100644 --- a/reference_model/src/ops/template_types.h +++ b/reference_model/src/ops/template_types.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ #include "tosa_generated.h" #include +#include "half.hpp" using namespace tosa; @@ -69,6 +70,12 @@ struct GetEigenType using type = float; }; template <> +struct GetEigenType +{ + // NOTE: full precision used + using type = float; +}; +template <> struct GetEigenType { using type = int32_t; @@ -109,6 +116,28 @@ struct GetEigenType using type = int32_t; }; +/* Get Accumulate Eigen Type: +Same behaviour as GetEigenType for all DTypes except the +single specialised case of DType_FP16. */ +template +struct GetAccEigenType; +template <> +struct GetAccEigenType +{ + using type = half_float::half; +}; +template +struct GetAccEigenType +{ + using type = typename GetEigenType::type; +}; + +template +struct GetHalfEigenType +{ + using type = half_float::half; +}; + // Meta function to get number of bits template struct GetNumBits @@ -155,6 +184,11 @@ struct GetNumBits { static constexpr int32_t value = 48; }; +template <> +struct GetNumBits +{ + static constexpr int32_t value = 16; +}; // Meta function to get quantized min/max in compile time template @@ -262,6 +296,11 @@ struct GetAccDType static constexpr DType value = DType_INT48; }; template <> +struct GetAccDType +{ + static constexpr DType value = DType_FP16; +}; +template <> struct GetAccDType { static constexpr DType value = DType_FLOAT; diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index 2cd94bb..c617dda 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #include "tensor_ops.h" #include "quant_util.h" #include "template_types.h" +#include "half.hpp" using namespace TosaReference; using namespace Eigen; @@ -329,8 +330,8 @@ int OpArgMax::eval() return GraphNode::eval(); } -template -OpAvgPool2d::OpAvgPool2d(SubgraphTraverser* sgt_, +template +OpAvgPool2d::OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_AVG_POOL2D, id_) @@ -341,15 +342,15 @@ OpAvgPool2d::OpAvgPool2d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Pool); } -template -OpAvgPool2d::~OpAvgPool2d() +template +OpAvgPool2d::~OpAvgPool2d() { if (attribute) delete attribute; } -template -int OpAvgPool2d::checkTensorAttributes() +template +int OpAvgPool2d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -385,8 +386,8 @@ int OpAvgPool2d::checkTensorAttributes() // This calculates the number of padding elements used for each location along an axis // Average pooling only divides by the number of elements used, not including padding. // This function uses left/right, but is also used for vertical padding with top/bottom -template -ETensor1 OpAvgPool2d::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right) +template +ETensor1 OpAvgPool2d::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right) { ETensor1 result(out_size); @@ -414,8 +415,8 @@ ETensor1 OpAvgPool2d::calculate_div_map_1d(int in_size, int out_ // assuming input and output tensor have same scales like tflite reference // so no need to scale input and output -template -int OpAvgPool2d::eval() +template +int OpAvgPool2d::eval() { int in_batch = this->in->getShape()[0]; int in_height = this->in->getShape()[1]; @@ -439,11 +440,13 @@ int OpAvgPool2d::eval() int stride_h = this->attribute->stride()[0]; int stride_w = this->attribute->stride()[1]; + tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype(); + DEBUG_INFO(OP, "perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], " - "stride=[%d,%d], pad=[%d,%d,%d,%d]", + "stride=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s", in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_h, - kernel_w, stride_h, stride_w, pad_top, pad_bottom, pad_left, pad_right); + kernel_w, stride_h, stride_w, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]); Eigen::array im2col_input_dims; im2col_input_dims[0] = kernel_h * kernel_w; @@ -509,8 +512,7 @@ int OpAvgPool2d::eval() .contract(div_map_w.reshape(Eigen::array{ 1, out_width }), contract_dims) .reshape(Eigen::array{ 1, out_height, out_width, 1 }) .broadcast(bcast); - - if (Dtype != DType_FLOAT) + if (Dtype != DType_FLOAT && Dtype != DType_FP16) { try { @@ -531,14 +533,15 @@ int OpAvgPool2d::eval() } else { + // Case for float-type resizes this->out->getTensor() = (sum / div_map.template cast()).template cast(); } return GraphNode::eval(); } -template -OpConv2d::OpConv2d(SubgraphTraverser* sgt_, +template +OpConv2d::OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_CONV2D, id_) @@ -549,15 +552,15 @@ OpConv2d::OpConv2d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Conv); } -template -OpConv2d::~OpConv2d() +template +OpConv2d::~OpConv2d() { if (attribute) delete attribute; } -template -int OpConv2d::checkTensorAttributes() +template +int OpConv2d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -574,12 +577,12 @@ int OpConv2d::checkTensorAttributes() } ERROR_IF(outputs[0]->getDtype() != AccDtype, - "OpConv2d: Output data type not supported for this configuration of operator"); + "OpConv2d: Output data type not supported for this configuration of operator"); input = dynamic_cast*>(inputs[0]); weight = dynamic_cast*>(inputs[1]); bias = dynamic_cast*>(inputs[2]); - output = dynamic_cast*>(outputs[0]); + output = dynamic_cast*>(outputs[0]); std::string msg; if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(), @@ -593,8 +596,8 @@ int OpConv2d::checkTensorAttributes() return 0; } -template -int OpConv2d::eval() +template +int OpConv2d::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; @@ -630,12 +633,14 @@ int OpConv2d::eval() int dilation_h = this->attribute->dilation()[0]; int dilation_w = this->attribute->dilation()[1]; + tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype(); + DEBUG_INFO(OP, "perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], " - "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]", + "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s", in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch, out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, pad_top, - pad_bottom, pad_left, pad_right); + pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]); // GEMM-conv2d, left matrix is input, right matrix is weight Eigen::array im2col_input_dims; @@ -695,33 +700,33 @@ int OpConv2d::eval() // transpose and reshape weight from [OC, H, W, IC] to [H * W * IC, OC] ETensor2 im2col_weight = - weight_val.shuffle(Eigen::array({ 1, 2, 3, 0 })).reshape(im2col_weight_dims); + weight_val.shuffle(Eigen::array({ 1, 2, 3, 0 })).reshape(im2col_weight_dims); // don't need to apply bias_multiplier ( * bias_scale and >> bias_shift) since tflite already scale it // and reshaped from [C] to [1, C], and broadcast to [N * H * W, C] - ETensor2 bias_2d = this->bias->getTensor().reshape(bias_reshaped_dims).broadcast(bias_bcast_dims); + ETensor2 bias_2d = (this->bias->getTensor().reshape(bias_reshaped_dims).broadcast(bias_bcast_dims)).template cast(); // output matrix is [N * H * W, C] - ETensor2 contracted_result = - im2col_input.template cast().contract(im2col_weight.template cast(), contract_dims); + ETensor2 contracted_result = + (im2col_input.template cast().contract(im2col_weight.template cast(), contract_dims)).template cast(); // adding bias - ETensor2 biased_output = contracted_result + bias_2d.template cast(); + ETensor2 biased_output = contracted_result + bias_2d; // reshape back to [N, H, W, C] this->output->getTensor() = biased_output.reshape(col2im_output_dims); if (AccDtype == DType_INT48) { - this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); - this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); + this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); } return GraphNode::eval(); } -template -OpConv3d::OpConv3d(SubgraphTraverser* sgt_, +template +OpConv3d::OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_CONV3D, id_) @@ -732,15 +737,15 @@ OpConv3d::OpConv3d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Conv); } -template -OpConv3d::~OpConv3d() +template +OpConv3d::~OpConv3d() { if (attribute) delete attribute; } -template -int OpConv3d::checkTensorAttributes() +template +int OpConv3d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -757,12 +762,12 @@ int OpConv3d::checkTensorAttributes() } ERROR_IF(outputs[0]->getDtype() != AccDtype, - "OpConv3d: Output data type not supported for this configuration of operator"); + "OpConv3d: Output data type not supported for this configuration of operator"); input = dynamic_cast*>(inputs[0]); weight = dynamic_cast*>(inputs[1]); bias = dynamic_cast*>(inputs[2]); - output = dynamic_cast*>(outputs[0]); + output = dynamic_cast*>(outputs[0]); std::string msg; if (check_conv_attribute(attribute, 3 /* conv_dimension */, input->getShape(), output->getShape(), @@ -776,8 +781,8 @@ int OpConv3d::checkTensorAttributes() return 0; } -template -int OpConv3d::eval() +template +int OpConv3d::eval() { int in_batch = this->input->getShape()[0]; int in_depth = this->input->getShape()[1]; @@ -821,13 +826,15 @@ int OpConv3d::eval() int dilation_h = this->attribute->dilation()[1]; int dilation_w = this->attribute->dilation()[2]; + tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype(); + DEBUG_INFO( OP, "perform OpConv3d, input.shape=[%d,%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d,%d], output.shape=[%d,%d,%d,%d,%d], " - "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d]", + "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d], accum_dtype=%s", in_batch, in_depth, in_height, in_width, in_channels, f_out_channels, f_depth, f_height, f_width, f_in_channels, out_batch, out_depth, out_height, out_width, out_channels, stride_d, stride_h, stride_w, dilation_d, dilation_h, - dilation_w, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right); + dilation_w, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]); Eigen::array, 5> pad; pad[0] = std::make_pair(0, 0); @@ -860,7 +867,7 @@ int OpConv3d::eval() this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast); // 2. direct convolution - AccEigenType acc = 0; + AccEigenType acc(0.0); int d_idx, h_idx, w_idx; for (int ob = 0; ob < out_batch; ob++) @@ -874,7 +881,7 @@ int OpConv3d::eval() for (int oc = 0; oc < out_channels; oc++) { // Initialize accumulator with bias value - acc = this->output->getTensor()(ob, od, oh, ow, oc); + acc = (AccEigenType)this->output->getTensor()(ob, od, oh, ow, oc); for (int fd = 0; fd < f_depth; fd++) { d_idx = od * stride_d + fd * dilation_d; @@ -892,7 +899,7 @@ int OpConv3d::eval() } } } - this->output->getTensor()(ob, od, oh, ow, oc) = acc; + this->output->getTensor()(ob, od, oh, ow, oc) = (OutEigenType)acc; } } } @@ -901,15 +908,15 @@ int OpConv3d::eval() if (AccDtype == DType_INT48) { - this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); - this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); + this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); } return GraphNode::eval(); } -template -OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTraverser* sgt_, +template +OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_) @@ -920,15 +927,15 @@ OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTraverser* sg INIT_ATTRIBUTE(Conv); } -template -OpDepthwiseConv2d::~OpDepthwiseConv2d() +template +OpDepthwiseConv2d::~OpDepthwiseConv2d() { if (attribute) delete attribute; } -template -int OpDepthwiseConv2d::checkTensorAttributes() +template +int OpDepthwiseConv2d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -945,12 +952,12 @@ int OpDepthwiseConv2d::checkTensorAttributes() } ERROR_IF(outputs[0]->getDtype() != AccDtype, - "OpDepthwiseConv2d: Output data type not supported for this configuration of operator"); + "OpDepthwiseConv2d: Output data type not supported for this configuration of operator"); input = dynamic_cast*>(inputs[0]); weight = dynamic_cast*>(inputs[1]); bias = dynamic_cast*>(inputs[2]); - output = dynamic_cast*>(outputs[0]); + output = dynamic_cast*>(outputs[0]); std::string msg; if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(), @@ -964,8 +971,8 @@ int OpDepthwiseConv2d::checkTensorAttributes() return 0; } -template -int OpDepthwiseConv2d::eval() +template +int OpDepthwiseConv2d::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; @@ -1002,12 +1009,14 @@ int OpDepthwiseConv2d::eval() int dilation_h = this->attribute->dilation()[0]; int dilation_w = this->attribute->dilation()[1]; + tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype(); + DEBUG_INFO(OP, "perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], " - "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]", + "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s", in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch, out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, pad_top, - pad_bottom, pad_left, pad_right); + pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]); Eigen::array, 4> pad; pad[0] = std::make_pair(0, 0); @@ -1061,9 +1070,10 @@ int OpDepthwiseConv2d::eval() { for (int fw = 0; fw < f_width; fw++) { + // Perform multiplication in AccEigenType then cast to OutEigenType this->output->getTensor()(ob, oh, ow, ic * f_multiplier + cm) += - ((AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh, ic) * - (AccEigenType)weight_val(fh, fw, ic, cm)); + (OutEigenType)((AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh, ic) * + (AccEigenType)weight_val(fh, fw, ic, cm)); } } } @@ -1074,15 +1084,15 @@ int OpDepthwiseConv2d::eval() if (AccDtype == DType_INT48) { - this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); - this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); + this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); } return GraphNode::eval(); } -template -OpFullyConnected::OpFullyConnected(SubgraphTraverser* sgt_, +template +OpFullyConnected::OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_FULLY_CONNECTED, id_) @@ -1093,15 +1103,15 @@ OpFullyConnected::OpFullyConnected(SubgraphTraverser* sgt_ INIT_ATTRIBUTE(FullyConnected); } -template -OpFullyConnected::~OpFullyConnected() +template +OpFullyConnected::~OpFullyConnected() { if (attribute) delete attribute; } -template -int OpFullyConnected::checkTensorAttributes() +template +int OpFullyConnected::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -1128,9 +1138,9 @@ int OpFullyConnected::checkTensorAttributes() } ERROR_IF(outputs[0]->getDtype() != AccDtype, - "OpFullyConnected: Output data type not supported for this configuration of operator"); + "OpFullyConnected: Output data type not supported for this configuration of operator"); - output = dynamic_cast*>(outputs[0]); + output = dynamic_cast*>(outputs[0]); ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpFullyConnected: Input zeropoint must be zero for non int8_t data"); ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpFullyConnected: Weight zeropoint must be zero for non int8_t data"); @@ -1138,8 +1148,8 @@ int OpFullyConnected::checkTensorAttributes() return 0; } -template -int OpFullyConnected::eval() +template +int OpFullyConnected::eval() { typedef Eigen::Tensor::DimensionPair DimPair; Eigen::array dims{ { DimPair(1, 0) } }; @@ -1163,19 +1173,19 @@ int OpFullyConnected::eval() } this->output->getTensor() = - input_val.template cast().contract(weight_val.template cast(), dims) + - this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast); + input_val.template cast().contract(weight_val.template cast(), dims).template cast() + + this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast); if (AccDtype == DType_INT48) { - this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); - this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); + this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); } return GraphNode::eval(); } -template -OpMatMul::OpMatMul(SubgraphTraverser* sgt_, +template +OpMatMul::OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_MATMUL, id_) @@ -1186,15 +1196,15 @@ OpMatMul::OpMatMul(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(MatMul); } -template -OpMatMul::~OpMatMul() +template +OpMatMul::~OpMatMul() { if (attribute) delete attribute; } -template -int OpMatMul::checkTensorAttributes() +template +int OpMatMul::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -1205,11 +1215,11 @@ int OpMatMul::checkTensorAttributes() } ERROR_IF(outputs[0]->getDtype() != AccDtype, - "OpMatMul: Output data type not supported for this configuration of operator"); + "OpMatMul: Output data type not supported for this configuration of operator"); a = dynamic_cast*>(inputs[0]); b = dynamic_cast*>(inputs[1]); - output = dynamic_cast*>(outputs[0]); + output = dynamic_cast*>(outputs[0]); ASSERT_MEM(a && b && output); @@ -1255,8 +1265,8 @@ int OpMatMul::checkTensorAttributes() return 0; } -template -int OpMatMul::eval() +template +int OpMatMul::eval() { typedef Eigen::Tensor::DimensionPair DimPair; Eigen::array dims{ { DimPair(1, 0) } }; @@ -1289,22 +1299,22 @@ int OpMatMul::eval() TInRank2 b_rank2_val = b_val.slice(b_begin_array, b_size_array).reshape(b_rank2_shape); TAccRank2 output_rank2_val = a_rank2_val.template cast().contract(b_rank2_val.template cast(), dims); - TAcc output_rank3_val = output_rank2_val.reshape(output_rank3_shape); + TOut output_rank3_val = output_rank2_val.reshape(output_rank3_shape).template cast(); if (i == 0) { this->output->getTensor() = output_rank3_val; } else { - TAcc temp = this->output->getTensor().concatenate(output_rank3_val, 0); + TOut temp = this->output->getTensor().concatenate(output_rank3_val, 0); this->output->getTensor() = temp; } } if (AccDtype == DType_INT48) { - this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); - this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); + this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); } return GraphNode::eval(); @@ -1442,8 +1452,8 @@ int OpMaxPool2d::eval() return GraphNode::eval(); } -template -OpTransposeConv2d::OpTransposeConv2d(SubgraphTraverser* sgt_, +template +OpTransposeConv2d::OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_) @@ -1454,15 +1464,15 @@ OpTransposeConv2d::OpTransposeConv2d(SubgraphTraverser* sg INIT_ATTRIBUTE(TransposeConv); } -template -OpTransposeConv2d::~OpTransposeConv2d() +template +OpTransposeConv2d::~OpTransposeConv2d() { if (attribute) delete attribute; } -template -int OpTransposeConv2d::checkTensorAttributes() +template +int OpTransposeConv2d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -1473,12 +1483,12 @@ int OpTransposeConv2d::checkTensorAttributes() } ERROR_IF(outputs[0]->getDtype() != AccDtype, - "OpTransposeConv2d: Output data type not supported for this configuration of operator"); + "OpTransposeConv2d: Output data type not supported for this configuration of operator"); input = dynamic_cast*>(inputs[0]); weight = dynamic_cast*>(inputs[1]); bias = dynamic_cast*>(inputs[2]); - output = dynamic_cast*>(outputs[0]); + output = dynamic_cast*>(outputs[0]); if (attribute->out_pad().size() != 4) { @@ -1556,8 +1566,8 @@ int OpTransposeConv2d::checkTensorAttributes() return 0; } -template -int OpTransposeConv2d::eval() +template +int OpTransposeConv2d::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; @@ -1584,6 +1594,8 @@ int OpTransposeConv2d::eval() int stride_h = this->attribute->stride()[0]; int stride_w = this->attribute->stride()[1]; + tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype(); + ERROR_IF(in_batch != out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch); ERROR_IF(f_in_channels != in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", f_in_channels, in_channels); @@ -1594,10 +1606,10 @@ int OpTransposeConv2d::eval() DEBUG_INFO(OP, "perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], " - "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d]", + "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d], accum_dtype=%s", in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels, out_batch, out_height, out_width, out_channels, stride_h, stride_w, out_pad_top, - out_pad_bottom, out_pad_left, out_pad_right); + out_pad_bottom, out_pad_left, out_pad_right, EnumNamesDType()[accum_dtype]); TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); @@ -1645,8 +1657,8 @@ int OpTransposeConv2d::eval() if ((out_x >= 0 && out_x < out_width) && (out_y >= 0 && out_y < out_height)) { this->output->getTensor()(ob, out_y, out_x, oc) += - ((AccEigenType)input_val(ob, ih, iw, ic) * - (AccEigenType)weight_val(oc, fh, fw, ic)); + (OutEigenType) ((AccEigenType)input_val(ob, ih, iw, ic) * + (AccEigenType)weight_val(oc, fh, fw, ic)); } } } @@ -1658,51 +1670,68 @@ int OpTransposeConv2d::eval() if (AccDtype == DType_INT48) { - this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); - this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); + this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); } return GraphNode::eval(); } // template explicit instantiation +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); -DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, FLOAT) -DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, INT8) -DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, INT16) - -DEF_INSTANTIATE_TWO_TYPE(OpConv2d, FLOAT, FLOAT); -DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT8, INT4); -DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT8, INT8); -DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT16, INT8); - -DEF_INSTANTIATE_TWO_TYPE(OpConv3d, FLOAT, FLOAT); -DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT8, INT4); -DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT8, INT8); -DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT16, INT8); - -DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT); -DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT8, INT4); -DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT8, INT8); -DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT16, INT8); - -DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, FLOAT, FLOAT); -DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, INT8, INT4); -DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, INT8, INT8); -DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, INT16, INT8); - -DEF_INSTANTIATE_ONE_TYPE(OpMatMul, INT8); -DEF_INSTANTIATE_ONE_TYPE(OpMatMul, INT16); -DEF_INSTANTIATE_ONE_TYPE(OpMatMul, FLOAT); - +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP16); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FLOAT); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FLOAT, FLOAT); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT8, INT32); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT16, INT32); + + // [in_t, weight_t, acc_t] +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP16); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FLOAT, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT4, INT32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT8, INT32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT16, INT8, INT48); + +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP16); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FLOAT, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT4, INT32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT8, INT32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT16, INT8, INT48); + +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP16); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FLOAT, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT4, INT32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT8, INT32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT16, INT8, INT48); + +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP16); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FLOAT, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT4, INT32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT8, INT32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT16, INT8, INT48); + +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT8, INT32); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT16, INT48); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP16); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FLOAT); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FLOAT, FLOAT); + +DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FLOAT); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16); -DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, FLOAT, FLOAT); -DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, INT8, INT4); -DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, INT8, INT8); -DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, INT16, INT8); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP16); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FLOAT, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT4, INT32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT8, INT32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT16, INT8, INT48); diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index 24eadeb..fd6dd25 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -45,7 +45,7 @@ protected: TosaReference::TensorTemplate* output; }; -template +template class OpAvgPool2d : public GraphNode { public: @@ -55,9 +55,8 @@ public: virtual int checkTensorAttributes(); virtual int eval(); - static constexpr DType AccDtype = GetAccDType::value; using InEigenType = typename GetEigenType::type; - using AccEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TOut = Eigen::Tensor; @@ -75,7 +74,7 @@ protected: ETensor1 calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right); }; -template +template class OpConv2d : public GraphNode { public: @@ -85,15 +84,14 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType::value; - using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; - using TBias = Eigen::Tensor; - using TAcc = Eigen::Tensor; + using TBias = Eigen::Tensor; + using TOut = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; static constexpr int64_t AccQMax = GetQMax::value; @@ -102,11 +100,11 @@ protected: TosaReference::TensorTemplate* input; TosaReference::TensorTemplate* weight; TosaReference::TensorTemplate* bias; - TosaReference::TensorTemplate* output; + TosaReference::TensorTemplate* output; tosa::TosaConvAttribute* attribute; }; -template +template class OpConv3d : public GraphNode { public: @@ -116,15 +114,14 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType::value; - using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; - using TBias = Eigen::Tensor; - using TAcc = Eigen::Tensor; + using TBias = Eigen::Tensor; + using TOut = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; static constexpr int64_t AccQMax = GetQMax::value; @@ -133,11 +130,11 @@ protected: TosaReference::TensorTemplate* input; TosaReference::TensorTemplate* weight; TosaReference::TensorTemplate* bias; - TosaReference::TensorTemplate* output; + TosaReference::TensorTemplate* output; tosa::TosaConvAttribute* attribute; }; -template +template class OpDepthwiseConv2d : public GraphNode { public: @@ -147,15 +144,14 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType::value; - using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; - using TBias = Eigen::Tensor; - using TAcc = Eigen::Tensor; + using TBias = Eigen::Tensor; + using TOut = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; static constexpr int64_t AccQMax = GetQMax::value; @@ -164,11 +160,11 @@ protected: TosaReference::TensorTemplate* input; TosaReference::TensorTemplate* weight; TosaReference::TensorTemplate* bias; - TosaReference::TensorTemplate* output; + TosaReference::TensorTemplate* output; tosa::TosaConvAttribute* attribute; }; -template +template class OpFullyConnected : public GraphNode { public: @@ -178,14 +174,14 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType::value; using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; - using TBias = Eigen::Tensor; - using TAcc = Eigen::Tensor; + using TBias = Eigen::Tensor; + using TOut = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; static constexpr int64_t AccQMax = GetQMax::value; @@ -194,12 +190,12 @@ protected: TosaReference::TensorTemplate* input; TosaReference::TensorTemplate* weight; TosaReference::TensorTemplate* bias; - TosaReference::TensorTemplate* output; + TosaReference::TensorTemplate* output; tosa::TosaFullyConnectedAttribute* attribute; }; -template +template class OpMatMul : public GraphNode { public: @@ -209,11 +205,11 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType::value; using InEigenType = typename GetEigenType::type; - using AccEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; - using TAcc = Eigen::Tensor; + using TOut = Eigen::Tensor; using TInRank2 = Eigen::Tensor; using TAccRank2 = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; @@ -222,7 +218,7 @@ public: protected: TosaReference::TensorTemplate* a; TosaReference::TensorTemplate* b; - TosaReference::TensorTemplate* output; + TosaReference::TensorTemplate* output; int64_t N; int64_t H; int64_t W; @@ -252,7 +248,7 @@ protected: tosa::TosaPoolAttribute* attribute; }; -template +template class OpTransposeConv2d : public GraphNode { public: @@ -262,15 +258,14 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType::value; - using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; - using TBias = Eigen::Tensor; - using TAcc = Eigen::Tensor; + using TBias = Eigen::Tensor; + using TOut = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; static constexpr int64_t AccQMax = GetQMax::value; @@ -279,7 +274,7 @@ protected: TosaReference::TensorTemplate* input; TosaReference::TensorTemplate* weight; TosaReference::TensorTemplate* bias; - TosaReference::TensorTemplate* output; + TosaReference::TensorTemplate* output; TosaTransposeConvAttribute* attribute; }; diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 52de2e4..50e710a 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ #include "quant_util.h" #include "template_types.h" #include +#include "half.hpp" using namespace TosaReference; using namespace Eigen; @@ -286,6 +287,30 @@ CastHelper::CastHelper() }; } +template +CastHelper::CastHelper() +{ + fcn = [](InEigenType in) -> float { + half_float::half out = half_float::half_cast(in); // Cast to half_float + return half_float::half_cast(out); // Cast to float (underlying FP16 EigenType) + }; +} + +template +CastHelper::CastHelper() +{ + // Assuming InEigenType = float. + fcn = [](float in) -> OutEigenType { + // Perform initial rounding in half-precision then cast back to float + half_float::half h = half_float::half_cast(in); + h = std::round(h); + OutEigenType out = half_float::half_cast(h); + out = std::max(out, OutMin); + out = std::min(out, OutMax); + return out; + }; +} + template CastHelper::CastHelper() { @@ -313,15 +338,21 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32); diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h index 53470d1..5f197cf 100644 --- a/reference_model/src/ops/type_conversion.h +++ b/reference_model/src/ops/type_conversion.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -100,6 +100,42 @@ private: FcnType fcn; }; +template +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + static constexpr int32_t OutMin = GetQMin::value; + static constexpr int32_t OutMax = GetQMax::value; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + template class CastHelper { diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index d0cc6cf..386f0e5 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -385,6 +385,14 @@ int SubgraphTraverser::allocateTensor() tensor->setTensorValueInt64(i64_data.size(), i64_data.data()); } break; + case DType_FP16: + { + // Interpret f16 data as float + std::vector f16_data; + TosaSerializationHandler::ConvertU8toF16(ts->GetData(), tensor->getElementCount(), f16_data); + tensor->setTensorValueFloat(f16_data.size(), f16_data.data()); + } + break; case DType_FLOAT: { std::vector fp32_data; @@ -693,7 +701,7 @@ int SubgraphTraverser::validateGraph() DType dtype = currTensor->getDtype(); // Float-point disallowed - if (dtype == DType_FLOAT) + if (dtype == DType_FLOAT || dtype == DType_FP16) { WARNING("SubgraphTraverser::validateGraph(): TOSA Base Inference profile selected: All floating point " "disabled, but %s tensor %s found\n", diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index 7cbeb13..cbe12a9 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ #include "tensor.h" #include "arith_util.h" +#include "half.hpp" using namespace TosaReference; using namespace Eigen; @@ -84,6 +85,7 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) { uint32_t elements = getElementCount(); float* fdatabuf = nullptr; + half_float::half* f16databuf = nullptr; int32_t* i32databuf = nullptr; int64_t* i64databuf = nullptr; bool* bdatabuf = nullptr; @@ -97,6 +99,14 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) nperror = NumpyUtilities::readFromNpyFile(filename, elements, fdatabuf); break; + case DType_FP16: + f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements); + ASSERT_MEM(f16databuf); + fdatabuf = (float*)calloc(sizeof(float), elements); + ASSERT_MEM(fdatabuf); + + nperror = NumpyUtilities::readFromNpyFile(filename, elements, f16databuf); + break; case DType_INT32: case DType_UINT8: case DType_INT4: @@ -146,9 +156,17 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) switch (getDtype()) { + case DType_FP16: + // Convert from fp16 to fp32 + for (uint32_t i=0; i < elements; i++) { + fdatabuf[i] = half_float::half_cast(f16databuf[i]); + } + // Fall through to DType_FLOAT case case DType_FLOAT: if (setTensorValueFloat(elements, fdatabuf)) { + if (f16databuf) + free(f16databuf); free(fdatabuf); return 1; } @@ -187,6 +205,8 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) if (fdatabuf) free(fdatabuf); + if (f16databuf) + free(f16databuf); if (i32databuf) free(i32databuf); if (i64databuf) @@ -200,11 +220,12 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) int TosaReference::Tensor::writeToNpyFile(const char* filename) const { float* fdatabuf = nullptr; + half_float::half* f16databuf = nullptr; int32_t* i32databuf = nullptr; int64_t* i64databuf = nullptr; bool* bdatabuf = nullptr; NumpyUtilities::NPError nperror; - int elements = getElementCount(); + uint32_t elements = getElementCount(); switch (getDtype()) { @@ -222,6 +243,27 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const free(fdatabuf); break; + case DType_FP16: + fdatabuf = (float*)calloc(sizeof(float), elements); + ASSERT_MEM(fdatabuf); + f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements); + ASSERT_MEM(f16databuf); + + if (getTensorValueFloat(elements, fdatabuf)) + { + free(fdatabuf); + free(f16databuf); + return 1; + } + // Convert fp32 to fp16 + for (uint32_t i=0; i < elements; i++) { + f16databuf[i] = half_float::half_cast(fdatabuf[i]); + } + nperror = NumpyUtilities::writeToNpyFile(filename, shape, f16databuf); + + free(fdatabuf); + free(f16databuf); + break; case DType_INT32: case DType_UINT8: case DType_INT4: diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index 6b7d5f1..78a210e 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -643,6 +643,7 @@ public: switch (tensorDtype_) { case DType_FLOAT: + case DType_FP16: switch (rank) { case 0: -- cgit v1.2.1