diff options
author | Nikhil Raj <nikhil.raj@arm.com> | 2019-07-19 15:15:23 +0100 |
---|---|---|
committer | Nikhil Raj Arm <nikhil.raj@arm.com> | 2019-07-19 14:33:04 +0000 |
commit | 747f586b9ae0b206c45d9678becfa4b7c092aeb7 (patch) | |
tree | 81dbc8ba29ef0b65860749b0a7c1b0f41db1f2a1 /src/backends/backendsCommon/test/PreluEndToEndTestImpl.hpp | |
parent | 598950d611304ffeb9d57a5b28a13b0ddd629026 (diff) | |
download | armnn-747f586b9ae0b206c45d9678becfa4b7c092aeb7.tar.gz |
IVGCVSW-3479 Add End to End test for Prelu
Change-Id: I041bdf9e721a4384ea3c2be0184787dd1f4ea08e
Signed-off-by: Nikhil Raj <nikhil.raj@arm.com>
Diffstat (limited to 'src/backends/backendsCommon/test/PreluEndToEndTestImpl.hpp')
-rw-r--r-- | src/backends/backendsCommon/test/PreluEndToEndTestImpl.hpp | 94 |
1 files changed, 94 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/test/PreluEndToEndTestImpl.hpp b/src/backends/backendsCommon/test/PreluEndToEndTestImpl.hpp new file mode 100644 index 0000000000..0dc1e78ced --- /dev/null +++ b/src/backends/backendsCommon/test/PreluEndToEndTestImpl.hpp @@ -0,0 +1,94 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include <ResolveType.hpp> + +#include <armnn/INetwork.hpp> + +#include <backendsCommon/test/CommonTestUtils.hpp> + +namespace +{ +template<typename armnn::DataType DataType> +INetworkPtr CreatePreluNetwork(const armnn::TensorInfo& inputInfo, + const armnn::TensorInfo& alphaInfo, + const armnn::TensorInfo& outputInfo) +{ + using namespace armnn; + + INetworkPtr net(INetwork::Create()); + + IConnectableLayer* input = net->AddInputLayer(0, "input"); + IConnectableLayer* alpha = net->AddInputLayer(1, "alpha"); + IConnectableLayer* prelu = net->AddPreluLayer("Prelu"); + IConnectableLayer* output = net->AddOutputLayer(0, "output"); + + Connect(input, prelu, inputInfo, 0, 0); + Connect(alpha, prelu, alphaInfo, 0, 1); + Connect(prelu, output, outputInfo, 0, 0); + + return net; +} + +template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> +void PreluEndToEnd(const std::vector<BackendId>& backends, + const std::vector<T>& inputData, + const std::vector<T>& alphaData, + const std::vector<T>& expectedOutputData, + const float qScale , + const int32_t qOffset) +{ + using namespace armnn; + + armnn::TensorInfo inputInfo({ 2, 2, 2, 1 }, ArmnnType); + armnn::TensorInfo alphaInfo({ 1, 2, 2, 1 }, ArmnnType); + armnn::TensorInfo outputInfo({ 2, 2, 2, 1 }, ArmnnType); + + inputInfo.SetQuantizationOffset(qOffset); + inputInfo.SetQuantizationScale(qScale); + alphaInfo.SetQuantizationOffset(qOffset); + alphaInfo.SetQuantizationScale(qScale); + outputInfo.SetQuantizationOffset(qOffset); + outputInfo.SetQuantizationScale(qScale); + + INetworkPtr net = CreatePreluNetwork<ArmnnType>(inputInfo, alphaInfo, outputInfo); + + BOOST_TEST_CHECKPOINT("Create a network"); + + std::map<int, std::vector<T>> inputTensorData = { { 0, inputData }, { 1, alphaData} }; + std::map<int, std::vector<T>> expectedOutputTensorData = { { 0, expectedOutputData } }; + + EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), + inputTensorData, + expectedOutputTensorData, + backends); +} + +template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> +void PreluEndToEndPositiveTest(const std::vector<BackendId>& backends, const float qScale = 1.0f, + const int32_t qOffset = 2) +{ + std::vector<T> inputData{ 1, 2, 3, 4, 5, 6, 7, 8 }; + std::vector<T> alphaData{ 2, 1, 1, 1 }; + + std::vector<T> expectedOutputData{ 2, 2, 3, 4, 5, 6, 7, 8 }; + + PreluEndToEnd<ArmnnType>(backends, inputData, alphaData, expectedOutputData, qScale, qOffset); +} + +template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>> +void PreluEndToEndNegativeTest(const std::vector<BackendId>& backends, const float qScale = 1.0f, + const int32_t qOffset = 0) +{ + std::vector<T> inputData{ 1, -2, 3, 4, 5, 6, 7, 8 }; + std::vector<T> alphaData{ 1, 2, 1, 1 }; + + std::vector<T> expectedOutputData{ 1, -4, 3, 4, 5, 6, 7, 8 }; + + PreluEndToEnd<ArmnnType>(backends, inputData, alphaData, expectedOutputData, qScale, qOffset); +} + +} // anonymous namespace
\ No newline at end of file |