aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r--src/backends/reference/workloads/ElementwiseFunction.cpp6
-rw-r--r--src/backends/reference/workloads/Power.hpp27
-rw-r--r--src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp14
-rw-r--r--src/backends/reference/workloads/SquaredDifference.hpp28
4 files changed, 75 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp
index c5b0ad1f24..4044f06ac4 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.cpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.cpp
@@ -14,6 +14,8 @@
#include "Rsqrt.hpp"
#include "Sin.hpp"
#include "Sqrt.hpp"
+#include "Power.hpp"
+#include "SquaredDifference.hpp"
namespace armnn
@@ -67,6 +69,8 @@ template struct armnn::ElementwiseBinaryFunction<std::multiplies<float>>;
template struct armnn::ElementwiseBinaryFunction<std::divides<float>>;
template struct armnn::ElementwiseBinaryFunction<armnn::maximum<float>>;
template struct armnn::ElementwiseBinaryFunction<armnn::minimum<float>>;
+template struct armnn::ElementwiseBinaryFunction<armnn::power<float>>;
+template struct armnn::ElementwiseBinaryFunction<armnn::squaredDifference<float>>;
template struct armnn::ElementwiseBinaryFunction<std::plus<int32_t>>;
template struct armnn::ElementwiseBinaryFunction<std::minus<int32_t>>;
@@ -74,6 +78,8 @@ template struct armnn::ElementwiseBinaryFunction<std::multiplies<int32_t>>;
template struct armnn::ElementwiseBinaryFunction<std::divides<int32_t>>;
template struct armnn::ElementwiseBinaryFunction<armnn::maximum<int32_t>>;
template struct armnn::ElementwiseBinaryFunction<armnn::minimum<int32_t>>;
+template struct armnn::ElementwiseBinaryFunction<armnn::power<int32_t>>;
+template struct armnn::ElementwiseBinaryFunction<armnn::squaredDifference<int32_t>>;
// Comparison
template struct armnn::ElementwiseBinaryFunction<std::equal_to<float>>;
diff --git a/src/backends/reference/workloads/Power.hpp b/src/backends/reference/workloads/Power.hpp
new file mode 100644
index 0000000000..744328e8bf
--- /dev/null
+++ b/src/backends/reference/workloads/Power.hpp
@@ -0,0 +1,27 @@
+//
+// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <iostream>
+
+namespace armnn
+{
+
+template<typename T>
+struct power
+{
+ typedef T result_type;
+ typedef T first_argument_type;
+
+ T
+ operator()(const T& input1, const T& input2) const
+ {
+ T power = armnn::numeric_cast<T>(std::pow(static_cast<float>(input1), static_cast<float>(input2)));
+ return power;
+ }
+};
+
+} //namespace armnn
diff --git a/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp b/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp
index 5dc77f8496..e71cdd4e3c 100644
--- a/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp
+++ b/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp
@@ -11,6 +11,8 @@
#include "RefWorkloadUtils.hpp"
#include "Maximum.hpp"
#include "Minimum.hpp"
+#include "SquaredDifference.hpp"
+#include "Power.hpp"
#include <Profiling.hpp>
@@ -44,6 +46,8 @@ void ExecuteFunction(std::vector<ITensorHandle*> inputs,
using MinimumFunction = ElementwiseBinaryFunction<armnn::minimum<DataType>>;
using MulFunction = ElementwiseBinaryFunction<std::multiplies<DataType>>;
using SubFunction = ElementwiseBinaryFunction<std::minus<DataType>>;
+ using SqDiffFunction = ElementwiseBinaryFunction<armnn::squaredDifference<DataType>>;
+ using PowerFunction = ElementwiseBinaryFunction<armnn::power<DataType>>;
switch (operation)
{
@@ -77,6 +81,16 @@ void ExecuteFunction(std::vector<ITensorHandle*> inputs,
SubFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
break;
}
+ case BinaryOperation::SqDiff:
+ {
+ SqDiffFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
+ break;
+ }
+ case BinaryOperation::Power:
+ {
+ PowerFunction(inShape0, inShape1, outShape, *input0, *input1, *output);
+ break;
+ }
default:
{
throw InvalidArgumentException(std::string("Unsupported binary operation ") +
diff --git a/src/backends/reference/workloads/SquaredDifference.hpp b/src/backends/reference/workloads/SquaredDifference.hpp
new file mode 100644
index 0000000000..c15b379a4d
--- /dev/null
+++ b/src/backends/reference/workloads/SquaredDifference.hpp
@@ -0,0 +1,28 @@
+//
+// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <cmath>
+
+namespace armnn
+{
+
+template<typename T>
+struct squaredDifference
+{
+ typedef T result_type;
+ typedef T first_argument_type;
+
+ T
+ operator()(const T& input1, const T& input2) const
+ {
+ float diff = std::minus<>{}(static_cast<float>(input1),static_cast<float>(input2));
+ T squaredDiff = armnn::numeric_cast<T>(std::pow(static_cast<float>(diff), 2));
+ return squaredDiff;
+ }
+};
+
+} //namespace armnn