aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/ArgMinMax.cpp
diff options
context:
space:
mode:
authorInki Dae <inki.dae@samsung.com>2020-09-10 15:33:54 +0900
committermike.kelly <mike.kelly@arm.com>2020-09-24 16:03:00 +0000
commitd4619e28a4cde423d5b4086a98c31f97b52a68d7 (patch)
tree9e6da174b200e8135ed2bbc2b0b8cb761d3e1a4f /src/backends/reference/workloads/ArgMinMax.cpp
parent02036e99c1b2074e5e5f04a2fe443f0c90689683 (diff)
downloadarmnn-d4619e28a4cde423d5b4086a98c31f97b52a68d7.tar.gz
Add int32 and int64 ArgMax op support
This patch adds int32 and int64 ArgMax op support. Current ARMNN already has ArgMax op but not used, and it doesn't support int64 output type. So this patch adds a new type, Signed64, and also adds ArgMinMax computation function for int64 type support. In default, output tensor type of ArgMax op is int64 in case of tensorflow lite model so this patch makes a proper function - ArgMax op for int64 or int32 - to be called according to parsed output_type value. With this patch, ARMNN supports both types - int64 and int32 - for ArgMinMax op. Changelog v1: - Check if output data type of ArgMinMax op is valid or not. - Use template function to support int32 and int64 types of ArgMinMax function. - Keep using Signed32 as default data type of m_Output_Type. Change-Id: I7a8e7e38dd9e5acc81464571d8b4d51378fc7f14 Signed-off-by: Inki Dae <inki.dae@samsung.com>
Diffstat (limited to 'src/backends/reference/workloads/ArgMinMax.cpp')
-rw-r--r--src/backends/reference/workloads/ArgMinMax.cpp12
1 files changed, 10 insertions, 2 deletions
diff --git a/src/backends/reference/workloads/ArgMinMax.cpp b/src/backends/reference/workloads/ArgMinMax.cpp
index c455c52e5a..3bf2853a20 100644
--- a/src/backends/reference/workloads/ArgMinMax.cpp
+++ b/src/backends/reference/workloads/ArgMinMax.cpp
@@ -12,7 +12,8 @@
namespace armnn
{
-void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorInfo,
+template <typename OUT>
+void ArgMinMax(Decoder<float>& in, OUT* out, const TensorInfo& inputTensorInfo,
const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis)
{
IgnoreUnused(outputTensorInfo);
@@ -39,9 +40,16 @@ void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorIn
tmpIndex = i;
}
}
- out[outer * innerElements + inner] = armnn::numeric_cast<int32_t>(tmpIndex);
+
+ out[outer * innerElements + inner] = armnn::numeric_cast<OUT>(tmpIndex);
}
}
}
+template void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorInfo,
+ const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis);
+
+template void ArgMinMax(Decoder<float>& in, int64_t* out, const TensorInfo& inputTensorInfo,
+ const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis);
+
} //namespace armnn