aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/test
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/test')
-rw-r--r--src/backends/backendsCommon/test/BroadcastToEndToEndTestImpl.hpp39
1 files changed, 31 insertions, 8 deletions
diff --git a/src/backends/backendsCommon/test/BroadcastToEndToEndTestImpl.hpp b/src/backends/backendsCommon/test/BroadcastToEndToEndTestImpl.hpp
index 3b2c47fb94..f9de3b928f 100644
--- a/src/backends/backendsCommon/test/BroadcastToEndToEndTestImpl.hpp
+++ b/src/backends/backendsCommon/test/BroadcastToEndToEndTestImpl.hpp
@@ -87,7 +87,8 @@ namespace
}
template <armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
- void BroadcastToEndToEndElementWiseBinary(const std::vector<BackendId>& backends)
+ void BroadcastToEndToEndElementWiseBinary(const std::vector<BackendId>& backends,
+ const ElementwiseBinaryDescriptor& elementWiseDescriptor)
{
float qScale = 1.0f;
int32_t qOffset = 0;
@@ -114,17 +115,39 @@ namespace
1, 1, 1, 1
}, qScale, qOffset);
- std::vector<T> expectedOutputData = armnnUtils::QuantizedVector<T>({
- 65, 144, 91, 161,
- 65, 144, 91, 161,
- 65, 144, 91, 161,
- 65, 144, 91, 161
- }, qScale, qOffset);
+ std::vector<T> expectedOutputData;
+ if (elementWiseDescriptor.m_Operation == BinaryOperation::Mul ||
+ elementWiseDescriptor.m_Operation == BinaryOperation::Div) {
+ expectedOutputData = armnnUtils::QuantizedVector<T>({
+ 65, 144, 91, 161,
+ 65, 144, 91, 161,
+ 65, 144, 91, 161,
+ 65, 144, 91, 161
+ }, qScale, qOffset);
+ }
+ else if (elementWiseDescriptor.m_Operation == BinaryOperation::Add)
+ {
+ expectedOutputData = armnnUtils::QuantizedVector<T>({
+ 66, 145, 92, 162,
+ 66, 145, 92, 162,
+ 66, 145, 92, 162,
+ 66, 145, 92, 162
+ }, qScale, qOffset);
+ }
+ else if (elementWiseDescriptor.m_Operation == BinaryOperation::Sub)
+ {
+ expectedOutputData = armnnUtils::QuantizedVector<T>({
+ 64, 143, 90, 160,
+ 64, 143, 90, 160,
+ 64, 143, 90, 160,
+ 64, 143, 90, 160
+ }, qScale, qOffset);
+ }
auto descriptor = armnn::BroadcastToDescriptor(armnn::TensorShape({ 4, 4 }));
CHECK(descriptor.m_BroadcastToShape == outputTensorShape);
INetworkPtr network = CreateBroadcastToNetworkWithElementWiseBinary(descriptor,
- BinaryOperation::Mul,
+ elementWiseDescriptor,
inputInfo,
inputInfoElementWise,
outputInfo);