diff options
Diffstat (limited to 'src/backends/reference/workloads/ArgMinMax.cpp')
-rw-r--r-- | src/backends/reference/workloads/ArgMinMax.cpp | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/src/backends/reference/workloads/ArgMinMax.cpp b/src/backends/reference/workloads/ArgMinMax.cpp index c455c52e5a..3bf2853a20 100644 --- a/src/backends/reference/workloads/ArgMinMax.cpp +++ b/src/backends/reference/workloads/ArgMinMax.cpp @@ -12,7 +12,8 @@ namespace armnn { -void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorInfo, +template <typename OUT> +void ArgMinMax(Decoder<float>& in, OUT* out, const TensorInfo& inputTensorInfo, const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis) { IgnoreUnused(outputTensorInfo); @@ -39,9 +40,16 @@ void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorIn tmpIndex = i; } } - out[outer * innerElements + inner] = armnn::numeric_cast<int32_t>(tmpIndex); + + out[outer * innerElements + inner] = armnn::numeric_cast<OUT>(tmpIndex); } } } +template void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorInfo, + const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis); + +template void ArgMinMax(Decoder<float>& in, int64_t* out, const TensorInfo& inputTensorInfo, + const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis); + } //namespace armnn |