aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r--src/backends/reference/workloads/ArgMinMax.cpp12
-rw-r--r--src/backends/reference/workloads/ArgMinMax.hpp3
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt2
-rw-r--r--src/backends/reference/workloads/RefArgMinMaxWorkload.cpp13
4 files changed, 21 insertions, 9 deletions
diff --git a/src/backends/reference/workloads/ArgMinMax.cpp b/src/backends/reference/workloads/ArgMinMax.cpp
index c455c52e5a..3bf2853a20 100644
--- a/src/backends/reference/workloads/ArgMinMax.cpp
+++ b/src/backends/reference/workloads/ArgMinMax.cpp
@@ -12,7 +12,8 @@
namespace armnn
{
-void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorInfo,
+template <typename OUT>
+void ArgMinMax(Decoder<float>& in, OUT* out, const TensorInfo& inputTensorInfo,
const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis)
{
IgnoreUnused(outputTensorInfo);
@@ -39,9 +40,16 @@ void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorIn
tmpIndex = i;
}
}
- out[outer * innerElements + inner] = armnn::numeric_cast<int32_t>(tmpIndex);
+
+ out[outer * innerElements + inner] = armnn::numeric_cast<OUT>(tmpIndex);
}
}
}
+template void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorInfo,
+ const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis);
+
+template void ArgMinMax(Decoder<float>& in, int64_t* out, const TensorInfo& inputTensorInfo,
+ const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis);
+
} //namespace armnn
diff --git a/src/backends/reference/workloads/ArgMinMax.hpp b/src/backends/reference/workloads/ArgMinMax.hpp
index 5a9c6a8a2a..3958ed7afd 100644
--- a/src/backends/reference/workloads/ArgMinMax.hpp
+++ b/src/backends/reference/workloads/ArgMinMax.hpp
@@ -13,7 +13,8 @@
namespace armnn
{
-void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorInfo,
+template <typename OUT>
+void ArgMinMax(Decoder<float>& in, OUT *out, const TensorInfo& inputTensorInfo,
const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis);
} //namespace armnn
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 937a32029e..cd9efc96af 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -5,8 +5,6 @@
list(APPEND armnnRefBackendWorkloads_sources
Abs.hpp
- ArgMinMax.cpp
- ArgMinMax.hpp
Activation.cpp
Activation.hpp
ArgMinMax.cpp
diff --git a/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp b/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp
index 5f1eb73b61..b7246d5b93 100644
--- a/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp
+++ b/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp
@@ -29,10 +29,15 @@ void RefArgMinMaxWorkload::Execute() const
const TensorInfo &outputTensorInfo = GetTensorInfo(m_Data.m_Outputs[0]);
- int32_t* output = GetOutputTensorData<int32_t>(0, m_Data);
-
- ArgMinMax(decoder, output, inputTensorInfo, outputTensorInfo, m_Data.m_Parameters.m_Function,
- m_Data.m_Parameters.m_Axis);
+ if (m_Data.m_Parameters.m_Output_Type == armnn::DataType::Signed32) {
+ int32_t *output = GetOutputTensorData<int32_t>(0, m_Data);
+ ArgMinMax(decoder, output, inputTensorInfo, outputTensorInfo, m_Data.m_Parameters.m_Function,
+ m_Data.m_Parameters.m_Axis);
+ } else {
+ int64_t *output = GetOutputTensorData<int64_t>(0, m_Data);
+ ArgMinMax(decoder, output, inputTensorInfo, outputTensorInfo, m_Data.m_Parameters.m_Function,
+ m_Data.m_Parameters.m_Axis);
+ }
}
} //namespace armnn \ No newline at end of file