aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--arm_compute/core/Types.h35
-rw-r--r--arm_compute/core/Utils.h21
-rw-r--r--src/core/Utils.cpp1
-rw-r--r--src/runtime/CL/CLTensorAllocator.cpp2
-rw-r--r--tests/Utils.h1
-rw-r--r--utils/TypePrinter.h3
-rw-r--r--utils/Utils.h1
7 files changed, 46 insertions, 18 deletions
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index 6df74e7b88..aa07067855 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -73,23 +73,24 @@ enum class Format
/** Available data types */
enum class DataType
{
- UNKNOWN, /**< Unknown data type */
- U8, /**< unsigned 8-bit number */
- S8, /**< signed 8-bit number */
- QSYMM8, /**< quantized, symmetric fixed-point 8-bit number */
- QASYMM8, /**< quantized, asymmetric fixed-point 8-bit number */
- QSYMM8_PER_CHANNEL, /**< quantized, symmetric per channel fixed-point 8-bit number */
- U16, /**< unsigned 16-bit number */
- S16, /**< signed 16-bit number */
- QSYMM16, /**< quantized, symmetric fixed-point 16-bit number */
- U32, /**< unsigned 32-bit number */
- S32, /**< signed 32-bit number */
- U64, /**< unsigned 64-bit number */
- S64, /**< signed 64-bit number */
- F16, /**< 16-bit floating-point number */
- F32, /**< 32-bit floating-point number */
- F64, /**< 64-bit floating-point number */
- SIZET /**< size_t */
+ UNKNOWN, /**< Unknown data type */
+ U8, /**< unsigned 8-bit number */
+ S8, /**< signed 8-bit number */
+ QSYMM8, /**< quantized, symmetric fixed-point 8-bit number */
+ QASYMM8, /**< quantized, asymmetric fixed-point 8-bit number */
+ QSYMM8_PER_CHANNEL, /**< quantized, symmetric per channel fixed-point 8-bit number */
+ QASYMM8_PER_CHANNEL, /**< quantized, asymmetric per channel fixed-point 8-bit number */
+ U16, /**< unsigned 16-bit number */
+ S16, /**< signed 16-bit number */
+ QSYMM16, /**< quantized, symmetric fixed-point 16-bit number */
+ U32, /**< unsigned 32-bit number */
+ S32, /**< signed 32-bit number */
+ U64, /**< unsigned 64-bit number */
+ S64, /**< signed 64-bit number */
+ F16, /**< 16-bit floating-point number */
+ F32, /**< 32-bit floating-point number */
+ F64, /**< 64-bit floating-point number */
+ SIZET /**< size_t */
};
/** Available Sampling Policies */
diff --git a/arm_compute/core/Utils.h b/arm_compute/core/Utils.h
index bc461e7ba9..eb4cf05ae9 100644
--- a/arm_compute/core/Utils.h
+++ b/arm_compute/core/Utils.h
@@ -115,6 +115,7 @@ inline size_t data_size_from_type(DataType data_type)
case DataType::QSYMM8:
case DataType::QASYMM8:
case DataType::QSYMM8_PER_CHANNEL:
+ case DataType::QASYMM8_PER_CHANNEL:
return 1;
case DataType::U16:
case DataType::S16:
@@ -1014,6 +1015,7 @@ inline bool is_data_type_quantized(DataType dt)
case DataType::QSYMM8:
case DataType::QASYMM8:
case DataType::QSYMM8_PER_CHANNEL:
+ case DataType::QASYMM8_PER_CHANNEL:
case DataType::QSYMM16:
return true;
default:
@@ -1032,6 +1034,7 @@ inline bool is_data_type_quantized_asymmetric(DataType dt)
switch(dt)
{
case DataType::QASYMM8:
+ case DataType::QASYMM8_PER_CHANNEL:
return true;
default:
return false;
@@ -1057,6 +1060,24 @@ inline bool is_data_type_quantized_symmetric(DataType dt)
}
}
+/** Check if a given data type is of per channel type
+ *
+ * @param[in] dt Input data type.
+ *
+ * @return True if data type is of per channel type, else false.
+ */
+inline bool is_data_type_quantized_per_channel(DataType dt)
+{
+ switch(dt)
+ {
+ case DataType::QSYMM8_PER_CHANNEL:
+ case DataType::QASYMM8_PER_CHANNEL:
+ return true;
+ default:
+ return false;
+ }
+}
+
/** Create a string with the float in full precision.
*
* @param val Floating point value
diff --git a/src/core/Utils.cpp b/src/core/Utils.cpp
index d0bffdf660..0c7eea84e3 100644
--- a/src/core/Utils.cpp
+++ b/src/core/Utils.cpp
@@ -160,6 +160,7 @@ const std::string &arm_compute::string_from_data_type(DataType dt)
{ DataType::SIZET, "SIZET" },
{ DataType::QSYMM8, "QSYMM8" },
{ DataType::QSYMM8_PER_CHANNEL, "QSYMM8_PER_CHANNEL" },
+ { DataType::QASYMM8_PER_CHANNEL, "QASYMM8_PER_CHANNEL" },
{ DataType::QASYMM8, "QASYMM8" },
{ DataType::QSYMM16, "QSYMM16" },
};
diff --git a/src/runtime/CL/CLTensorAllocator.cpp b/src/runtime/CL/CLTensorAllocator.cpp
index f3f16cd8c0..028a764fc2 100644
--- a/src/runtime/CL/CLTensorAllocator.cpp
+++ b/src/runtime/CL/CLTensorAllocator.cpp
@@ -139,7 +139,7 @@ void CLTensorAllocator::allocate()
}
// Allocate and fill the quantization parameter arrays
- if(info().data_type() == DataType::QSYMM8_PER_CHANNEL)
+ if(is_data_type_quantized_per_channel(info().data_type()))
{
const size_t pad_size = 0;
populate_quantization_info(_scale, _offset, info().quantization_info(), pad_size);
diff --git a/tests/Utils.h b/tests/Utils.h
index a14b30b659..81bc2663de 100644
--- a/tests/Utils.h
+++ b/tests/Utils.h
@@ -352,6 +352,7 @@ void store_value_with_data_type(void *ptr, T value, DataType data_type)
{
case DataType::U8:
case DataType::QASYMM8:
+ case DataType::QASYMM8_PER_CHANNEL:
*reinterpret_cast<uint8_t *>(ptr) = value;
break;
case DataType::S8:
diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h
index f51d2368e1..69ffe9e4a6 100644
--- a/utils/TypePrinter.h
+++ b/utils/TypePrinter.h
@@ -628,6 +628,9 @@ inline ::std::ostream &operator<<(::std::ostream &os, const DataType &data_type)
case DataType::QSYMM8_PER_CHANNEL:
os << "QSYMM8_PER_CHANNEL";
break;
+ case DataType::QASYMM8_PER_CHANNEL:
+ os << "QASYMM8_PER_CHANNEL";
+ break;
case DataType::S8:
os << "S8";
break;
diff --git a/utils/Utils.h b/utils/Utils.h
index cc5dfbabc2..8605f4e3e1 100644
--- a/utils/Utils.h
+++ b/utils/Utils.h
@@ -166,6 +166,7 @@ inline std::string get_typestring(DataType data_type)
{
case DataType::U8:
case DataType::QASYMM8:
+ case DataType::QASYMM8_PER_CHANNEL:
return no_endianness + "u" + support::cpp11::to_string(sizeof(uint8_t));
case DataType::S8:
case DataType::QSYMM8: