diff options
Diffstat (limited to 'src/backends/reference/workloads/RefArgMinMaxWorkload.cpp')
-rw-r--r-- | src/backends/reference/workloads/RefArgMinMaxWorkload.cpp | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp b/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp index 5f1eb73b61..b7246d5b93 100644 --- a/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp +++ b/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp @@ -29,10 +29,15 @@ void RefArgMinMaxWorkload::Execute() const const TensorInfo &outputTensorInfo = GetTensorInfo(m_Data.m_Outputs[0]); - int32_t* output = GetOutputTensorData<int32_t>(0, m_Data); - - ArgMinMax(decoder, output, inputTensorInfo, outputTensorInfo, m_Data.m_Parameters.m_Function, - m_Data.m_Parameters.m_Axis); + if (m_Data.m_Parameters.m_Output_Type == armnn::DataType::Signed32) { + int32_t *output = GetOutputTensorData<int32_t>(0, m_Data); + ArgMinMax(decoder, output, inputTensorInfo, outputTensorInfo, m_Data.m_Parameters.m_Function, + m_Data.m_Parameters.m_Axis); + } else { + int64_t *output = GetOutputTensorData<int64_t>(0, m_Data); + ArgMinMax(decoder, output, inputTensorInfo, outputTensorInfo, m_Data.m_Parameters.m_Function, + m_Data.m_Parameters.m_Axis); + } } } //namespace armnn
\ No newline at end of file |