From c8530210c17b391f27ace95523e9590e8166fcd8 Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Thu, 22 Aug 2019 11:44:04 +0100 Subject: COMPMID-2417: Add new QASYMM8_PER_CHANNEL data type Change-Id: I6825320909a553513b98cf9b262fc90e37a2fa30 Signed-off-by: Michalis Spyrou Reviewed-on: https://review.mlplatform.org/c/1790 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas --- arm_compute/core/Types.h | 35 ++++++++++++++++++----------------- arm_compute/core/Utils.h | 21 +++++++++++++++++++++ src/core/Utils.cpp | 1 + src/runtime/CL/CLTensorAllocator.cpp | 2 +- tests/Utils.h | 1 + utils/TypePrinter.h | 3 +++ utils/Utils.h | 1 + 7 files changed, 46 insertions(+), 18 deletions(-) diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 6df74e7b88..aa07067855 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -73,23 +73,24 @@ enum class Format /** Available data types */ enum class DataType { - UNKNOWN, /**< Unknown data type */ - U8, /**< unsigned 8-bit number */ - S8, /**< signed 8-bit number */ - QSYMM8, /**< quantized, symmetric fixed-point 8-bit number */ - QASYMM8, /**< quantized, asymmetric fixed-point 8-bit number */ - QSYMM8_PER_CHANNEL, /**< quantized, symmetric per channel fixed-point 8-bit number */ - U16, /**< unsigned 16-bit number */ - S16, /**< signed 16-bit number */ - QSYMM16, /**< quantized, symmetric fixed-point 16-bit number */ - U32, /**< unsigned 32-bit number */ - S32, /**< signed 32-bit number */ - U64, /**< unsigned 64-bit number */ - S64, /**< signed 64-bit number */ - F16, /**< 16-bit floating-point number */ - F32, /**< 32-bit floating-point number */ - F64, /**< 64-bit floating-point number */ - SIZET /**< size_t */ + UNKNOWN, /**< Unknown data type */ + U8, /**< unsigned 8-bit number */ + S8, /**< signed 8-bit number */ + QSYMM8, /**< quantized, symmetric fixed-point 8-bit number */ + QASYMM8, /**< quantized, asymmetric fixed-point 8-bit number */ + QSYMM8_PER_CHANNEL, /**< quantized, symmetric per channel fixed-point 8-bit number */ + QASYMM8_PER_CHANNEL, /**< quantized, asymmetric per channel fixed-point 8-bit number */ + U16, /**< unsigned 16-bit number */ + S16, /**< signed 16-bit number */ + QSYMM16, /**< quantized, symmetric fixed-point 16-bit number */ + U32, /**< unsigned 32-bit number */ + S32, /**< signed 32-bit number */ + U64, /**< unsigned 64-bit number */ + S64, /**< signed 64-bit number */ + F16, /**< 16-bit floating-point number */ + F32, /**< 32-bit floating-point number */ + F64, /**< 64-bit floating-point number */ + SIZET /**< size_t */ }; /** Available Sampling Policies */ diff --git a/arm_compute/core/Utils.h b/arm_compute/core/Utils.h index bc461e7ba9..eb4cf05ae9 100644 --- a/arm_compute/core/Utils.h +++ b/arm_compute/core/Utils.h @@ -115,6 +115,7 @@ inline size_t data_size_from_type(DataType data_type) case DataType::QSYMM8: case DataType::QASYMM8: case DataType::QSYMM8_PER_CHANNEL: + case DataType::QASYMM8_PER_CHANNEL: return 1; case DataType::U16: case DataType::S16: @@ -1014,6 +1015,7 @@ inline bool is_data_type_quantized(DataType dt) case DataType::QSYMM8: case DataType::QASYMM8: case DataType::QSYMM8_PER_CHANNEL: + case DataType::QASYMM8_PER_CHANNEL: case DataType::QSYMM16: return true; default: @@ -1032,6 +1034,7 @@ inline bool is_data_type_quantized_asymmetric(DataType dt) switch(dt) { case DataType::QASYMM8: + case DataType::QASYMM8_PER_CHANNEL: return true; default: return false; @@ -1057,6 +1060,24 @@ inline bool is_data_type_quantized_symmetric(DataType dt) } } +/** Check if a given data type is of per channel type + * + * @param[in] dt Input data type. + * + * @return True if data type is of per channel type, else false. + */ +inline bool is_data_type_quantized_per_channel(DataType dt) +{ + switch(dt) + { + case DataType::QSYMM8_PER_CHANNEL: + case DataType::QASYMM8_PER_CHANNEL: + return true; + default: + return false; + } +} + /** Create a string with the float in full precision. * * @param val Floating point value diff --git a/src/core/Utils.cpp b/src/core/Utils.cpp index d0bffdf660..0c7eea84e3 100644 --- a/src/core/Utils.cpp +++ b/src/core/Utils.cpp @@ -160,6 +160,7 @@ const std::string &arm_compute::string_from_data_type(DataType dt) { DataType::SIZET, "SIZET" }, { DataType::QSYMM8, "QSYMM8" }, { DataType::QSYMM8_PER_CHANNEL, "QSYMM8_PER_CHANNEL" }, + { DataType::QASYMM8_PER_CHANNEL, "QASYMM8_PER_CHANNEL" }, { DataType::QASYMM8, "QASYMM8" }, { DataType::QSYMM16, "QSYMM16" }, }; diff --git a/src/runtime/CL/CLTensorAllocator.cpp b/src/runtime/CL/CLTensorAllocator.cpp index f3f16cd8c0..028a764fc2 100644 --- a/src/runtime/CL/CLTensorAllocator.cpp +++ b/src/runtime/CL/CLTensorAllocator.cpp @@ -139,7 +139,7 @@ void CLTensorAllocator::allocate() } // Allocate and fill the quantization parameter arrays - if(info().data_type() == DataType::QSYMM8_PER_CHANNEL) + if(is_data_type_quantized_per_channel(info().data_type())) { const size_t pad_size = 0; populate_quantization_info(_scale, _offset, info().quantization_info(), pad_size); diff --git a/tests/Utils.h b/tests/Utils.h index a14b30b659..81bc2663de 100644 --- a/tests/Utils.h +++ b/tests/Utils.h @@ -352,6 +352,7 @@ void store_value_with_data_type(void *ptr, T value, DataType data_type) { case DataType::U8: case DataType::QASYMM8: + case DataType::QASYMM8_PER_CHANNEL: *reinterpret_cast(ptr) = value; break; case DataType::S8: diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h index f51d2368e1..69ffe9e4a6 100644 --- a/utils/TypePrinter.h +++ b/utils/TypePrinter.h @@ -628,6 +628,9 @@ inline ::std::ostream &operator<<(::std::ostream &os, const DataType &data_type) case DataType::QSYMM8_PER_CHANNEL: os << "QSYMM8_PER_CHANNEL"; break; + case DataType::QASYMM8_PER_CHANNEL: + os << "QASYMM8_PER_CHANNEL"; + break; case DataType::S8: os << "S8"; break; diff --git a/utils/Utils.h b/utils/Utils.h index cc5dfbabc2..8605f4e3e1 100644 --- a/utils/Utils.h +++ b/utils/Utils.h @@ -166,6 +166,7 @@ inline std::string get_typestring(DataType data_type) { case DataType::U8: case DataType::QASYMM8: + case DataType::QASYMM8_PER_CHANNEL: return no_endianness + "u" + support::cpp11::to_string(sizeof(uint8_t)); case DataType::S8: case DataType::QSYMM8: -- cgit v1.2.1