diff options
Diffstat (limited to 'src/backends/backendsCommon/test')
-rw-r--r-- | src/backends/backendsCommon/test/BroadcastToEndToEndTestImpl.hpp | 39 |
1 files changed, 31 insertions, 8 deletions
diff --git a/src/backends/backendsCommon/test/BroadcastToEndToEndTestImpl.hpp b/src/backends/backendsCommon/test/BroadcastToEndToEndTestImpl.hpp index 3b2c47fb94..f9de3b928f 100644 --- a/src/backends/backendsCommon/test/BroadcastToEndToEndTestImpl.hpp +++ b/src/backends/backendsCommon/test/BroadcastToEndToEndTestImpl.hpp @@ -87,7 +87,8 @@ namespace } template <armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> - void BroadcastToEndToEndElementWiseBinary(const std::vector<BackendId>& backends) + void BroadcastToEndToEndElementWiseBinary(const std::vector<BackendId>& backends, + const ElementwiseBinaryDescriptor& elementWiseDescriptor) { float qScale = 1.0f; int32_t qOffset = 0; @@ -114,17 +115,39 @@ namespace 1, 1, 1, 1 }, qScale, qOffset); - std::vector<T> expectedOutputData = armnnUtils::QuantizedVector<T>({ - 65, 144, 91, 161, - 65, 144, 91, 161, - 65, 144, 91, 161, - 65, 144, 91, 161 - }, qScale, qOffset); + std::vector<T> expectedOutputData; + if (elementWiseDescriptor.m_Operation == BinaryOperation::Mul || + elementWiseDescriptor.m_Operation == BinaryOperation::Div) { + expectedOutputData = armnnUtils::QuantizedVector<T>({ + 65, 144, 91, 161, + 65, 144, 91, 161, + 65, 144, 91, 161, + 65, 144, 91, 161 + }, qScale, qOffset); + } + else if (elementWiseDescriptor.m_Operation == BinaryOperation::Add) + { + expectedOutputData = armnnUtils::QuantizedVector<T>({ + 66, 145, 92, 162, + 66, 145, 92, 162, + 66, 145, 92, 162, + 66, 145, 92, 162 + }, qScale, qOffset); + } + else if (elementWiseDescriptor.m_Operation == BinaryOperation::Sub) + { + expectedOutputData = armnnUtils::QuantizedVector<T>({ + 64, 143, 90, 160, + 64, 143, 90, 160, + 64, 143, 90, 160, + 64, 143, 90, 160 + }, qScale, qOffset); + } auto descriptor = armnn::BroadcastToDescriptor(armnn::TensorShape({ 4, 4 })); CHECK(descriptor.m_BroadcastToShape == outputTensorShape); INetworkPtr network = CreateBroadcastToNetworkWithElementWiseBinary(descriptor, - BinaryOperation::Mul, + elementWiseDescriptor, inputInfo, inputInfoElementWise, outputInfo); |