aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
Diffstat (limited to 'include')
-rw-r--r--include/armnn/Descriptors.hpp5
-rw-r--r--include/armnn/Types.hpp1
-rw-r--r--include/armnn/TypesUtils.hpp2
3 files changed, 7 insertions, 1 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 241b23d4ed..2834336fb2 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -53,17 +53,20 @@ struct ArgMinMaxDescriptor
ArgMinMaxDescriptor()
: m_Function(ArgMinMaxFunction::Min)
, m_Axis(-1)
+ , m_Output_Type(armnn::DataType::Signed32)
{}
bool operator ==(const ArgMinMaxDescriptor &rhs) const
{
- return m_Function == rhs.m_Function && m_Axis == rhs.m_Axis;
+ return m_Function == rhs.m_Function && m_Axis == rhs.m_Axis && m_Output_Type == rhs.m_Output_Type;
}
/// Specify if the function is to find Min or Max.
ArgMinMaxFunction m_Function;
/// Axis to reduce across the input tensor.
int m_Axis;
+ // Tensor data type and this could be int32 or int64. Default type is int64.
+ armnn::DataType m_Output_Type;
};
/// A ComparisonDescriptor for the ComparisonLayer
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index 11d807cd89..4a01549a14 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -41,6 +41,7 @@ enum class DataType
QSymmS8 = 7,
QAsymmS8 = 8,
BFloat16 = 9,
+ Signed64 = 10,
QuantisedAsymm8 ARMNN_DEPRECATED_ENUM_MSG("Use DataType::QAsymmU8 instead.") = QAsymmU8,
QuantisedSymm16 ARMNN_DEPRECATED_ENUM_MSG("Use DataType::QSymmS16 instead.") = QSymmS16
diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp
index a2b3c95752..efc69deb67 100644
--- a/include/armnn/TypesUtils.hpp
+++ b/include/armnn/TypesUtils.hpp
@@ -120,6 +120,7 @@ constexpr unsigned int GetDataTypeSize(DataType dataType)
case DataType::Float16: return 2U;
case DataType::Float32:
case DataType::Signed32: return 4U;
+ case DataType::Signed64: return 8U;
case DataType::QAsymmU8: return 1U;
case DataType::QAsymmS8: return 1U;
case DataType::QSymmS8: return 1U;
@@ -171,6 +172,7 @@ constexpr const char* GetDataTypeName(DataType dataType)
{
case DataType::Float16: return "Float16";
case DataType::Float32: return "Float32";
+ case DataType::Signed64: return "Signed64";
case DataType::QAsymmU8: return "QAsymmU8";
case DataType::QAsymmS8: return "QAsymmS8";
case DataType::QSymmS8: return "QSymmS8";