diff options
Diffstat (limited to 'src/backends/reference/workloads')
4 files changed, 75 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp index c5b0ad1f24..4044f06ac4 100644 --- a/src/backends/reference/workloads/ElementwiseFunction.cpp +++ b/src/backends/reference/workloads/ElementwiseFunction.cpp @@ -14,6 +14,8 @@ #include "Rsqrt.hpp" #include "Sin.hpp" #include "Sqrt.hpp" +#include "Power.hpp" +#include "SquaredDifference.hpp" namespace armnn @@ -67,6 +69,8 @@ template struct armnn::ElementwiseBinaryFunction<std::multiplies<float>>; template struct armnn::ElementwiseBinaryFunction<std::divides<float>>; template struct armnn::ElementwiseBinaryFunction<armnn::maximum<float>>; template struct armnn::ElementwiseBinaryFunction<armnn::minimum<float>>; +template struct armnn::ElementwiseBinaryFunction<armnn::power<float>>; +template struct armnn::ElementwiseBinaryFunction<armnn::squaredDifference<float>>; template struct armnn::ElementwiseBinaryFunction<std::plus<int32_t>>; template struct armnn::ElementwiseBinaryFunction<std::minus<int32_t>>; @@ -74,6 +78,8 @@ template struct armnn::ElementwiseBinaryFunction<std::multiplies<int32_t>>; template struct armnn::ElementwiseBinaryFunction<std::divides<int32_t>>; template struct armnn::ElementwiseBinaryFunction<armnn::maximum<int32_t>>; template struct armnn::ElementwiseBinaryFunction<armnn::minimum<int32_t>>; +template struct armnn::ElementwiseBinaryFunction<armnn::power<int32_t>>; +template struct armnn::ElementwiseBinaryFunction<armnn::squaredDifference<int32_t>>; // Comparison template struct armnn::ElementwiseBinaryFunction<std::equal_to<float>>; diff --git a/src/backends/reference/workloads/Power.hpp b/src/backends/reference/workloads/Power.hpp new file mode 100644 index 0000000000..744328e8bf --- /dev/null +++ b/src/backends/reference/workloads/Power.hpp @@ -0,0 +1,27 @@ +// +// Copyright © 2023 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <iostream> + +namespace armnn +{ + +template<typename T> +struct power +{ + typedef T result_type; + typedef T first_argument_type; + + T + operator()(const T& input1, const T& input2) const + { + T power = armnn::numeric_cast<T>(std::pow(static_cast<float>(input1), static_cast<float>(input2))); + return power; + } +}; + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp b/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp index 5dc77f8496..e71cdd4e3c 100644 --- a/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp +++ b/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp @@ -11,6 +11,8 @@ #include "RefWorkloadUtils.hpp" #include "Maximum.hpp" #include "Minimum.hpp" +#include "SquaredDifference.hpp" +#include "Power.hpp" #include <Profiling.hpp> @@ -44,6 +46,8 @@ void ExecuteFunction(std::vector<ITensorHandle*> inputs, using MinimumFunction = ElementwiseBinaryFunction<armnn::minimum<DataType>>; using MulFunction = ElementwiseBinaryFunction<std::multiplies<DataType>>; using SubFunction = ElementwiseBinaryFunction<std::minus<DataType>>; + using SqDiffFunction = ElementwiseBinaryFunction<armnn::squaredDifference<DataType>>; + using PowerFunction = ElementwiseBinaryFunction<armnn::power<DataType>>; switch (operation) { @@ -77,6 +81,16 @@ void ExecuteFunction(std::vector<ITensorHandle*> inputs, SubFunction(inShape0, inShape1, outShape, *input0, *input1, *output); break; } + case BinaryOperation::SqDiff: + { + SqDiffFunction(inShape0, inShape1, outShape, *input0, *input1, *output); + break; + } + case BinaryOperation::Power: + { + PowerFunction(inShape0, inShape1, outShape, *input0, *input1, *output); + break; + } default: { throw InvalidArgumentException(std::string("Unsupported binary operation ") + diff --git a/src/backends/reference/workloads/SquaredDifference.hpp b/src/backends/reference/workloads/SquaredDifference.hpp new file mode 100644 index 0000000000..c15b379a4d --- /dev/null +++ b/src/backends/reference/workloads/SquaredDifference.hpp @@ -0,0 +1,28 @@ +// +// Copyright © 2023 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <cmath> + +namespace armnn +{ + +template<typename T> +struct squaredDifference +{ + typedef T result_type; + typedef T first_argument_type; + + T + operator()(const T& input1, const T& input2) const + { + float diff = std::minus<>{}(static_cast<float>(input1),static_cast<float>(input2)); + T squaredDiff = armnn::numeric_cast<T>(std::pow(static_cast<float>(diff), 2)); + return squaredDiff; + } +}; + +} //namespace armnn |