From 0ec008761ab26110dcb108d544be4040a14fd403 Mon Sep 17 00:00:00 2001 From: John Mcloughlin Date: Mon, 15 May 2023 17:03:49 +0100 Subject: IVGCVSW-7400 POW IVGCVSW-7278 SQUARED_DIFFERENCE. * Added 2 new operators as ElementWiseBinary ops * Ref End to End and unit tests * Serialize and Deserialize tests * Delegate and Opaque Delegate tests * TfLite Parser tests Signed-off-by: John Mcloughlin Change-Id: I537158127f602f0c41ca0402aa31655cd3bd4281 --- .../reference/workloads/ElementwiseFunction.cpp | 6 +++++ src/backends/reference/workloads/Power.hpp | 27 +++++++++++++++++++++ .../workloads/RefElementwiseBinaryWorkload.cpp | 14 +++++++++++ .../reference/workloads/SquaredDifference.hpp | 28 ++++++++++++++++++++++ 4 files changed, 75 insertions(+) create mode 100644 src/backends/reference/workloads/Power.hpp create mode 100644 src/backends/reference/workloads/SquaredDifference.hpp (limited to 'src/backends/reference/workloads') diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp index c5b0ad1f24..4044f06ac4 100644 --- a/src/backends/reference/workloads/ElementwiseFunction.cpp +++ b/src/backends/reference/workloads/ElementwiseFunction.cpp @@ -14,6 +14,8 @@ #include "Rsqrt.hpp" #include "Sin.hpp" #include "Sqrt.hpp" +#include "Power.hpp" +#include "SquaredDifference.hpp" namespace armnn @@ -67,6 +69,8 @@ template struct armnn::ElementwiseBinaryFunction>; template struct armnn::ElementwiseBinaryFunction>; template struct armnn::ElementwiseBinaryFunction>; template struct armnn::ElementwiseBinaryFunction>; +template struct armnn::ElementwiseBinaryFunction>; +template struct armnn::ElementwiseBinaryFunction>; template struct armnn::ElementwiseBinaryFunction>; template struct armnn::ElementwiseBinaryFunction>; @@ -74,6 +78,8 @@ template struct armnn::ElementwiseBinaryFunction>; template struct armnn::ElementwiseBinaryFunction>; template struct armnn::ElementwiseBinaryFunction>; template struct armnn::ElementwiseBinaryFunction>; +template struct armnn::ElementwiseBinaryFunction>; +template struct armnn::ElementwiseBinaryFunction>; // Comparison template struct armnn::ElementwiseBinaryFunction>; diff --git a/src/backends/reference/workloads/Power.hpp b/src/backends/reference/workloads/Power.hpp new file mode 100644 index 0000000000..744328e8bf --- /dev/null +++ b/src/backends/reference/workloads/Power.hpp @@ -0,0 +1,27 @@ +// +// Copyright © 2023 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include + +namespace armnn +{ + +template +struct power +{ + typedef T result_type; + typedef T first_argument_type; + + T + operator()(const T& input1, const T& input2) const + { + T power = armnn::numeric_cast(std::pow(static_cast(input1), static_cast(input2))); + return power; + } +}; + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp b/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp index 5dc77f8496..e71cdd4e3c 100644 --- a/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp +++ b/src/backends/reference/workloads/RefElementwiseBinaryWorkload.cpp @@ -11,6 +11,8 @@ #include "RefWorkloadUtils.hpp" #include "Maximum.hpp" #include "Minimum.hpp" +#include "SquaredDifference.hpp" +#include "Power.hpp" #include @@ -44,6 +46,8 @@ void ExecuteFunction(std::vector inputs, using MinimumFunction = ElementwiseBinaryFunction>; using MulFunction = ElementwiseBinaryFunction>; using SubFunction = ElementwiseBinaryFunction>; + using SqDiffFunction = ElementwiseBinaryFunction>; + using PowerFunction = ElementwiseBinaryFunction>; switch (operation) { @@ -77,6 +81,16 @@ void ExecuteFunction(std::vector inputs, SubFunction(inShape0, inShape1, outShape, *input0, *input1, *output); break; } + case BinaryOperation::SqDiff: + { + SqDiffFunction(inShape0, inShape1, outShape, *input0, *input1, *output); + break; + } + case BinaryOperation::Power: + { + PowerFunction(inShape0, inShape1, outShape, *input0, *input1, *output); + break; + } default: { throw InvalidArgumentException(std::string("Unsupported binary operation ") + diff --git a/src/backends/reference/workloads/SquaredDifference.hpp b/src/backends/reference/workloads/SquaredDifference.hpp new file mode 100644 index 0000000000..c15b379a4d --- /dev/null +++ b/src/backends/reference/workloads/SquaredDifference.hpp @@ -0,0 +1,28 @@ +// +// Copyright © 2023 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include + +namespace armnn +{ + +template +struct squaredDifference +{ + typedef T result_type; + typedef T first_argument_type; + + T + operator()(const T& input1, const T& input2) const + { + float diff = std::minus<>{}(static_cast(input1),static_cast(input2)); + T squaredDiff = armnn::numeric_cast(std::pow(static_cast(diff), 2)); + return squaredDiff; + } +}; + +} //namespace armnn -- cgit v1.2.1