From 8b39043c70332e1e4c95ee6a9616aec40dd3baf1 Mon Sep 17 00:00:00 2001 From: James Ward Date: Fri, 12 Aug 2022 20:48:56 +0100 Subject: Reference model changes for fp16 support Change-Id: I72f21fcfa153046274969d327313e3349981dbe6 Signed-off-by: James Ward --- CMakeLists.txt | 1 + reference_model/CMakeLists.txt | 10 +- reference_model/src/graph_node.h | 19 +- reference_model/src/ops/activation_funcs.cc | 8 +- reference_model/src/ops/comparison.cc | 8 +- reference_model/src/ops/data_layout.cc | 10 +- reference_model/src/ops/data_nodes.cc | 3 +- reference_model/src/ops/ewise_binary.cc | 17 +- reference_model/src/ops/ewise_ternary.cc | 3 +- reference_model/src/ops/ewise_unary.cc | 18 +- reference_model/src/ops/image.cc | 7 +- reference_model/src/ops/op_factory.cc | 112 +++++++--- reference_model/src/ops/op_factory.h | 39 +++- reference_model/src/ops/reduction.cc | 6 +- reference_model/src/ops/scatter_gather.cc | 4 +- reference_model/src/ops/template_types.h | 41 +++- reference_model/src/ops/tensor_ops.cc | 315 +++++++++++++++------------- reference_model/src/ops/tensor_ops.h | 81 ++++--- reference_model/src/ops/type_conversion.cc | 33 ++- reference_model/src/ops/type_conversion.h | 38 +++- reference_model/src/subgraph_traverser.cc | 12 +- reference_model/src/tensor.cc | 46 +++- reference_model/src/tensor.h | 3 +- scripts/json2numpy/json2numpy.py | 4 +- thirdparty/serialization_lib | 2 +- verif/checker/tosa_result_checker.py | 4 +- verif/generator/tosa_arg_gen.py | 230 ++++++++++++++------ verif/generator/tosa_error_if.py | 74 ++++++- verif/generator/tosa_test_gen.py | 298 +++++++++++++++----------- verif/generator/tosa_utils.py | 39 ++++ verif/tests/test_tosa_result_checker.py | 2 +- 31 files changed, 1046 insertions(+), 441 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d3281c6..6556b27 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,6 +12,7 @@ option(BUILD_MODEL_RUNNER_SAMPLE "Enable building of ModelRunner sample executab option(SERIALIZATION_DIR "Location where the TOSA Serialization Library 'include' folder is found" Off) option(FLATBUFFERS_DIR "Location where the FlatBuffers 'include' and 'lib' folders is found" Off) option(EIGEN_DIR "Location where the Eigen folder is found" Off) +option(HALF_DIR "Location where the Half folder is found" Off) option(DOCTEST_DIR "Location where the doctest folder is found (If building unit tests)" Off) add_subdirectory(thirdparty) diff --git a/reference_model/CMakeLists.txt b/reference_model/CMakeLists.txt index a790968..04b0db5 100644 --- a/reference_model/CMakeLists.txt +++ b/reference_model/CMakeLists.txt @@ -43,7 +43,7 @@ else() set(SERIALIZATION_DIR "../thirdparty/serialization_lib/") endif() -# If Flatbuffers or Eigen path isn't specified, set to thirdparty directory. +# If Flatbuffers, Eigen, Half path isn't specified, set to thirdparty directory. if(NOT FLATBUFFERS_DIR) set(FLATBUFFERS_DIR "../thirdparty/serialization_lib/third_party/flatbuffers/") endif() @@ -52,6 +52,10 @@ if(NOT EIGEN_DIR) set(EIGEN_DIR "../thirdparty/eigen/") endif() +if(NOT HALF_DIR) + set(HALF_DIR "../thirdparty/serialization_lib/third_party/half") +endif() + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) # Common sources required for TOSA Reference Model library, executable and unit tests @@ -92,6 +96,7 @@ target_include_directories(tosa_reference_model_lib ${EIGEN_DIR} ${EIGEN_DIR}/unsupported/ ${SERIALIZATION_DIR}/include + ${HALF_DIR}/include ) target_link_libraries(tosa_reference_model_lib @@ -130,6 +135,7 @@ if(BUILD_TOSA_REFERENCE_MODEL_EXECUTABLE) ${EIGEN_DIR} ${EIGEN_DIR}/unsupported/ ${SERIALIZATION_DIR}/include + ${HALF_DIR}/include ) target_link_libraries(tosa_reference_model @@ -171,6 +177,7 @@ if(BUILD_TOSA_REFERENCE_MODEL_TESTS) ${EIGEN_DIR} ${EIGEN_DIR}/unsupported/ ${SERIALIZATION_DIR}/include + ${HALF_DIR}/include ${DOCTEST_DIR} ) @@ -203,6 +210,7 @@ if(BUILD_MODEL_RUNNER_SAMPLE) ${EIGEN_DIR} ${EIGEN_DIR}/unsupported/ ${SERIALIZATION_DIR}/include + ${HALF_DIR}/include ) target_link_libraries(model_runner_sample diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h index adf81b2..874f1d8 100644 --- a/reference_model/src/graph_node.h +++ b/reference_model/src/graph_node.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,6 +24,9 @@ #define DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) template class TosaReference::OP; +#define DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, RANK, DTYPE, ACCUM_DTYPE) \ + template class TosaReference::OP; + #define DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \ template class TosaReference::OP; @@ -35,8 +38,14 @@ #define DEF_INSTANTIATE_ONE_TYPE(OP, DTYPE) template class TosaReference::OP; +#define DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OP, DTYPE, ACCUM_DTYPE) \ + template class TosaReference::OP; + #define DEF_INSTANTIATE_TWO_TYPE(OP, DTYPE1, DTYPE2) template class TosaReference::OP; +#define DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OP, DTYPE1, DTYPE2, ACCUM_DTYPE) \ + template class TosaReference::OP; + #define DEF_INSTANTIATE_THREE_TYPE(OP, DTYPE1, DTYPE2, OP_TYPE) \ template class TosaReference::OP; @@ -57,6 +66,14 @@ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 6, DTYPE) +#define DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, DTYPE, ACCUM_DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 1, DTYPE, ACCUM_DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 2, DTYPE, ACCUM_DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 3, DTYPE, ACCUM_DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 4, DTYPE, ACCUM_DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 5, DTYPE, ACCUM_DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 6, DTYPE, ACCUM_DTYPE) + #define DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \ DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 0, DTYPE1, DTYPE2) \ DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 1, DTYPE1, DTYPE2) \ diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc index c344bcb..1c0c23a 100644 --- a/reference_model/src/ops/activation_funcs.cc +++ b/reference_model/src/ops/activation_funcs.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ int OpClamp::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: { InEigenType min = (InEigenType)attribute->min_fp(); @@ -57,6 +58,7 @@ int OpSigmoid::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { return (1.0 / (1.0 + (expf(-1.0 * a)))); }; break; @@ -72,6 +74,7 @@ int OpTanh::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { return tanhf(a); }; break; @@ -83,10 +86,13 @@ int OpTanh::register_fcn() } // template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FLOAT); diff --git a/reference_model/src/ops/comparison.cc b/reference_model/src/ops/comparison.cc index ab89e24..5930c1a 100644 --- a/reference_model/src/ops/comparison.cc +++ b/reference_model/src/ops/comparison.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ int OpEqual::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; }; @@ -43,6 +44,7 @@ int OpGreater::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; }; @@ -59,6 +61,7 @@ int OpGreaterEqual::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; }; @@ -71,11 +74,14 @@ int OpGreaterEqual::register_fcn() } // template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32); diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index 7840450..1ed0be2 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -190,6 +190,7 @@ int OpPad::eval() case DType_INT32: pad_value = (InEigenType)attribute->pad_const_int(); break; + case DType_FP16: case DType_FLOAT: pad_value = (InEigenType)attribute->pad_const_fp(); break; @@ -637,42 +638,49 @@ int OpTranspose::eval() } // template explicit instantiation +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL) +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); +DEF_INSTANTIATE_RESHAPE(OpReshape, FP16); DEF_INSTANTIATE_RESHAPE(OpReshape, FLOAT); DEF_INSTANTIATE_RESHAPE(OpReshape, INT8); DEF_INSTANTIATE_RESHAPE(OpReshape, INT16); DEF_INSTANTIATE_RESHAPE(OpReshape, INT32); DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc index 30c9511..4ff08be 100644 --- a/reference_model/src/ops/data_nodes.cc +++ b/reference_model/src/ops/data_nodes.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -89,6 +89,7 @@ int OpIdentity::eval() // template explicit instantiation // note OpConst is not templated +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index a5e1a20..917d56e 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -142,6 +142,7 @@ int OpAdd::register_fcn() return static_cast(res_in_64); }; break; + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; }; break; @@ -369,6 +370,7 @@ int OpMaximum::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; }; @@ -385,6 +387,7 @@ int OpMinimum::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; }; @@ -403,6 +406,7 @@ int OpMul::register_fcn() switch (InDtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; }; break; @@ -452,6 +456,7 @@ int OpPow::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); }; break; @@ -476,6 +481,7 @@ int OpSub::register_fcn() return static_cast(res_in_64); }; break; + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; }; break; @@ -574,6 +580,7 @@ int OpTable::eval() } // template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); @@ -609,24 +616,32 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16); +// Instantiation of nodes for comparison operators opEqual, opGreater +// and opGreaterEqual +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL); diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc index 72fe5a0..da046a7 100644 --- a/reference_model/src/ops/ewise_ternary.cc +++ b/reference_model/src/ops/ewise_ternary.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -107,6 +107,7 @@ int OpSelect<0, Dtype>::eval() } // template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16); diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc index 8ef1e3c..52f5aff 100644 --- a/reference_model/src/ops/ewise_unary.cc +++ b/reference_model/src/ops/ewise_unary.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -79,6 +79,7 @@ int OpAbs::register_fcn() switch (Dtype) { case DType_FLOAT: + case DType_FP16: case DType_INT32: this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); }; break; @@ -111,6 +112,7 @@ int OpCeil::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); }; break; @@ -158,6 +160,7 @@ int OpExp::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); }; break; @@ -173,6 +176,7 @@ int OpFloor::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); }; break; @@ -188,6 +192,7 @@ int OpLog::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); }; break; @@ -239,6 +244,7 @@ int OpNegate::register_fcn() switch (Dtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { InEigenType result = -(a); @@ -290,6 +296,7 @@ int OpReciprocal::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; }; break; @@ -305,6 +312,7 @@ int OpRsqrt::register_fcn() { switch (Dtype) { + case DType_FP16: case DType_FLOAT: this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); }; break; @@ -316,6 +324,7 @@ int OpRsqrt::register_fcn() } // template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); @@ -323,23 +332,30 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT); diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc index a021226..891261b 100644 --- a/reference_model/src/ops/image.cc +++ b/reference_model/src/ops/image.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -63,7 +63,7 @@ int OpResize::checkTensorAttributes() if (this->mode == ResizeMode_BILINEAR) { - if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FLOAT) + if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FLOAT && OutDtype != DType_FP16) { printNodeValidationError("OpResize: invalid data type for BILINEAR"); return 1; @@ -71,7 +71,7 @@ int OpResize::checkTensorAttributes() } else { - if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FLOAT) + if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FLOAT && OutDtype != DType_FP16) { printNodeValidationError("OpResize: invalid data type for NEAREST"); return 1; @@ -224,4 +224,5 @@ DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT32, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT8, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT48, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT16, int16_t); +DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, float); DEF_INSTANTIATE_THREE_TYPE(OpResize, FLOAT, FLOAT, float); diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index b6a2e15..fd73eb5 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -48,71 +48,91 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, { // tensor_ops case Op_ARGMAX: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); break; case Op_AVG_POOL2D: - DEF_FACTORY_ONE_TYPE(OpAvgPool2d, FLOAT); - DEF_FACTORY_ONE_TYPE(OpAvgPool2d, INT8); - DEF_FACTORY_ONE_TYPE(OpAvgPool2d, INT16); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FLOAT); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FLOAT, FLOAT); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32); break; case Op_CONV2D: - DEF_FACTORY_TWO_TYPE(OpConv2d, FLOAT, FLOAT); - DEF_FACTORY_TWO_TYPE(OpConv2d, INT8, INT4); - DEF_FACTORY_TWO_TYPE(OpConv2d, INT8, INT8); - DEF_FACTORY_TWO_TYPE(OpConv2d, INT16, INT8); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FLOAT, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT16, INT8, INT48); break; case Op_CONV3D: - DEF_FACTORY_TWO_TYPE(OpConv3d, FLOAT, FLOAT); - DEF_FACTORY_TWO_TYPE(OpConv3d, INT8, INT4); - DEF_FACTORY_TWO_TYPE(OpConv3d, INT8, INT8); - DEF_FACTORY_TWO_TYPE(OpConv3d, INT16, INT8); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FLOAT, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT16, INT8, INT48); break; case Op_DEPTHWISE_CONV2D: - DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT); - DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT8, INT4); - DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT8, INT8); - DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT16, INT8); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FLOAT, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT16, INT8, INT48); break; case Op_FULLY_CONNECTED: - DEF_FACTORY_TWO_TYPE(OpFullyConnected, FLOAT, FLOAT); - DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT8, INT4); - DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT8, INT8); - DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT16, INT8); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP16); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FLOAT, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT4, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT8, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT16, INT8, INT48); break; case Op_MATMUL: - DEF_FACTORY_ONE_TYPE(OpMatMul, FLOAT); - DEF_FACTORY_ONE_TYPE(OpMatMul, INT8); - DEF_FACTORY_ONE_TYPE(OpMatMul, INT16); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP16); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FLOAT); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FLOAT, FLOAT); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT8, INT32); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT16, INT48); break; case Op_MAX_POOL2D: + DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FLOAT); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16); break; case Op_TRANSPOSE_CONV2D: - DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, FLOAT, FLOAT); - DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT8, INT4); - DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT8, INT8); - DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT16, INT8); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FLOAT, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT16, INT8, INT48); break; // activation_funcs case Op_CLAMP: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16); break; case Op_SIGMOID: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FLOAT); break; case Op_TANH: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FLOAT); break; // ewise_binary case Op_ADD: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); break; @@ -159,23 +179,28 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL); break; case Op_MAXIMUM: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); break; case Op_MINIMUM: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); break; case Op_MUL: + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32); break; case Op_POW: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT); break; case Op_SUB: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); break; @@ -186,6 +211,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // ewise_unary case Op_ABS: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); break; @@ -195,38 +221,46 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32); break; case Op_CEIL: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT); break; case Op_CLZ: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); break; case Op_EXP: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT); break; case Op_FLOOR: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT); break; case Op_LOG: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT); break; case Op_LOGICAL_NOT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL); break; case Op_NEGATE: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32); break; case Op_RECIPROCAL: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT); break; case Op_RSQRT: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT); break; // ewise_ternary case Op_SELECT: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16); @@ -236,14 +270,17 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // comparison case Op_EQUAL: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32); break; case Op_GREATER: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32); break; case Op_GREATER_EQUAL: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32); break; @@ -256,27 +293,32 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL); break; case Op_REDUCE_MAX: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32); break; case Op_REDUCE_MIN: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32); break; case Op_REDUCE_PRODUCT: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT); break; case Op_REDUCE_SUM: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32); break; // data layout case Op_CONCAT: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16); @@ -284,6 +326,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL); break; case Op_PAD: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8); @@ -291,6 +334,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); break; case Op_RESHAPE: + DEF_FACTORY_RESHAPE(OpReshape, FP16); DEF_FACTORY_RESHAPE(OpReshape, FLOAT); DEF_FACTORY_RESHAPE(OpReshape, INT8); DEF_FACTORY_RESHAPE(OpReshape, INT16); @@ -298,6 +342,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RESHAPE(OpReshape, BOOL); break; case Op_REVERSE: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); @@ -305,6 +350,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); break; case Op_SLICE: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); @@ -312,6 +358,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); break; case Op_TILE: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16); @@ -320,6 +367,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_TRANSPOSE: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); @@ -331,12 +379,14 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpGather, INT8); DEF_FACTORY_ONE_TYPE(OpGather, INT16); DEF_FACTORY_ONE_TYPE(OpGather, INT32); + DEF_FACTORY_ONE_TYPE(OpGather, FP16); DEF_FACTORY_ONE_TYPE(OpGather, FLOAT); break; case Op_SCATTER: DEF_FACTORY_ONE_TYPE(OpScatter, INT8); DEF_FACTORY_ONE_TYPE(OpScatter, INT16); DEF_FACTORY_ONE_TYPE(OpScatter, INT32); + DEF_FACTORY_ONE_TYPE(OpScatter, FP16); DEF_FACTORY_ONE_TYPE(OpScatter, FLOAT); break; @@ -346,6 +396,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT8, INT8); DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT48); DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT16); + DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OpResize, FP16, FP16); DEF_FACTORY_TWO_TYPE_RESIZE_FLOAT(OpResize, FLOAT, FLOAT); break; @@ -353,6 +404,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_CONST: return new OpConst(sgt, id); case Op_IDENTITY: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); @@ -368,15 +420,21 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32); diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h index 341d7dc..25dfc6e 100644 --- a/reference_model/src/ops/op_factory.h +++ b/reference_model/src/ops/op_factory.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -62,18 +62,55 @@ return new OP(sgt, attribute, id); \ } +#define DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE, ACCUM_DTYPE) \ + if (inputDType == DType_##DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \ + { \ + return new OP(sgt, attribute, id); \ + } + #define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \ if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \ { \ return new OP(sgt, attribute, id); \ } +#define DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE1, DTYPE2, ACCUM_DTYPE) \ + if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 \ + && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \ + { \ + return new OP(sgt, attribute, id); \ + } \ + +// Statement-expression to evaluate accumulate attribute in-place +#define ACCUM_FROM_ATTRIBUTE(ATTRIBUTE_NAME) \ + ({ \ + tosa::DType accumDType = tosa::DType_UNKNOWN; \ + if (auto p = dynamic_cast(attribute)) \ + { \ + auto attr = new tosa::Tosa##ATTRIBUTE_NAME##Attribute(p); \ + ASSERT_MEM(attr); \ + accumDType = tosa::EnumValuesDType()[attr->accum_dtype()]; \ + } \ + else \ + { \ + FATAL_ERROR("Can't initialize Tosa" #ATTRIBUTE_NAME "Attribute.\nPre-initialization " \ + "of this attribute is required in order to determine the accumulate type."); \ + } \ + accumDType; \ + }) \ + #define DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OP, DTYPE1, DTYPE2) \ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ { \ return new OP(sgt, attribute, id); \ } +#define DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OP, DTYPE1, DTYPE2) \ + if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ + { \ + return new OP(sgt, attribute, id); \ + } + #define DEF_FACTORY_TWO_TYPE_RESIZE_FLOAT(OP, DTYPE1, DTYPE2) \ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ { \ diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc index 18fac44..03ee660 100644 --- a/reference_model/src/ops/reduction.cc +++ b/reference_model/src/ops/reduction.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -158,17 +158,21 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32); diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc index faf7db9..25174bd 100644 --- a/reference_model/src/ops/scatter_gather.cc +++ b/reference_model/src/ops/scatter_gather.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -226,9 +226,11 @@ int OpScatter::eval() DEF_INSTANTIATE_ONE_TYPE(OpGather, INT8); DEF_INSTANTIATE_ONE_TYPE(OpGather, INT16); DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32); +DEF_INSTANTIATE_ONE_TYPE(OpGather, FP16); DEF_INSTANTIATE_ONE_TYPE(OpGather, FLOAT); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, FLOAT); diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h index 2bc7e04..9511c31 100644 --- a/reference_model/src/ops/template_types.h +++ b/reference_model/src/ops/template_types.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ #include "tosa_generated.h" #include +#include "half.hpp" using namespace tosa; @@ -69,6 +70,12 @@ struct GetEigenType using type = float; }; template <> +struct GetEigenType +{ + // NOTE: full precision used + using type = float; +}; +template <> struct GetEigenType { using type = int32_t; @@ -109,6 +116,28 @@ struct GetEigenType using type = int32_t; }; +/* Get Accumulate Eigen Type: +Same behaviour as GetEigenType for all DTypes except the +single specialised case of DType_FP16. */ +template +struct GetAccEigenType; +template <> +struct GetAccEigenType +{ + using type = half_float::half; +}; +template +struct GetAccEigenType +{ + using type = typename GetEigenType::type; +}; + +template +struct GetHalfEigenType +{ + using type = half_float::half; +}; + // Meta function to get number of bits template struct GetNumBits @@ -155,6 +184,11 @@ struct GetNumBits { static constexpr int32_t value = 48; }; +template <> +struct GetNumBits +{ + static constexpr int32_t value = 16; +}; // Meta function to get quantized min/max in compile time template @@ -262,6 +296,11 @@ struct GetAccDType static constexpr DType value = DType_INT48; }; template <> +struct GetAccDType +{ + static constexpr DType value = DType_FP16; +}; +template <> struct GetAccDType { static constexpr DType value = DType_FLOAT; diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index 2cd94bb..c617dda 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #include "tensor_ops.h" #include "quant_util.h" #include "template_types.h" +#include "half.hpp" using namespace TosaReference; using namespace Eigen; @@ -329,8 +330,8 @@ int OpArgMax::eval() return GraphNode::eval(); } -template -OpAvgPool2d::OpAvgPool2d(SubgraphTraverser* sgt_, +template +OpAvgPool2d::OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_AVG_POOL2D, id_) @@ -341,15 +342,15 @@ OpAvgPool2d::OpAvgPool2d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Pool); } -template -OpAvgPool2d::~OpAvgPool2d() +template +OpAvgPool2d::~OpAvgPool2d() { if (attribute) delete attribute; } -template -int OpAvgPool2d::checkTensorAttributes() +template +int OpAvgPool2d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -385,8 +386,8 @@ int OpAvgPool2d::checkTensorAttributes() // This calculates the number of padding elements used for each location along an axis // Average pooling only divides by the number of elements used, not including padding. // This function uses left/right, but is also used for vertical padding with top/bottom -template -ETensor1 OpAvgPool2d::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right) +template +ETensor1 OpAvgPool2d::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right) { ETensor1 result(out_size); @@ -414,8 +415,8 @@ ETensor1 OpAvgPool2d::calculate_div_map_1d(int in_size, int out_ // assuming input and output tensor have same scales like tflite reference // so no need to scale input and output -template -int OpAvgPool2d::eval() +template +int OpAvgPool2d::eval() { int in_batch = this->in->getShape()[0]; int in_height = this->in->getShape()[1]; @@ -439,11 +440,13 @@ int OpAvgPool2d::eval() int stride_h = this->attribute->stride()[0]; int stride_w = this->attribute->stride()[1]; + tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype(); + DEBUG_INFO(OP, "perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], " - "stride=[%d,%d], pad=[%d,%d,%d,%d]", + "stride=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s", in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_h, - kernel_w, stride_h, stride_w, pad_top, pad_bottom, pad_left, pad_right); + kernel_w, stride_h, stride_w, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]); Eigen::array im2col_input_dims; im2col_input_dims[0] = kernel_h * kernel_w; @@ -509,8 +512,7 @@ int OpAvgPool2d::eval() .contract(div_map_w.reshape(Eigen::array{ 1, out_width }), contract_dims) .reshape(Eigen::array{ 1, out_height, out_width, 1 }) .broadcast(bcast); - - if (Dtype != DType_FLOAT) + if (Dtype != DType_FLOAT && Dtype != DType_FP16) { try { @@ -531,14 +533,15 @@ int OpAvgPool2d::eval() } else { + // Case for float-type resizes this->out->getTensor() = (sum / div_map.template cast()).template cast(); } return GraphNode::eval(); } -template -OpConv2d::OpConv2d(SubgraphTraverser* sgt_, +template +OpConv2d::OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_CONV2D, id_) @@ -549,15 +552,15 @@ OpConv2d::OpConv2d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Conv); } -template -OpConv2d::~OpConv2d() +template +OpConv2d::~OpConv2d() { if (attribute) delete attribute; } -template -int OpConv2d::checkTensorAttributes() +template +int OpConv2d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -574,12 +577,12 @@ int OpConv2d::checkTensorAttributes() } ERROR_IF(outputs[0]->getDtype() != AccDtype, - "OpConv2d: Output data type not supported for this configuration of operator"); + "OpConv2d: Output data type not supported for this configuration of operator"); input = dynamic_cast*>(inputs[0]); weight = dynamic_cast*>(inputs[1]); bias = dynamic_cast*>(inputs[2]); - output = dynamic_cast*>(outputs[0]); + output = dynamic_cast*>(outputs[0]); std::string msg; if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(), @@ -593,8 +596,8 @@ int OpConv2d::checkTensorAttributes() return 0; } -template -int OpConv2d::eval() +template +int OpConv2d::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; @@ -630,12 +633,14 @@ int OpConv2d::eval() int dilation_h = this->attribute->dilation()[0]; int dilation_w = this->attribute->dilation()[1]; + tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype(); + DEBUG_INFO(OP, "perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], " - "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]", + "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s", in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch, out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, pad_top, - pad_bottom, pad_left, pad_right); + pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]); // GEMM-conv2d, left matrix is input, right matrix is weight Eigen::array im2col_input_dims; @@ -695,33 +700,33 @@ int OpConv2d::eval() // transpose and reshape weight from [OC, H, W, IC] to [H * W * IC, OC] ETensor2 im2col_weight = - weight_val.shuffle(Eigen::array({ 1, 2, 3, 0 })).reshape(im2col_weight_dims); + weight_val.shuffle(Eigen::array({ 1, 2, 3, 0 })).reshape(im2col_weight_dims); // don't need to apply bias_multiplier ( * bias_scale and >> bias_shift) since tflite already scale it // and reshaped from [C] to [1, C], and broadcast to [N * H * W, C] - ETensor2 bias_2d = this->bias->getTensor().reshape(bias_reshaped_dims).broadcast(bias_bcast_dims); + ETensor2 bias_2d = (this->bias->getTensor().reshape(bias_reshaped_dims).broadcast(bias_bcast_dims)).template cast(); // output matrix is [N * H * W, C] - ETensor2 contracted_result = - im2col_input.template cast().contract(im2col_weight.template cast(), contract_dims); + ETensor2 contracted_result = + (im2col_input.template cast().contract(im2col_weight.template cast(), contract_dims)).template cast(); // adding bias - ETensor2 biased_output = contracted_result + bias_2d.template cast(); + ETensor2 biased_output = contracted_result + bias_2d; // reshape back to [N, H, W, C] this->output->getTensor() = biased_output.reshape(col2im_output_dims); if (AccDtype == DType_INT48) { - this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); - this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); + this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); } return GraphNode::eval(); } -template -OpConv3d::OpConv3d(SubgraphTraverser* sgt_, +template +OpConv3d::OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_CONV3D, id_) @@ -732,15 +737,15 @@ OpConv3d::OpConv3d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Conv); } -template -OpConv3d::~OpConv3d() +template +OpConv3d::~OpConv3d() { if (attribute) delete attribute; } -template -int OpConv3d::checkTensorAttributes() +template +int OpConv3d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -757,12 +762,12 @@ int OpConv3d::checkTensorAttributes() } ERROR_IF(outputs[0]->getDtype() != AccDtype, - "OpConv3d: Output data type not supported for this configuration of operator"); + "OpConv3d: Output data type not supported for this configuration of operator"); input = dynamic_cast*>(inputs[0]); weight = dynamic_cast*>(inputs[1]); bias = dynamic_cast*>(inputs[2]); - output = dynamic_cast*>(outputs[0]); + output = dynamic_cast*>(outputs[0]); std::string msg; if (check_conv_attribute(attribute, 3 /* conv_dimension */, input->getShape(), output->getShape(), @@ -776,8 +781,8 @@ int OpConv3d::checkTensorAttributes() return 0; } -template -int OpConv3d::eval() +template +int OpConv3d::eval() { int in_batch = this->input->getShape()[0]; int in_depth = this->input->getShape()[1]; @@ -821,13 +826,15 @@ int OpConv3d::eval() int dilation_h = this->attribute->dilation()[1]; int dilation_w = this->attribute->dilation()[2]; + tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype(); + DEBUG_INFO( OP, "perform OpConv3d, input.shape=[%d,%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d,%d], output.shape=[%d,%d,%d,%d,%d], " - "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d]", + "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d], accum_dtype=%s", in_batch, in_depth, in_height, in_width, in_channels, f_out_channels, f_depth, f_height, f_width, f_in_channels, out_batch, out_depth, out_height, out_width, out_channels, stride_d, stride_h, stride_w, dilation_d, dilation_h, - dilation_w, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right); + dilation_w, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]); Eigen::array, 5> pad; pad[0] = std::make_pair(0, 0); @@ -860,7 +867,7 @@ int OpConv3d::eval() this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast); // 2. direct convolution - AccEigenType acc = 0; + AccEigenType acc(0.0); int d_idx, h_idx, w_idx; for (int ob = 0; ob < out_batch; ob++) @@ -874,7 +881,7 @@ int OpConv3d::eval() for (int oc = 0; oc < out_channels; oc++) { // Initialize accumulator with bias value - acc = this->output->getTensor()(ob, od, oh, ow, oc); + acc = (AccEigenType)this->output->getTensor()(ob, od, oh, ow, oc); for (int fd = 0; fd < f_depth; fd++) { d_idx = od * stride_d + fd * dilation_d; @@ -892,7 +899,7 @@ int OpConv3d::eval() } } } - this->output->getTensor()(ob, od, oh, ow, oc) = acc; + this->output->getTensor()(ob, od, oh, ow, oc) = (OutEigenType)acc; } } } @@ -901,15 +908,15 @@ int OpConv3d::eval() if (AccDtype == DType_INT48) { - this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); - this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); + this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); } return GraphNode::eval(); } -template -OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTraverser* sgt_, +template +OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_) @@ -920,15 +927,15 @@ OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTraverser* sg INIT_ATTRIBUTE(Conv); } -template -OpDepthwiseConv2d::~OpDepthwiseConv2d() +template +OpDepthwiseConv2d::~OpDepthwiseConv2d() { if (attribute) delete attribute; } -template -int OpDepthwiseConv2d::checkTensorAttributes() +template +int OpDepthwiseConv2d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -945,12 +952,12 @@ int OpDepthwiseConv2d::checkTensorAttributes() } ERROR_IF(outputs[0]->getDtype() != AccDtype, - "OpDepthwiseConv2d: Output data type not supported for this configuration of operator"); + "OpDepthwiseConv2d: Output data type not supported for this configuration of operator"); input = dynamic_cast*>(inputs[0]); weight = dynamic_cast*>(inputs[1]); bias = dynamic_cast*>(inputs[2]); - output = dynamic_cast*>(outputs[0]); + output = dynamic_cast*>(outputs[0]); std::string msg; if (check_conv_attribute(attribute, 2 /* conv_dimension */, input->getShape(), output->getShape(), @@ -964,8 +971,8 @@ int OpDepthwiseConv2d::checkTensorAttributes() return 0; } -template -int OpDepthwiseConv2d::eval() +template +int OpDepthwiseConv2d::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; @@ -1002,12 +1009,14 @@ int OpDepthwiseConv2d::eval() int dilation_h = this->attribute->dilation()[0]; int dilation_w = this->attribute->dilation()[1]; + tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype(); + DEBUG_INFO(OP, "perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], " - "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]", + "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s", in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch, out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, pad_top, - pad_bottom, pad_left, pad_right); + pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]); Eigen::array, 4> pad; pad[0] = std::make_pair(0, 0); @@ -1061,9 +1070,10 @@ int OpDepthwiseConv2d::eval() { for (int fw = 0; fw < f_width; fw++) { + // Perform multiplication in AccEigenType then cast to OutEigenType this->output->getTensor()(ob, oh, ow, ic * f_multiplier + cm) += - ((AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh, ic) * - (AccEigenType)weight_val(fh, fw, ic, cm)); + (OutEigenType)((AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh, ic) * + (AccEigenType)weight_val(fh, fw, ic, cm)); } } } @@ -1074,15 +1084,15 @@ int OpDepthwiseConv2d::eval() if (AccDtype == DType_INT48) { - this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); - this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); + this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); } return GraphNode::eval(); } -template -OpFullyConnected::OpFullyConnected(SubgraphTraverser* sgt_, +template +OpFullyConnected::OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_FULLY_CONNECTED, id_) @@ -1093,15 +1103,15 @@ OpFullyConnected::OpFullyConnected(SubgraphTraverser* sgt_ INIT_ATTRIBUTE(FullyConnected); } -template -OpFullyConnected::~OpFullyConnected() +template +OpFullyConnected::~OpFullyConnected() { if (attribute) delete attribute; } -template -int OpFullyConnected::checkTensorAttributes() +template +int OpFullyConnected::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -1128,9 +1138,9 @@ int OpFullyConnected::checkTensorAttributes() } ERROR_IF(outputs[0]->getDtype() != AccDtype, - "OpFullyConnected: Output data type not supported for this configuration of operator"); + "OpFullyConnected: Output data type not supported for this configuration of operator"); - output = dynamic_cast*>(outputs[0]); + output = dynamic_cast*>(outputs[0]); ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpFullyConnected: Input zeropoint must be zero for non int8_t data"); ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpFullyConnected: Weight zeropoint must be zero for non int8_t data"); @@ -1138,8 +1148,8 @@ int OpFullyConnected::checkTensorAttributes() return 0; } -template -int OpFullyConnected::eval() +template +int OpFullyConnected::eval() { typedef Eigen::Tensor::DimensionPair DimPair; Eigen::array dims{ { DimPair(1, 0) } }; @@ -1163,19 +1173,19 @@ int OpFullyConnected::eval() } this->output->getTensor() = - input_val.template cast().contract(weight_val.template cast(), dims) + - this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast); + input_val.template cast().contract(weight_val.template cast(), dims).template cast() + + this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast); if (AccDtype == DType_INT48) { - this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); - this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); + this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); } return GraphNode::eval(); } -template -OpMatMul::OpMatMul(SubgraphTraverser* sgt_, +template +OpMatMul::OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_MATMUL, id_) @@ -1186,15 +1196,15 @@ OpMatMul::OpMatMul(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(MatMul); } -template -OpMatMul::~OpMatMul() +template +OpMatMul::~OpMatMul() { if (attribute) delete attribute; } -template -int OpMatMul::checkTensorAttributes() +template +int OpMatMul::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -1205,11 +1215,11 @@ int OpMatMul::checkTensorAttributes() } ERROR_IF(outputs[0]->getDtype() != AccDtype, - "OpMatMul: Output data type not supported for this configuration of operator"); + "OpMatMul: Output data type not supported for this configuration of operator"); a = dynamic_cast*>(inputs[0]); b = dynamic_cast*>(inputs[1]); - output = dynamic_cast*>(outputs[0]); + output = dynamic_cast*>(outputs[0]); ASSERT_MEM(a && b && output); @@ -1255,8 +1265,8 @@ int OpMatMul::checkTensorAttributes() return 0; } -template -int OpMatMul::eval() +template +int OpMatMul::eval() { typedef Eigen::Tensor::DimensionPair DimPair; Eigen::array dims{ { DimPair(1, 0) } }; @@ -1289,22 +1299,22 @@ int OpMatMul::eval() TInRank2 b_rank2_val = b_val.slice(b_begin_array, b_size_array).reshape(b_rank2_shape); TAccRank2 output_rank2_val = a_rank2_val.template cast().contract(b_rank2_val.template cast(), dims); - TAcc output_rank3_val = output_rank2_val.reshape(output_rank3_shape); + TOut output_rank3_val = output_rank2_val.reshape(output_rank3_shape).template cast(); if (i == 0) { this->output->getTensor() = output_rank3_val; } else { - TAcc temp = this->output->getTensor().concatenate(output_rank3_val, 0); + TOut temp = this->output->getTensor().concatenate(output_rank3_val, 0); this->output->getTensor() = temp; } } if (AccDtype == DType_INT48) { - this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); - this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); + this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); } return GraphNode::eval(); @@ -1442,8 +1452,8 @@ int OpMaxPool2d::eval() return GraphNode::eval(); } -template -OpTransposeConv2d::OpTransposeConv2d(SubgraphTraverser* sgt_, +template +OpTransposeConv2d::OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_) @@ -1454,15 +1464,15 @@ OpTransposeConv2d::OpTransposeConv2d(SubgraphTraverser* sg INIT_ATTRIBUTE(TransposeConv); } -template -OpTransposeConv2d::~OpTransposeConv2d() +template +OpTransposeConv2d::~OpTransposeConv2d() { if (attribute) delete attribute; } -template -int OpTransposeConv2d::checkTensorAttributes() +template +int OpTransposeConv2d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -1473,12 +1483,12 @@ int OpTransposeConv2d::checkTensorAttributes() } ERROR_IF(outputs[0]->getDtype() != AccDtype, - "OpTransposeConv2d: Output data type not supported for this configuration of operator"); + "OpTransposeConv2d: Output data type not supported for this configuration of operator"); input = dynamic_cast*>(inputs[0]); weight = dynamic_cast*>(inputs[1]); bias = dynamic_cast*>(inputs[2]); - output = dynamic_cast*>(outputs[0]); + output = dynamic_cast*>(outputs[0]); if (attribute->out_pad().size() != 4) { @@ -1556,8 +1566,8 @@ int OpTransposeConv2d::checkTensorAttributes() return 0; } -template -int OpTransposeConv2d::eval() +template +int OpTransposeConv2d::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; @@ -1584,6 +1594,8 @@ int OpTransposeConv2d::eval() int stride_h = this->attribute->stride()[0]; int stride_w = this->attribute->stride()[1]; + tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype(); + ERROR_IF(in_batch != out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch); ERROR_IF(f_in_channels != in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", f_in_channels, in_channels); @@ -1594,10 +1606,10 @@ int OpTransposeConv2d::eval() DEBUG_INFO(OP, "perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], " - "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d]", + "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d], accum_dtype=%s", in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels, out_batch, out_height, out_width, out_channels, stride_h, stride_w, out_pad_top, - out_pad_bottom, out_pad_left, out_pad_right); + out_pad_bottom, out_pad_left, out_pad_right, EnumNamesDType()[accum_dtype]); TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); @@ -1645,8 +1657,8 @@ int OpTransposeConv2d::eval() if ((out_x >= 0 && out_x < out_width) && (out_y >= 0 && out_y < out_height)) { this->output->getTensor()(ob, out_y, out_x, oc) += - ((AccEigenType)input_val(ob, ih, iw, ic) * - (AccEigenType)weight_val(oc, fh, fw, ic)); + (OutEigenType) ((AccEigenType)input_val(ob, ih, iw, ic) * + (AccEigenType)weight_val(oc, fh, fw, ic)); } } } @@ -1658,51 +1670,68 @@ int OpTransposeConv2d::eval() if (AccDtype == DType_INT48) { - this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); - this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); + this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); } return GraphNode::eval(); } // template explicit instantiation +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); -DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, FLOAT) -DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, INT8) -DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, INT16) - -DEF_INSTANTIATE_TWO_TYPE(OpConv2d, FLOAT, FLOAT); -DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT8, INT4); -DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT8, INT8); -DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT16, INT8); - -DEF_INSTANTIATE_TWO_TYPE(OpConv3d, FLOAT, FLOAT); -DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT8, INT4); -DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT8, INT8); -DEF_INSTANTIATE_TWO_TYPE(OpConv3d, INT16, INT8); - -DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT); -DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT8, INT4); -DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT8, INT8); -DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT16, INT8); - -DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, FLOAT, FLOAT); -DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, INT8, INT4); -DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, INT8, INT8); -DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, INT16, INT8); - -DEF_INSTANTIATE_ONE_TYPE(OpMatMul, INT8); -DEF_INSTANTIATE_ONE_TYPE(OpMatMul, INT16); -DEF_INSTANTIATE_ONE_TYPE(OpMatMul, FLOAT); - +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP16); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FLOAT); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FLOAT, FLOAT); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT8, INT32); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT16, INT32); + + // [in_t, weight_t, acc_t] +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP16); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FLOAT, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT4, INT32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT8, INT32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT16, INT8, INT48); + +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP16); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FLOAT, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT4, INT32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT8, INT32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT16, INT8, INT48); + +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP16); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FLOAT, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT4, INT32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT8, INT32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT16, INT8, INT48); + +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP16); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FLOAT, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT4, INT32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT8, INT32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT16, INT8, INT48); + +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT8, INT32); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT16, INT48); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP16); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FLOAT); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FLOAT, FLOAT); + +DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FLOAT); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16); -DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, FLOAT, FLOAT); -DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, INT8, INT4); -DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, INT8, INT8); -DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, INT16, INT8); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP16); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FLOAT, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT4, INT32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT8, INT32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT16, INT8, INT48); diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index 24eadeb..fd6dd25 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -45,7 +45,7 @@ protected: TosaReference::TensorTemplate* output; }; -template +template class OpAvgPool2d : public GraphNode { public: @@ -55,9 +55,8 @@ public: virtual int checkTensorAttributes(); virtual int eval(); - static constexpr DType AccDtype = GetAccDType::value; using InEigenType = typename GetEigenType::type; - using AccEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TOut = Eigen::Tensor; @@ -75,7 +74,7 @@ protected: ETensor1 calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right); }; -template +template class OpConv2d : public GraphNode { public: @@ -85,15 +84,14 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType::value; - using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; - using TBias = Eigen::Tensor; - using TAcc = Eigen::Tensor; + using TBias = Eigen::Tensor; + using TOut = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; static constexpr int64_t AccQMax = GetQMax::value; @@ -102,11 +100,11 @@ protected: TosaReference::TensorTemplate* input; TosaReference::TensorTemplate* weight; TosaReference::TensorTemplate* bias; - TosaReference::TensorTemplate* output; + TosaReference::TensorTemplate* output; tosa::TosaConvAttribute* attribute; }; -template +template class OpConv3d : public GraphNode { public: @@ -116,15 +114,14 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType::value; - using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; - using TBias = Eigen::Tensor; - using TAcc = Eigen::Tensor; + using TBias = Eigen::Tensor; + using TOut = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; static constexpr int64_t AccQMax = GetQMax::value; @@ -133,11 +130,11 @@ protected: TosaReference::TensorTemplate* input; TosaReference::TensorTemplate* weight; TosaReference::TensorTemplate* bias; - TosaReference::TensorTemplate* output; + TosaReference::TensorTemplate* output; tosa::TosaConvAttribute* attribute; }; -template +template class OpDepthwiseConv2d : public GraphNode { public: @@ -147,15 +144,14 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType::value; - using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; - using TBias = Eigen::Tensor; - using TAcc = Eigen::Tensor; + using TBias = Eigen::Tensor; + using TOut = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; static constexpr int64_t AccQMax = GetQMax::value; @@ -164,11 +160,11 @@ protected: TosaReference::TensorTemplate* input; TosaReference::TensorTemplate* weight; TosaReference::TensorTemplate* bias; - TosaReference::TensorTemplate* output; + TosaReference::TensorTemplate* output; tosa::TosaConvAttribute* attribute; }; -template +template class OpFullyConnected : public GraphNode { public: @@ -178,14 +174,14 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType::value; using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; - using TBias = Eigen::Tensor; - using TAcc = Eigen::Tensor; + using TBias = Eigen::Tensor; + using TOut = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; static constexpr int64_t AccQMax = GetQMax::value; @@ -194,12 +190,12 @@ protected: TosaReference::TensorTemplate* input; TosaReference::TensorTemplate* weight; TosaReference::TensorTemplate* bias; - TosaReference::TensorTemplate* output; + TosaReference::TensorTemplate* output; tosa::TosaFullyConnectedAttribute* attribute; }; -template +template class OpMatMul : public GraphNode { public: @@ -209,11 +205,11 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType::value; using InEigenType = typename GetEigenType::type; - using AccEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; - using TAcc = Eigen::Tensor; + using TOut = Eigen::Tensor; using TInRank2 = Eigen::Tensor; using TAccRank2 = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; @@ -222,7 +218,7 @@ public: protected: TosaReference::TensorTemplate* a; TosaReference::TensorTemplate* b; - TosaReference::TensorTemplate* output; + TosaReference::TensorTemplate* output; int64_t N; int64_t H; int64_t W; @@ -252,7 +248,7 @@ protected: tosa::TosaPoolAttribute* attribute; }; -template +template class OpTransposeConv2d : public GraphNode { public: @@ -262,15 +258,14 @@ public: virtual int checkTensorAttributes() final; virtual int eval() final; - static constexpr DType AccDtype = GetAccDType::value; - using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; - using TBias = Eigen::Tensor; - using TAcc = Eigen::Tensor; + using TBias = Eigen::Tensor; + using TOut = Eigen::Tensor; static constexpr int64_t AccQMin = GetQMin::value; static constexpr int64_t AccQMax = GetQMax::value; @@ -279,7 +274,7 @@ protected: TosaReference::TensorTemplate* input; TosaReference::TensorTemplate* weight; TosaReference::TensorTemplate* bias; - TosaReference::TensorTemplate* output; + TosaReference::TensorTemplate* output; TosaTransposeConvAttribute* attribute; }; diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 52de2e4..50e710a 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ #include "quant_util.h" #include "template_types.h" #include +#include "half.hpp" using namespace TosaReference; using namespace Eigen; @@ -286,6 +287,30 @@ CastHelper::CastHelper() }; } +template +CastHelper::CastHelper() +{ + fcn = [](InEigenType in) -> float { + half_float::half out = half_float::half_cast(in); // Cast to half_float + return half_float::half_cast(out); // Cast to float (underlying FP16 EigenType) + }; +} + +template +CastHelper::CastHelper() +{ + // Assuming InEigenType = float. + fcn = [](float in) -> OutEigenType { + // Perform initial rounding in half-precision then cast back to float + half_float::half h = half_float::half_cast(in); + h = std::round(h); + OutEigenType out = half_float::half_cast(h); + out = std::max(out, OutMin); + out = std::min(out, OutMax); + return out; + }; +} + template CastHelper::CastHelper() { @@ -313,15 +338,21 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32); diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h index 53470d1..5f197cf 100644 --- a/reference_model/src/ops/type_conversion.h +++ b/reference_model/src/ops/type_conversion.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -100,6 +100,42 @@ private: FcnType fcn; }; +template +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + static constexpr int32_t OutMin = GetQMin::value; + static constexpr int32_t OutMax = GetQMax::value; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + template class CastHelper { diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index d0cc6cf..386f0e5 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -385,6 +385,14 @@ int SubgraphTraverser::allocateTensor() tensor->setTensorValueInt64(i64_data.size(), i64_data.data()); } break; + case DType_FP16: + { + // Interpret f16 data as float + std::vector f16_data; + TosaSerializationHandler::ConvertU8toF16(ts->GetData(), tensor->getElementCount(), f16_data); + tensor->setTensorValueFloat(f16_data.size(), f16_data.data()); + } + break; case DType_FLOAT: { std::vector fp32_data; @@ -693,7 +701,7 @@ int SubgraphTraverser::validateGraph() DType dtype = currTensor->getDtype(); // Float-point disallowed - if (dtype == DType_FLOAT) + if (dtype == DType_FLOAT || dtype == DType_FP16) { WARNING("SubgraphTraverser::validateGraph(): TOSA Base Inference profile selected: All floating point " "disabled, but %s tensor %s found\n", diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index 7cbeb13..cbe12a9 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ #include "tensor.h" #include "arith_util.h" +#include "half.hpp" using namespace TosaReference; using namespace Eigen; @@ -84,6 +85,7 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) { uint32_t elements = getElementCount(); float* fdatabuf = nullptr; + half_float::half* f16databuf = nullptr; int32_t* i32databuf = nullptr; int64_t* i64databuf = nullptr; bool* bdatabuf = nullptr; @@ -97,6 +99,14 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) nperror = NumpyUtilities::readFromNpyFile(filename, elements, fdatabuf); break; + case DType_FP16: + f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements); + ASSERT_MEM(f16databuf); + fdatabuf = (float*)calloc(sizeof(float), elements); + ASSERT_MEM(fdatabuf); + + nperror = NumpyUtilities::readFromNpyFile(filename, elements, f16databuf); + break; case DType_INT32: case DType_UINT8: case DType_INT4: @@ -146,9 +156,17 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) switch (getDtype()) { + case DType_FP16: + // Convert from fp16 to fp32 + for (uint32_t i=0; i < elements; i++) { + fdatabuf[i] = half_float::half_cast(f16databuf[i]); + } + // Fall through to DType_FLOAT case case DType_FLOAT: if (setTensorValueFloat(elements, fdatabuf)) { + if (f16databuf) + free(f16databuf); free(fdatabuf); return 1; } @@ -187,6 +205,8 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) if (fdatabuf) free(fdatabuf); + if (f16databuf) + free(f16databuf); if (i32databuf) free(i32databuf); if (i64databuf) @@ -200,11 +220,12 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) int TosaReference::Tensor::writeToNpyFile(const char* filename) const { float* fdatabuf = nullptr; + half_float::half* f16databuf = nullptr; int32_t* i32databuf = nullptr; int64_t* i64databuf = nullptr; bool* bdatabuf = nullptr; NumpyUtilities::NPError nperror; - int elements = getElementCount(); + uint32_t elements = getElementCount(); switch (getDtype()) { @@ -222,6 +243,27 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const free(fdatabuf); break; + case DType_FP16: + fdatabuf = (float*)calloc(sizeof(float), elements); + ASSERT_MEM(fdatabuf); + f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements); + ASSERT_MEM(f16databuf); + + if (getTensorValueFloat(elements, fdatabuf)) + { + free(fdatabuf); + free(f16databuf); + return 1; + } + // Convert fp32 to fp16 + for (uint32_t i=0; i < elements; i++) { + f16databuf[i] = half_float::half_cast(fdatabuf[i]); + } + nperror = NumpyUtilities::writeToNpyFile(filename, shape, f16databuf); + + free(fdatabuf); + free(f16databuf); + break; case DType_INT32: case DType_UINT8: case DType_INT4: diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index 6b7d5f1..78a210e 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -643,6 +643,7 @@ public: switch (tensorDtype_) { case DType_FLOAT: + case DType_FP16: switch (rank) { case 0: diff --git a/scripts/json2numpy/json2numpy.py b/scripts/json2numpy/json2numpy.py index 21b1acd..c04013e 100644 --- a/scripts/json2numpy/json2numpy.py +++ b/scripts/json2numpy/json2numpy.py @@ -16,8 +16,10 @@ class NumpyArrayEncoder(json.JSONEncoder): """Encode default operation.""" if isinstance(obj, np.integer): return int(obj) - elif isinstance(obj, np.floating): + elif isinstance(obj, np.float32): return float(obj) + elif isinstance(obj, np.float16): + return np.float16(obj) elif isinstance(obj, np.ndarray): return obj.tolist() return super(NumpyArrayEncoder, self).default(obj) diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib index c92710d..485a11d 160000 --- a/thirdparty/serialization_lib +++ b/thirdparty/serialization_lib @@ -1 +1 @@ -Subproject commit c92710d7259558fb0cd9e9b38d0c78da21c6e2d4 +Subproject commit 485a11d8cb67c8062c632f0987cd31cedbe93d6d diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py index 66864c2..8ae3218 100644 --- a/verif/checker/tosa_result_checker.py +++ b/verif/checker/tosa_result_checker.py @@ -147,14 +147,14 @@ def test_check( tolerance = 0.0 # Fall-through to below to add failure values - elif reference_result.dtype == np.float32: + # TODO: update for fp16 tolerance + elif reference_result.dtype == np.float32 or reference_result.dtype == np.float16: tolerance = float_tolerance if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True): print_color(LogColors.GREEN, "Results PASS {}".format(test_name)) return (TestResult.PASS, tolerance, "") msg = "Float result does not match within tolerance of {}".format(tolerance) # Fall-through to below to add failure values - else: print_color(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name)) msg = "Unsupported results type: {}".format(reference_result.dtype) diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index a65e220..69968d3 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -2,10 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import itertools import math +import warnings import numpy as np from generator.tosa_error_if import ErrorIf from generator.tosa_error_if import TosaErrorIfArgGen +from generator.tosa_utils import get_accum_dtype_from_tgTypes +from generator.tosa_utils import get_wrong_output_type from generator.tosa_utils import MAX_RESIZE_DIMENSION from serializer.tosa_serializer import DTypeNames from tosa.DType import DType @@ -773,7 +776,7 @@ class TosaTensorValuesGen: ), "Op.MUL must have 2 placeholders, 0 consts" tens = [] - if dtypeList[0] == DType.FLOAT: + if dtypeList[0] in (DType.FP16, DType.FLOAT): tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:])) else: placeholders = [] @@ -982,7 +985,7 @@ class TosaArgGen: return axes @staticmethod - def agConv(testGen, opName, shapeList, dtype, error_name=None): + def agConv(testGen, opName, shapeList, dtypes, error_name=None): arg_list = [] ifm_shape = shapeList[0] @@ -990,6 +993,8 @@ class TosaArgGen: # determine the kernel shape from operator name (e.g. "conv2d_3x3" => [3,3]) k = [int(x) for x in opName.split("_")[-1].split("x")] + accum_dtype = get_accum_dtype_from_tgTypes(dtypes) + # Check the rank rank = 5 if opName.startswith("conv3d") else 4 if error_name != ErrorIf.WrongRank: @@ -1089,12 +1094,13 @@ class TosaArgGen: ): arg_list.append( ( - "st{}_pad{}_dilat{}".format( + "acc{}_st{}_pad{}_dilat{}".format( + testGen.typeStr(accum_dtype), "".join([str(x) for x in s]), "".join([str(x) for x in p]), "".join([str(x) for x in d]), ), - [s, p, d], + [accum_dtype, s, p, d], ) ) n += 1 @@ -1102,12 +1108,55 @@ class TosaArgGen: return arg_list @staticmethod - def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None): + def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None): + + if isinstance(dtypes, list) or isinstance(dtypes, tuple): + input_dtype = dtypes[0] + else: + input_dtype = dtypes + + if error_name == ErrorIf.WrongOutputType: + accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype) + elif error_name == ErrorIf.WrongInputType: + # Pick some potentially correct output dtype if input type is incorrect + accum_dtype = DType.INT32 + else: + accum_dtype = get_accum_dtype_from_tgTypes(dtypes) + + return [(f"acc{testGen.typeStr(accum_dtype)}", [accum_dtype])] + + @staticmethod + def agMatMul(testGen, opName, shapeList, dtype, error_name=None): + # Get valid accumulate type(s) + if dtype == DType.INT8: + accum_dtypes = [DType.INT32] + elif dtype == DType.INT16: + accum_dtypes = [DType.INT48] + elif dtype == DType.FP16: + accum_dtypes = [DType.FP16, DType.FLOAT] + elif dtype == DType.FLOAT: + accum_dtypes = [DType.FLOAT] + elif error_name is None: + assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}" + + if error_name == ErrorIf.WrongOutputType: + # Get incorrect output dtype for ErrorIf case + accum_dtypes = [get_wrong_output_type(opName, testGen.rng, dtype)] + elif error_name == ErrorIf.WrongInputType: + # Pick some potentially correct output dtype if input type is incorrect + accum_dtypes = [DType.INT32] + + return [(f"acc{testGen.typeStr(a)}", [a]) for a in accum_dtypes] + + @staticmethod + def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None): arg_list = [] ifm_shape = shapeList[0] filter_shape = shapeList[1] + accum_dtype = get_accum_dtype_from_tgTypes(dtypes) + # Must be rank 4 if error_name != ErrorIf.WrongRank: assert len(ifm_shape) == 4 @@ -1169,12 +1218,13 @@ class TosaArgGen: os = [ifm_shape[0], oh, ow, filter_shape[0]] arg_list.append( ( - "st{}_pad{}_os{}".format( + "acc{}_st{}_pad{}_os{}".format( + testGen.typeStr(accum_dtype), "".join([str(x) for x in s]), "".join([str(x) for x in p]), "x".join([str(x) for x in os]), ), - [s, p, os], + [accum_dtype, s, p, os], ) ) n += 1 @@ -1199,18 +1249,38 @@ class TosaArgGen: if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]: pad_const_int = testGen.getRandNumberDType(dtype) pad_const_fp = 0 - elif dtype == DType.FLOAT: + elif dtype in (DType.FP16, DType.FLOAT): pad_const_int = 0 pad_const_fp = testGen.getRandNumberDType(dtype) else: return [] for paddings in shape_pad_values: - name = "pad" - for r in range(rank): - before, after = paddings[r] - name = f"{name}{before}{after}" - arg_list.append((name, [np.array(paddings), pad_const_int, pad_const_fp])) + paddings = list(paddings) + args_valid = True + + if error_name == ErrorIf.PadSmallerZero: + # Prevent negative output shapes while ensuring still testing for negative padding + for i in range(rank): + dim_after_padding = ( + paddings[i][0] + paddings[i][1] + shapeList[0][i] + ) + if dim_after_padding < 1: + paddings[i] = (0, 0) + if all([p > -1 for p in paddings[i]]): + args_valid = False + + if args_valid: + name = "pad" + for r in range(rank): + before, after = paddings[r] + name = f"{name}{before}{after}" + arg_list.append( + (name, [np.array(paddings), pad_const_int, pad_const_fp]) + ) + + if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0: + warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}") return arg_list @@ -1232,6 +1302,21 @@ class TosaArgGen: k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 1)] kernels = {x for x in itertools.product(*([k_vals] * 2))} + if opName == "max_pool2d": + accum_dtypes = [None] # max_pool has no accumulate dtype + elif dtype == DType.INT8 or dtype == DType.INT16: + accum_dtypes = [DType.INT32] + elif dtype == DType.FP16: + accum_dtypes = [DType.FP16, DType.FLOAT] + elif dtype == DType.FLOAT: + accum_dtypes = [DType.FLOAT] + elif error_name is None: + assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}" + else: + # Set to something for the ErrorIf case which has + # incorrect input data-type + accum_dtypes = [DType.INT32] + if testGen.args.oversize: # add some oversize argument values bigStride = 7 @@ -1252,63 +1337,70 @@ class TosaArgGen: sparsity_factor = 2 if error_name else 500 sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1 + arg_str = ( + "acc{}_st{}_kern{}_pad{}" + if accum_dtypes[0] is not None + else "st{}_kern{}_pad{}" + ) + + def get_arg_list_element(accum, stride, pad, kern): + # Return tuple containing the formatted argument string and + # the corresponding argument values + arg_str_elems = [ + "".join([str(x) for x in stride]), + "".join([str(x) for x in kern]), + "".join([str(x) for x in pad]), + ] + # Note: different order to string + arg_val_elems = [stride, pad, kern] + + if accum is not None: + arg_str_elems.insert(0, testGen.typeStr(accum)) + arg_val_elems.insert(0, accum) + return (arg_str.format(*arg_str_elems), arg_val_elems) + n = 0 - for s in sorted(list(strides)): - for p in sorted(list(paddings)): - for k in sorted(list(kernels)): - if error_name in [ - ErrorIf.StrideSmallerOne, - ErrorIf.KernelSmallerOne, - ErrorIf.PadSmallerZero, - ErrorIf.PadLargerEqualKernel, - ]: - sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf( - testGen, error_name, s, p, k - ) - if None not in [sNew, pNew, kNew] and n % sparsity == 0: - arg_list.append( - ( - "st{}_kern{}_pad{}".format( - "".join([str(x) for x in sNew]), - "".join([str(x) for x in kNew]), - "".join([str(x) for x in pNew]), - ), - [sNew, pNew, kNew], - ) + for a in accum_dtypes: + for s in sorted(list(strides)): + for p in sorted(list(paddings)): + for k in sorted(list(kernels)): + if error_name in [ + ErrorIf.StrideSmallerOne, + ErrorIf.KernelSmallerOne, + ErrorIf.PadSmallerZero, + ErrorIf.PadLargerEqualKernel, + ]: + sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf( + testGen, error_name, s, p, k ) - elif ( - n % sparsity == 0 - # padding must not exceed the kernel size - and p[0] < k[0] - and p[1] < k[0] - and p[2] < k[1] - and p[3] < k[1] - # the padded shape must exceed the kernel size - and (shape[1] + p[0] + p[1]) > k[0] - and (shape[2] + p[2] + p[3]) > k[1] - ): - remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0] - remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1] - if ( - # the parameters must produce integer exact output - error_name != ErrorIf.PoolingOutputShapeNonInteger - and remainder_h == 0 - and remainder_w == 0 - ) or ( - error_name == ErrorIf.PoolingOutputShapeNonInteger - and (remainder_h != 0 or remainder_w != 0) + if None not in [sNew, pNew, kNew] and n % sparsity == 0: + arg_vals = [a, sNew, pNew, kNew] + arg_list.append(get_arg_list_element(*arg_vals)) + elif ( + n % sparsity == 0 + # padding must not exceed the kernel size + and p[0] < k[0] + and p[1] < k[0] + and p[2] < k[1] + and p[3] < k[1] + # the padded shape must exceed the kernel size + and (shape[1] + p[0] + p[1]) > k[0] + and (shape[2] + p[2] + p[3]) > k[1] ): - arg_list.append( - ( - "st{}_kern{}_pad{}".format( - "".join([str(x) for x in s]), - "".join([str(x) for x in k]), - "".join([str(x) for x in p]), - ), - [s, p, k], - ) - ) - n += 1 + remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0] + remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1] + if ( + # the parameters must produce integer exact output + error_name != ErrorIf.PoolingOutputShapeNonInteger + and remainder_h == 0 + and remainder_w == 0 + ) or ( + error_name == ErrorIf.PoolingOutputShapeNonInteger + and (remainder_h != 0 or remainder_w != 0) + ): + arg_vals = [a, s, p, k] + arg_list.append(get_arg_list_element(*arg_vals)) + n += 1 return arg_list @@ -1327,6 +1419,8 @@ class TosaArgGen: dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT] elif inDtype == DType.BOOL: dtypeList = [DType.INT8, DType.INT16, DType.INT32] + elif inDtype == DType.FP16: + dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif inDtype == DType.FLOAT: dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif error_name == ErrorIf.WrongInputType: @@ -1734,6 +1828,8 @@ class TosaArgGen: outputDTypeList = [DType.INT32] elif mode == ResizeMode.BILINEAR and dtype == DType.INT16: outputDTypeList = [DType.INT48] + elif dtype == DType.FP16: + outputDTypeList = [DType.FP16] elif dtype == DType.FLOAT: outputDTypeList = [DType.FLOAT] elif error_name == ErrorIf.WrongInputType: diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index f9a00f9..a766803 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -120,6 +120,7 @@ class TosaErrorIfArgGen: DType.INT32, DType.INT48, DType.FLOAT, + DType.FP16, ) elif mode == ResizeMode.NEAREST and dtype == DType.INT16: incorrect_types = ( @@ -128,6 +129,7 @@ class TosaErrorIfArgGen: DType.INT32, DType.INT48, DType.FLOAT, + DType.FP16, ) elif mode == ResizeMode.BILINEAR and dtype == DType.INT8: incorrect_types = ( @@ -136,6 +138,7 @@ class TosaErrorIfArgGen: DType.INT16, DType.INT48, DType.FLOAT, + DType.FP16, ) elif mode == ResizeMode.BILINEAR and dtype == DType.INT16: incorrect_types = ( @@ -144,6 +147,16 @@ class TosaErrorIfArgGen: DType.INT16, DType.INT32, DType.FLOAT, + DType.FP16, + ) + elif dtype == DType.FP16: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FLOAT, ) elif dtype == DType.FLOAT: incorrect_types = ( @@ -152,6 +165,7 @@ class TosaErrorIfArgGen: DType.INT16, DType.INT32, DType.INT48, + DType.FP16, ) outputDType = testGen.rng.choice(a=incorrect_types) @@ -285,8 +299,8 @@ class TosaErrorIfArgGen: @staticmethod def eiCastErrorIf(testGen, input_dtype): - if input_dtype in [DType.BOOL, DType.FLOAT]: - outputDType = [DType.BOOL, DType.INT48, DType.FLOAT] + if input_dtype in [DType.BOOL, DType.FP16, DType.FLOAT]: + outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FLOAT] elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]: outputDType = [DType.INT48] else: @@ -400,6 +414,7 @@ class TosaErrorValidator: and input_dtype == DType.INT16 and output_dtype != DType.INT48 ) + or (input_dtype == DType.FP16 and output_dtype != DType.FP16) or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT) ): error_result = True @@ -413,19 +428,28 @@ class TosaErrorValidator: if ( (input_dtype == DType.INT8 and output_dtype != DType.INT32) or (input_dtype == DType.INT16 and output_dtype != DType.INT48) + or ( + input_dtype == DType.FP16 + and output_dtype not in (DType.FP16, DType.FLOAT) + ) or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT) ): error_result = True elif op["op"] == Op.ARGMAX: if ( - input_dtype in [DType.INT8, DType.INT16, DType.FLOAT] + input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT] and output_dtype != DType.INT32 ): error_result = True elif op["op"] == Op.MUL: - if input_dtype != DType.FLOAT and output_dtype != DType.INT32: + if ( + input_dtype not in (DType.FP16, DType.FLOAT) + and output_dtype != DType.INT32 + ): + error_result = True + elif input_dtype == DType.FP16 and output_dtype != DType.FP16: error_result = True elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT: error_result = True @@ -449,17 +473,39 @@ class TosaErrorValidator: or ( input_dtype == DType.INT8 and output_dtype - not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT] + not in [ + DType.BOOL, + DType.INT16, + DType.INT32, + DType.FLOAT, + DType.FP16, + ] ) or ( input_dtype == DType.INT16 and output_dtype - not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT] + not in [ + DType.BOOL, + DType.INT8, + DType.INT32, + DType.FLOAT, + DType.FP16, + ] ) or ( input_dtype == DType.INT32 and output_dtype - not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT] + not in [ + DType.BOOL, + DType.INT8, + DType.INT16, + DType.FLOAT, + DType.FP16, + ] + ) + or ( + input_dtype == DType.FP16 + and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] ) or ( input_dtype == DType.FLOAT @@ -479,6 +525,8 @@ class TosaErrorValidator: and output_dtype != DType.INT32 or input_dtype == DType.INT16 and output_dtype != DType.INT48 + or input_dtype == DType.FP16 + and output_dtype not in (DType.FP16, DType.FLOAT) or input_dtype == DType.FLOAT and output_dtype != DType.FLOAT ): @@ -2257,12 +2305,13 @@ class TosaInvalidValidator: return ( not (input_dtype == DType.INT8 and output_dtype == DType.INT32) and not (input_dtype == DType.INT16 and output_dtype == DType.INT48) + and not (input_dtype == DType.FP16 and output_dtype == DType.FP16) and not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) ) elif mode == ResizeMode.NEAREST: # Invalid output data type / Invalid input datatype return (input_dtype != output_dtype) or ( - input_dtype not in [DType.INT8, DType.INT16, DType.FLOAT] + input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT] ) else: # Invalid resize mode @@ -2276,8 +2325,11 @@ class TosaInvalidValidator: input_shape = inputShapes[0] args = kwargs["args"] - strides = args[0] - padding = args[1] + + # MaxPool2D has no accum_dtype arg + stride_idx, pad_idx = (0, 1) if opName == "max_pool2d" else (1, 2) + strides = args[stride_idx] + padding = args[pad_idx] if opName.endswith("pool2d"): # avg_pool2d, max_pool2d @@ -2365,7 +2417,7 @@ class TosaInvalidValidator: @staticmethod def ivNonPositiveOutputShape(**kwargs): args = kwargs["args"] - output_shape = args[2] + output_shape = args[3] if output_shape[1] <= 0 or output_shape[2] <= 0: # Negative output shape return True diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index b76b656..9ff6ec5 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -81,6 +81,8 @@ class TosaTestGen: return np.int64( self.rng.integers(low=-(1 << 47), high=(1 << 47), size=shape) ) + elif dtype == DType.FP16: + return np.float16(self.rng.random(size=shape)) elif dtype == DType.FLOAT: return np.float32(self.rng.random(size=shape)) else: @@ -128,6 +130,9 @@ class TosaTestGen: def getRandNumberDType(self, dtype): if dtype == DType.FLOAT: return self.rng.random() + elif dtype == DType.FP16: + rand_f32 = self.rng.random() + return np.float16(rand_f32) elif dtype == DType.BOOL: return self.rng.choice([False, True]) # TOSA specific INT4 weight range from -7 to 7 @@ -178,13 +183,15 @@ class TosaTestGen: return "i32" elif t == DType.INT48: return "i48" + elif t == DType.FP16: + return "f16" elif t == DType.FLOAT: return "float" else: raise Exception("Unknown dtype, cannot convert to string: {}".format(t)) def typeWidth(self, t): - """Get the datatype width for integer types""" + """Get the datatype width for data types""" if t == DType.INT4: return 4 elif t == DType.INT8: @@ -199,6 +206,8 @@ class TosaTestGen: return 32 elif t == DType.INT48: return 48 + elif t == DType.FP16: + return 16 elif t == DType.FLOAT: return 32 elif t == DType.BOOL: @@ -346,7 +355,7 @@ class TosaTestGen: # Special for multiply: # Force the result to INT32 for INT types - if a.dtype != DType.FLOAT: + if a.dtype not in (DType.FP16, DType.FLOAT): result_tens.setDtype(DType.INT32) if error_name == ErrorIf.WrongOutputType: all_dtypes = [DType.INT8, DType.INT16, DType.INT48] @@ -533,6 +542,7 @@ class TosaTestGen: self, op, input, + accum_dtype, stride, pad, kernel, @@ -585,17 +595,43 @@ class TosaTestGen: qinfo = [0, 0] attr = ts.TosaSerializerAttribute() - attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1]) + attr.PoolAttribute(kernel, stride, pad, qinfo[0], qinfo[1], accum_dtype) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens + def build_maxpool2d( + self, + op, + input, + stride, + pad, + kernel, + validator_fcns=None, + error_name=None, + qinfo=None, + ): + # Same as build_pool2d but manually sets accum_dtype value + # (maxpool has no accum_dtype) + return self.build_pool2d( + op, + input, + DType.UNKNOWN, + stride, + pad, + kernel, + validator_fcns, + error_name, + qinfo, + ) + def build_conv2d( self, op, ifm, filter, bias, + accum_dtype, strides, padding, dilations, @@ -605,7 +641,15 @@ class TosaTestGen: ): assert len(padding) == 4 result_tens = OutputShaper.conv2dOp( - self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name + self.ser, + self.rng, + ifm, + filter, + accum_dtype, + strides, + padding, + dilations, + error_name, ) # Ensure new output type has correct qinfo @@ -648,7 +692,7 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1]) + attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens @@ -659,6 +703,7 @@ class TosaTestGen: ifm, filter, bias, + accum_dtype, strides, padding, dilations, @@ -668,7 +713,15 @@ class TosaTestGen: ): assert len(padding) == 6 result_tens = OutputShaper.conv3dOp( - self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name + self.ser, + self.rng, + ifm, + filter, + accum_dtype, + strides, + padding, + dilations, + error_name, ) # Ensure new output type has correct qinfo @@ -711,7 +764,7 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1]) + attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens @@ -722,6 +775,7 @@ class TosaTestGen: ifm, filter, bias, + accum_dtype, stride, out_pad, output_shape, @@ -731,7 +785,7 @@ class TosaTestGen: ): assert len(out_pad) == 4 result_tens = OutputShaper.transposeConv2DOp( - self.ser, self.rng, ifm, output_shape, error_name + self.ser, self.rng, ifm, output_shape, accum_dtype, error_name ) # Ensure new output type has correct qinfo @@ -773,7 +827,9 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1]) + attr.TransposeConvAttribute( + out_pad, stride, output_shape, qinfo[0], qinfo[1], accum_dtype + ) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens @@ -784,6 +840,7 @@ class TosaTestGen: ifm, filter, bias, + accum_dtype, strides, padding, dilations, @@ -792,7 +849,15 @@ class TosaTestGen: qinfo=None, ): result_tens = OutputShaper.depthwiseConv2dOp( - self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name + self.ser, + self.rng, + ifm, + filter, + accum_dtype, + strides, + padding, + dilations, + error_name, ) # Ensure new output type has correct qinfo @@ -835,16 +900,24 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1]) + attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens def build_fully_connected( - self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None + self, + op, + ifm, + filter, + bias, + accum_dtype, + validator_fcns=None, + error_name=None, + qinfo=None, ): result_tens = OutputShaper.fullyConnectedOp( - self.ser, self.rng, ifm, filter, error_name + self.ser, self.rng, ifm, filter, accum_dtype, error_name ) # Invalidate Input/Output list for error if checks. @@ -871,17 +944,22 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, + accum_dtype=accum_dtype, ): return None attr = ts.TosaSerializerAttribute() - attr.FullyConnectedAttribute(qinfo[0], qinfo[1]) + attr.FullyConnectedAttribute(qinfo[0], qinfo[1], accum_dtype) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens - def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None): - result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name) + def build_matmul( + self, op, a, b, accum_dtype, validator_fcns=None, error_name=None, qinfo=None + ): + result_tens = OutputShaper.matmulOp( + self.ser, self.rng, a, b, accum_dtype, error_name + ) # Invalidate Input/Output list for error if checks. input_list = [a.name, b.name] @@ -908,11 +986,12 @@ class TosaTestGen: input_list=input_list, output_list=output_list, num_operands=num_operands, + accum_dtype=accum_dtype, ): return None attr = ts.TosaSerializerAttribute() - attr.MatMulAttribute(qinfo[0], qinfo[1]) + attr.MatMulAttribute(qinfo[0], qinfo[1], accum_dtype) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens @@ -995,7 +1074,7 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - if a.dtype == DType.FLOAT: + if a.dtype in (DType.FP16, DType.FLOAT): attr.ClampAttribute(0, 0, min_val, max_val) else: attr.ClampAttribute(min_val, max_val, 0, 0) @@ -1811,7 +1890,7 @@ class TosaTestGen: op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr ) - if a.dtype in (DType.FLOAT, DType.INT32): + if a.dtype in (DType.FLOAT, DType.FP16, DType.INT32): then_op, else_op = Op.ADD, Op.SUB elif a.dtype in (DType.INT8, DType.INT16): then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT @@ -2350,22 +2429,37 @@ class TosaTestGen: # if not specified, defaults to (1, 4) # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum) # 'types': array of datatypes to be tested - TYPE_FP = [DType.FLOAT] + TYPE_FP = [DType.FLOAT, DType.FP16] TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4 - TYPE_INT_FP = [DType.INT8, DType.INT16, DType.INT32, DType.FLOAT] # Excludes INT4 + TYPE_INT_FP = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.FP16, + DType.FLOAT, + ] # Excludes INT4 TYPE_BOOL = [DType.BOOL] - TYPE_FI32 = [DType.FLOAT, DType.INT32] - TYPE_FIB = [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL] + TYPE_FI32 = [DType.FLOAT, DType.FP16, DType.INT32] # floating-types and INT32 + TYPE_FIB = [ + DType.FP16, + DType.FLOAT, + DType.INT8, + DType.INT16, + DType.INT32, + DType.BOOL, + ] TYPE_FI16 = [DType.FLOAT, DType.INT16] - TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT] + TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT] TYPE_CONV = [ [DType.INT8, DType.INT4, DType.INT32], [DType.INT8, DType.INT8, DType.INT32], [DType.INT16, DType.INT8, DType.INT48], + [DType.FP16, DType.FP16, DType.FP16], + [DType.FP16, DType.FP16, DType.FLOAT], DType.FLOAT, ] @@ -2524,7 +2618,7 @@ class TosaTestGen: build_fully_connected, TosaTensorGen.tgFullyConnected, TosaTensorValuesGen.tvgDefault, - None, + TosaArgGen.agFullyConnected, ), "qgen": TosaQuantGen.qgConv, "types": TYPE_CONV, @@ -2546,7 +2640,7 @@ class TosaTestGen: build_matmul, TosaTensorGen.tgMatmul, TosaTensorValuesGen.tvgDefault, - None, + TosaArgGen.agMatMul, ), "qgen": TosaQuantGen.qgMatmul, "types": TYPE_NARROW_INT_FP, @@ -2564,7 +2658,7 @@ class TosaTestGen: "operands": (1, 0), "rank": (4, 4), "build_fcn": ( - build_pool2d, + build_maxpool2d, TosaTensorGen.tgNHWC, TosaTensorValuesGen.tvgDefault, TosaArgGen.agPooling, @@ -3384,7 +3478,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgReduceSum, TosaArgGen.agAxis, ), - "types": TYPE_FI32, + "types": (DType.FP16, DType.FLOAT, DType.INT32), "error_if_validators": ( TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, @@ -3571,7 +3665,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, None, ), - "types": TYPE_INT_FP, + "types": (DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.FLOAT), "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, @@ -3612,7 +3706,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, TosaArgGen.agResize, ), - "types": [DType.INT8, DType.INT16, DType.FLOAT], + "types": (DType.INT8, DType.INT16, DType.FP16, DType.FLOAT), "invalid_test_validators": ( TosaInvalidValidator.ivWrongDataTypeOrModeResize, ), @@ -3646,7 +3740,14 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, TosaArgGen.agCast, ), - "types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL], + "types": ( + DType.FP16, + DType.FLOAT, + DType.INT8, + DType.INT16, + DType.INT32, + DType.BOOL, + ), "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, @@ -3925,7 +4026,9 @@ class OutputShaper: return ser.addOutput(shape, outputDType) @staticmethod - def conv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None): + def conv2dOp( + ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None + ): # IFM: NHWC # Filter: OHWI @@ -3958,26 +4061,26 @@ class OutputShaper: ofm_shape = [ifm.shape[0], h, w, filter.shape[0]] - if ifm.dtype == DType.INT8: - out_dtype = DType.INT32 - elif ifm.dtype == DType.INT16: - out_dtype = DType.INT48 - elif ifm.dtype == DType.FLOAT: - out_dtype = DType.FLOAT - elif error_name == ErrorIf.WrongInputType: + if error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - raise Exception(f"Unsupported input dtype: {ifm.dtype}") + out_dtype = accum_dtype if error_name == ErrorIf.WrongOutputType: - wrong_dtypes = list(usableDTypes(excludes=[out_dtype])) + if ifm.dtype == DType.FP16: + excludes = [DType.FP16, DType.FLOAT] + else: + excludes = [out_dtype] + wrong_dtypes = list(usableDTypes(excludes=excludes)) out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(ofm_shape, out_dtype) @staticmethod - def conv3dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None): + def conv3dOp( + ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None + ): # IFM: NDHWC # Filter: ODHWI @@ -4020,27 +4123,25 @@ class OutputShaper: ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]] - if ifm.dtype == DType.INT8: - out_dtype = DType.INT32 - elif ifm.dtype == DType.INT16: - out_dtype = DType.INT48 - elif ifm.dtype == DType.FLOAT: - out_dtype = DType.FLOAT - elif error_name == ErrorIf.WrongInputType: + if error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - raise Exception(f"Unsupported input dtype: {ifm.dtype}") + out_dtype = accum_dtype if error_name == ErrorIf.WrongOutputType: - wrong_dtypes = list(usableDTypes(excludes=[out_dtype])) + if ifm.dtype == DType.FP16: + excludes = [DType.FP16, DType.FLOAT] + else: + excludes = [out_dtype] + wrong_dtypes = list(usableDTypes(excludes=excludes)) out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(ofm_shape, out_dtype) @staticmethod def depthwiseConv2dOp( - ser, rng, ifm, filter, strides, padding, dilations, error_name=None + ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None ): # IFM: NHWC # Filter: HWCM @@ -4073,20 +4174,18 @@ class OutputShaper: ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]] - if ifm.dtype == DType.INT8: - out_dtype = DType.INT32 - elif ifm.dtype == DType.INT16: - out_dtype = DType.INT48 - elif ifm.dtype == DType.FLOAT: - out_dtype = DType.FLOAT - elif error_name == ErrorIf.WrongInputType: + if error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - raise Exception(f"Unsupported input dtype: {ifm.dtype}") + out_dtype = accum_dtype if error_name == ErrorIf.WrongOutputType: - wrong_dtypes = list(usableDTypes(excludes=[out_dtype])) + if ifm.dtype == DType.FP16: + excludes = [DType.FP16, DType.FLOAT] + else: + excludes = [out_dtype] + wrong_dtypes = list(usableDTypes(excludes=excludes)) out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(ofm_shape, out_dtype) @@ -4119,6 +4218,7 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FLOAT, + DType.FP16, ] wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4128,55 +4228,20 @@ class OutputShaper: return ser.addOutput(ofm_shape, outputDType) @staticmethod - def fullyConnectedOp(ser, rng, input, filter, error_name=None): + def fullyConnectedOp(ser, rng, input, filter, accum_dtype, error_name=None): # input: N, IC # filter: OC, IC # output: N, OC output_shape = [input.shape[0], filter.shape[0]] - if error_name == ErrorIf.WrongOutputType: - if input.dtype == DType.INT8: - incorrect_types = ( - DType.INT4, - DType.INT8, - DType.INT16, - DType.INT48, - DType.FLOAT, - ) - elif input.dtype == DType.INT16: - incorrect_types = ( - DType.INT4, - DType.INT8, - DType.INT16, - DType.INT32, - DType.FLOAT, - ) - elif input.dtype == DType.FLOAT: - incorrect_types = ( - DType.INT4, - DType.INT8, - DType.INT16, - DType.INT32, - DType.INT48, - ) - out_dtype = rng.choice(a=incorrect_types) - elif input.dtype == DType.INT8: - out_dtype = DType.INT32 - elif input.dtype == DType.INT16: - out_dtype = DType.INT48 - elif input.dtype == DType.FLOAT: - out_dtype = DType.FLOAT - elif error_name == ErrorIf.WrongInputType: - # Pick some potentially correct output dtype if input type is incorrect - out_dtype = DType.INT32 - else: - raise Exception("Unsupported input dtype: {}".format(input.dtype)) + # Validated in arg_gen (also invalidated for ErrorIf) + out_dtype = accum_dtype return ser.addOutput(output_shape, out_dtype) @staticmethod - def matmulOp(ser, rng, a, b, error_name=None): + def matmulOp(ser, rng, a, b, accum_dtype, error_name=None): # a: N, H, C # b: N, C, W # out: N, H, W @@ -4200,7 +4265,7 @@ class OutputShaper: DType.INT32, DType.FLOAT, ) - elif a.dtype == DType.FLOAT: + elif a.dtype == DType.FLOAT or a.dtype == DType.FP16: incorrect_types = ( DType.INT4, DType.INT8, @@ -4209,17 +4274,11 @@ class OutputShaper: DType.INT48, ) out_dtype = rng.choice(a=incorrect_types) - elif a.dtype == DType.INT8: - out_dtype = DType.INT32 - elif a.dtype == DType.INT16: - out_dtype = DType.INT48 - elif a.dtype == DType.FLOAT: - out_dtype = DType.FLOAT elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype)) + out_dtype = accum_dtype # Validated in arg_gen return ser.addOutput(output_shape, out_dtype) @@ -4269,10 +4328,6 @@ class OutputShaper: bad_dim = rng.choice(range(len(output_shape))) output_shape[bad_dim] -= rng.choice([1, 2]) - # Fix negative output shape if error_if test causes it - if error_name == ErrorIf.PadSmallerZero and min(output_shape) < 1: - output_shape = [i if i >= 1 else 1 for i in output_shape] - if error_name == ErrorIf.WrongOutputType: all_dtypes = [ DType.INT8, @@ -4280,6 +4335,7 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FLOAT, + DType.FP16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4546,7 +4602,7 @@ class OutputShaper: return ser.addOutput(val.shape, out_dtype) @staticmethod - def transposeConv2DOp(ser, rng, ifm, output_shape, error_name=None): + def transposeConv2DOp(ser, rng, ifm, output_shape, accum_dtype, error_name=None): if error_name == ErrorIf.ConvOutputShapeMismatch: choices = [1, 2, 3] change = rng.choice(choices) @@ -4555,20 +4611,18 @@ class OutputShaper: if change in [2, 3]: output_shape[2] = output_shape[2] + rng.choice(choices) - if ifm.dtype == DType.INT8: - out_dtype = DType.INT32 - elif ifm.dtype == DType.INT16: - out_dtype = DType.INT48 - elif ifm.dtype == DType.FLOAT: - out_dtype = DType.FLOAT - elif error_name == ErrorIf.WrongInputType: + if error_name == ErrorIf.WrongInputType: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - raise Exception(f"Unsupported input dtype: {ifm.dtype}") + out_dtype = accum_dtype if error_name == ErrorIf.WrongOutputType: - wrong_dtypes = list(usableDTypes(excludes=[out_dtype])) + if ifm.dtype == DType.FP16: + excludes = [DType.FP16, DType.FLOAT] + else: + excludes = [out_dtype] + wrong_dtypes = list(usableDTypes(excludes=excludes)) out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(output_shape, out_dtype) diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index 6a689d0..7fa31e7 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -84,3 +84,42 @@ def product(shape): for n in shape: value *= n return value + + +def get_accum_dtype_from_tgTypes(dtypes): + # Get accumulate data-type from the test generator's defined types + if isinstance(dtypes, list) or isinstance(dtypes, tuple): + return dtypes[-1] + else: + return dtypes + + +def get_wrong_output_type(op_name, rng, input_dtype): + if op_name == "fully_connected" or op_name == "matmul": + if input_dtype == DType.INT8: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT48, + DType.FLOAT, + DType.FP16, + ) + elif input_dtype == DType.INT16: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.FLOAT, + DType.FP16, + ) + elif input_dtype == DType.FLOAT or input_dtype == DType.FP16: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + ) + return rng.choice(a=incorrect_types) diff --git a/verif/tests/test_tosa_result_checker.py b/verif/tests/test_tosa_result_checker.py index efee23b..d78d158 100644 --- a/verif/tests/test_tosa_result_checker.py +++ b/verif/tests/test_tosa_result_checker.py @@ -40,7 +40,7 @@ def _delete_data_file(file: Path): (np.uint16, trc.TestResult.MISMATCH), (np.uint32, trc.TestResult.MISMATCH), (np.uint64, trc.TestResult.MISMATCH), - (np.float16, trc.TestResult.MISMATCH), + (np.float16, trc.TestResult.PASS), (np.float32, trc.TestResult.PASS), (np.float64, trc.TestResult.MISMATCH), (bool, trc.TestResult.PASS), -- cgit v1.2.1