diff options
author | Inki Dae <inki.dae@samsung.com> | 2020-09-10 15:33:54 +0900 |
---|---|---|
committer | mike.kelly <mike.kelly@arm.com> | 2020-09-24 16:03:00 +0000 |
commit | d4619e28a4cde423d5b4086a98c31f97b52a68d7 (patch) | |
tree | 9e6da174b200e8135ed2bbc2b0b8cb761d3e1a4f /include | |
parent | 02036e99c1b2074e5e5f04a2fe443f0c90689683 (diff) | |
download | armnn-d4619e28a4cde423d5b4086a98c31f97b52a68d7.tar.gz |
Add int32 and int64 ArgMax op support
This patch adds int32 and int64 ArgMax op support.
Current ARMNN already has ArgMax op but not used, and
it doesn't support int64 output type.
So this patch adds a new type, Signed64, and also adds
ArgMinMax computation function for int64 type support.
In default, output tensor type of ArgMax op is int64 in case of
tensorflow lite model so this patch makes a proper function - ArgMax op
for int64 or int32 - to be called according to parsed output_type value.
With this patch, ARMNN supports both types - int64 and int32 - for
ArgMinMax op.
Changelog v1:
- Check if output data type of ArgMinMax op is valid or not.
- Use template function to support int32 and int64 types of ArgMinMax function.
- Keep using Signed32 as default data type of m_Output_Type.
Change-Id: I7a8e7e38dd9e5acc81464571d8b4d51378fc7f14
Signed-off-by: Inki Dae <inki.dae@samsung.com>
Diffstat (limited to 'include')
-rw-r--r-- | include/armnn/Descriptors.hpp | 5 | ||||
-rw-r--r-- | include/armnn/Types.hpp | 1 | ||||
-rw-r--r-- | include/armnn/TypesUtils.hpp | 2 |
3 files changed, 7 insertions, 1 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 241b23d4ed..2834336fb2 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -53,17 +53,20 @@ struct ArgMinMaxDescriptor ArgMinMaxDescriptor() : m_Function(ArgMinMaxFunction::Min) , m_Axis(-1) + , m_Output_Type(armnn::DataType::Signed32) {} bool operator ==(const ArgMinMaxDescriptor &rhs) const { - return m_Function == rhs.m_Function && m_Axis == rhs.m_Axis; + return m_Function == rhs.m_Function && m_Axis == rhs.m_Axis && m_Output_Type == rhs.m_Output_Type; } /// Specify if the function is to find Min or Max. ArgMinMaxFunction m_Function; /// Axis to reduce across the input tensor. int m_Axis; + // Tensor data type and this could be int32 or int64. Default type is int64. + armnn::DataType m_Output_Type; }; /// A ComparisonDescriptor for the ComparisonLayer diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp index 11d807cd89..4a01549a14 100644 --- a/include/armnn/Types.hpp +++ b/include/armnn/Types.hpp @@ -41,6 +41,7 @@ enum class DataType QSymmS8 = 7, QAsymmS8 = 8, BFloat16 = 9, + Signed64 = 10, QuantisedAsymm8 ARMNN_DEPRECATED_ENUM_MSG("Use DataType::QAsymmU8 instead.") = QAsymmU8, QuantisedSymm16 ARMNN_DEPRECATED_ENUM_MSG("Use DataType::QSymmS16 instead.") = QSymmS16 diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp index a2b3c95752..efc69deb67 100644 --- a/include/armnn/TypesUtils.hpp +++ b/include/armnn/TypesUtils.hpp @@ -120,6 +120,7 @@ constexpr unsigned int GetDataTypeSize(DataType dataType) case DataType::Float16: return 2U; case DataType::Float32: case DataType::Signed32: return 4U; + case DataType::Signed64: return 8U; case DataType::QAsymmU8: return 1U; case DataType::QAsymmS8: return 1U; case DataType::QSymmS8: return 1U; @@ -171,6 +172,7 @@ constexpr const char* GetDataTypeName(DataType dataType) { case DataType::Float16: return "Float16"; case DataType::Float32: return "Float32"; + case DataType::Signed64: return "Signed64"; case DataType::QAsymmU8: return "QAsymmU8"; case DataType::QAsymmS8: return "QAsymmS8"; case DataType::QSymmS8: return "QSymmS8"; |