diff options
Diffstat (limited to 'include')
-rw-r--r-- | include/armnn/Descriptors.hpp | 5 | ||||
-rw-r--r-- | include/armnn/Types.hpp | 1 | ||||
-rw-r--r-- | include/armnn/TypesUtils.hpp | 2 |
3 files changed, 7 insertions, 1 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 241b23d4ed..2834336fb2 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -53,17 +53,20 @@ struct ArgMinMaxDescriptor ArgMinMaxDescriptor() : m_Function(ArgMinMaxFunction::Min) , m_Axis(-1) + , m_Output_Type(armnn::DataType::Signed32) {} bool operator ==(const ArgMinMaxDescriptor &rhs) const { - return m_Function == rhs.m_Function && m_Axis == rhs.m_Axis; + return m_Function == rhs.m_Function && m_Axis == rhs.m_Axis && m_Output_Type == rhs.m_Output_Type; } /// Specify if the function is to find Min or Max. ArgMinMaxFunction m_Function; /// Axis to reduce across the input tensor. int m_Axis; + // Tensor data type and this could be int32 or int64. Default type is int64. + armnn::DataType m_Output_Type; }; /// A ComparisonDescriptor for the ComparisonLayer diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp index 11d807cd89..4a01549a14 100644 --- a/include/armnn/Types.hpp +++ b/include/armnn/Types.hpp @@ -41,6 +41,7 @@ enum class DataType QSymmS8 = 7, QAsymmS8 = 8, BFloat16 = 9, + Signed64 = 10, QuantisedAsymm8 ARMNN_DEPRECATED_ENUM_MSG("Use DataType::QAsymmU8 instead.") = QAsymmU8, QuantisedSymm16 ARMNN_DEPRECATED_ENUM_MSG("Use DataType::QSymmS16 instead.") = QSymmS16 diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp index a2b3c95752..efc69deb67 100644 --- a/include/armnn/TypesUtils.hpp +++ b/include/armnn/TypesUtils.hpp @@ -120,6 +120,7 @@ constexpr unsigned int GetDataTypeSize(DataType dataType) case DataType::Float16: return 2U; case DataType::Float32: case DataType::Signed32: return 4U; + case DataType::Signed64: return 8U; case DataType::QAsymmU8: return 1U; case DataType::QAsymmS8: return 1U; case DataType::QSymmS8: return 1U; @@ -171,6 +172,7 @@ constexpr const char* GetDataTypeName(DataType dataType) { case DataType::Float16: return "Float16"; case DataType::Float32: return "Float32"; + case DataType::Signed64: return "Signed64"; case DataType::QAsymmU8: return "QAsymmU8"; case DataType::QAsymmS8: return "QAsymmS8"; case DataType::QSymmS8: return "QSymmS8"; |