diff options
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r-- | src/backends/reference/workloads/ArgMinMax.cpp | 12 | ||||
-rw-r--r-- | src/backends/reference/workloads/ArgMinMax.hpp | 3 | ||||
-rw-r--r-- | src/backends/reference/workloads/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefArgMinMaxWorkload.cpp | 13 |
4 files changed, 21 insertions, 9 deletions
diff --git a/src/backends/reference/workloads/ArgMinMax.cpp b/src/backends/reference/workloads/ArgMinMax.cpp index c455c52e5a..3bf2853a20 100644 --- a/src/backends/reference/workloads/ArgMinMax.cpp +++ b/src/backends/reference/workloads/ArgMinMax.cpp @@ -12,7 +12,8 @@ namespace armnn { -void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorInfo, +template <typename OUT> +void ArgMinMax(Decoder<float>& in, OUT* out, const TensorInfo& inputTensorInfo, const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis) { IgnoreUnused(outputTensorInfo); @@ -39,9 +40,16 @@ void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorIn tmpIndex = i; } } - out[outer * innerElements + inner] = armnn::numeric_cast<int32_t>(tmpIndex); + + out[outer * innerElements + inner] = armnn::numeric_cast<OUT>(tmpIndex); } } } +template void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorInfo, + const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis); + +template void ArgMinMax(Decoder<float>& in, int64_t* out, const TensorInfo& inputTensorInfo, + const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis); + } //namespace armnn diff --git a/src/backends/reference/workloads/ArgMinMax.hpp b/src/backends/reference/workloads/ArgMinMax.hpp index 5a9c6a8a2a..3958ed7afd 100644 --- a/src/backends/reference/workloads/ArgMinMax.hpp +++ b/src/backends/reference/workloads/ArgMinMax.hpp @@ -13,7 +13,8 @@ namespace armnn { -void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorInfo, +template <typename OUT> +void ArgMinMax(Decoder<float>& in, OUT *out, const TensorInfo& inputTensorInfo, const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis); } //namespace armnn diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 937a32029e..cd9efc96af 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -5,8 +5,6 @@ list(APPEND armnnRefBackendWorkloads_sources Abs.hpp - ArgMinMax.cpp - ArgMinMax.hpp Activation.cpp Activation.hpp ArgMinMax.cpp diff --git a/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp b/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp index 5f1eb73b61..b7246d5b93 100644 --- a/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp +++ b/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp @@ -29,10 +29,15 @@ void RefArgMinMaxWorkload::Execute() const const TensorInfo &outputTensorInfo = GetTensorInfo(m_Data.m_Outputs[0]); - int32_t* output = GetOutputTensorData<int32_t>(0, m_Data); - - ArgMinMax(decoder, output, inputTensorInfo, outputTensorInfo, m_Data.m_Parameters.m_Function, - m_Data.m_Parameters.m_Axis); + if (m_Data.m_Parameters.m_Output_Type == armnn::DataType::Signed32) { + int32_t *output = GetOutputTensorData<int32_t>(0, m_Data); + ArgMinMax(decoder, output, inputTensorInfo, outputTensorInfo, m_Data.m_Parameters.m_Function, + m_Data.m_Parameters.m_Axis); + } else { + int64_t *output = GetOutputTensorData<int64_t>(0, m_Data); + ArgMinMax(decoder, output, inputTensorInfo, outputTensorInfo, m_Data.m_Parameters.m_Function, + m_Data.m_Parameters.m_Axis); + } } } //namespace armnn
\ No newline at end of file |