diff options
author | Michalis Spyrou <michalis.spyrou@arm.com> | 2019-08-22 11:44:04 +0100 |
---|---|---|
committer | Michalis Spyrou <michalis.spyrou@arm.com> | 2019-08-22 15:28:03 +0000 |
commit | c8530210c17b391f27ace95523e9590e8166fcd8 (patch) | |
tree | 7ab0ea38b730a194aba2da91440a991a7614be8f | |
parent | 57a896172ff77c13655b1dc5acc9cfb2930e0570 (diff) | |
download | ComputeLibrary-c8530210c17b391f27ace95523e9590e8166fcd8.tar.gz |
COMPMID-2417: Add new QASYMM8_PER_CHANNEL data type
Change-Id: I6825320909a553513b98cf9b262fc90e37a2fa30
Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1790
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
-rw-r--r-- | arm_compute/core/Types.h | 35 | ||||
-rw-r--r-- | arm_compute/core/Utils.h | 21 | ||||
-rw-r--r-- | src/core/Utils.cpp | 1 | ||||
-rw-r--r-- | src/runtime/CL/CLTensorAllocator.cpp | 2 | ||||
-rw-r--r-- | tests/Utils.h | 1 | ||||
-rw-r--r-- | utils/TypePrinter.h | 3 | ||||
-rw-r--r-- | utils/Utils.h | 1 |
7 files changed, 46 insertions, 18 deletions
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 6df74e7b88..aa07067855 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -73,23 +73,24 @@ enum class Format /** Available data types */ enum class DataType { - UNKNOWN, /**< Unknown data type */ - U8, /**< unsigned 8-bit number */ - S8, /**< signed 8-bit number */ - QSYMM8, /**< quantized, symmetric fixed-point 8-bit number */ - QASYMM8, /**< quantized, asymmetric fixed-point 8-bit number */ - QSYMM8_PER_CHANNEL, /**< quantized, symmetric per channel fixed-point 8-bit number */ - U16, /**< unsigned 16-bit number */ - S16, /**< signed 16-bit number */ - QSYMM16, /**< quantized, symmetric fixed-point 16-bit number */ - U32, /**< unsigned 32-bit number */ - S32, /**< signed 32-bit number */ - U64, /**< unsigned 64-bit number */ - S64, /**< signed 64-bit number */ - F16, /**< 16-bit floating-point number */ - F32, /**< 32-bit floating-point number */ - F64, /**< 64-bit floating-point number */ - SIZET /**< size_t */ + UNKNOWN, /**< Unknown data type */ + U8, /**< unsigned 8-bit number */ + S8, /**< signed 8-bit number */ + QSYMM8, /**< quantized, symmetric fixed-point 8-bit number */ + QASYMM8, /**< quantized, asymmetric fixed-point 8-bit number */ + QSYMM8_PER_CHANNEL, /**< quantized, symmetric per channel fixed-point 8-bit number */ + QASYMM8_PER_CHANNEL, /**< quantized, asymmetric per channel fixed-point 8-bit number */ + U16, /**< unsigned 16-bit number */ + S16, /**< signed 16-bit number */ + QSYMM16, /**< quantized, symmetric fixed-point 16-bit number */ + U32, /**< unsigned 32-bit number */ + S32, /**< signed 32-bit number */ + U64, /**< unsigned 64-bit number */ + S64, /**< signed 64-bit number */ + F16, /**< 16-bit floating-point number */ + F32, /**< 32-bit floating-point number */ + F64, /**< 64-bit floating-point number */ + SIZET /**< size_t */ }; /** Available Sampling Policies */ diff --git a/arm_compute/core/Utils.h b/arm_compute/core/Utils.h index bc461e7ba9..eb4cf05ae9 100644 --- a/arm_compute/core/Utils.h +++ b/arm_compute/core/Utils.h @@ -115,6 +115,7 @@ inline size_t data_size_from_type(DataType data_type) case DataType::QSYMM8: case DataType::QASYMM8: case DataType::QSYMM8_PER_CHANNEL: + case DataType::QASYMM8_PER_CHANNEL: return 1; case DataType::U16: case DataType::S16: @@ -1014,6 +1015,7 @@ inline bool is_data_type_quantized(DataType dt) case DataType::QSYMM8: case DataType::QASYMM8: case DataType::QSYMM8_PER_CHANNEL: + case DataType::QASYMM8_PER_CHANNEL: case DataType::QSYMM16: return true; default: @@ -1032,6 +1034,7 @@ inline bool is_data_type_quantized_asymmetric(DataType dt) switch(dt) { case DataType::QASYMM8: + case DataType::QASYMM8_PER_CHANNEL: return true; default: return false; @@ -1057,6 +1060,24 @@ inline bool is_data_type_quantized_symmetric(DataType dt) } } +/** Check if a given data type is of per channel type + * + * @param[in] dt Input data type. + * + * @return True if data type is of per channel type, else false. + */ +inline bool is_data_type_quantized_per_channel(DataType dt) +{ + switch(dt) + { + case DataType::QSYMM8_PER_CHANNEL: + case DataType::QASYMM8_PER_CHANNEL: + return true; + default: + return false; + } +} + /** Create a string with the float in full precision. * * @param val Floating point value diff --git a/src/core/Utils.cpp b/src/core/Utils.cpp index d0bffdf660..0c7eea84e3 100644 --- a/src/core/Utils.cpp +++ b/src/core/Utils.cpp @@ -160,6 +160,7 @@ const std::string &arm_compute::string_from_data_type(DataType dt) { DataType::SIZET, "SIZET" }, { DataType::QSYMM8, "QSYMM8" }, { DataType::QSYMM8_PER_CHANNEL, "QSYMM8_PER_CHANNEL" }, + { DataType::QASYMM8_PER_CHANNEL, "QASYMM8_PER_CHANNEL" }, { DataType::QASYMM8, "QASYMM8" }, { DataType::QSYMM16, "QSYMM16" }, }; diff --git a/src/runtime/CL/CLTensorAllocator.cpp b/src/runtime/CL/CLTensorAllocator.cpp index f3f16cd8c0..028a764fc2 100644 --- a/src/runtime/CL/CLTensorAllocator.cpp +++ b/src/runtime/CL/CLTensorAllocator.cpp @@ -139,7 +139,7 @@ void CLTensorAllocator::allocate() } // Allocate and fill the quantization parameter arrays - if(info().data_type() == DataType::QSYMM8_PER_CHANNEL) + if(is_data_type_quantized_per_channel(info().data_type())) { const size_t pad_size = 0; populate_quantization_info(_scale, _offset, info().quantization_info(), pad_size); diff --git a/tests/Utils.h b/tests/Utils.h index a14b30b659..81bc2663de 100644 --- a/tests/Utils.h +++ b/tests/Utils.h @@ -352,6 +352,7 @@ void store_value_with_data_type(void *ptr, T value, DataType data_type) { case DataType::U8: case DataType::QASYMM8: + case DataType::QASYMM8_PER_CHANNEL: *reinterpret_cast<uint8_t *>(ptr) = value; break; case DataType::S8: diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h index f51d2368e1..69ffe9e4a6 100644 --- a/utils/TypePrinter.h +++ b/utils/TypePrinter.h @@ -628,6 +628,9 @@ inline ::std::ostream &operator<<(::std::ostream &os, const DataType &data_type) case DataType::QSYMM8_PER_CHANNEL: os << "QSYMM8_PER_CHANNEL"; break; + case DataType::QASYMM8_PER_CHANNEL: + os << "QASYMM8_PER_CHANNEL"; + break; case DataType::S8: os << "S8"; break; diff --git a/utils/Utils.h b/utils/Utils.h index cc5dfbabc2..8605f4e3e1 100644 --- a/utils/Utils.h +++ b/utils/Utils.h @@ -166,6 +166,7 @@ inline std::string get_typestring(DataType data_type) { case DataType::U8: case DataType::QASYMM8: + case DataType::QASYMM8_PER_CHANNEL: return no_endianness + "u" + support::cpp11::to_string(sizeof(uint8_t)); case DataType::S8: case DataType::QSYMM8: |