ArmNN
 21.11
PreluEndToEndTestImpl.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <ResolveType.hpp>
8 
9 #include <armnn/INetwork.hpp>
10 
12 
13 #include <doctest/doctest.h>
14 
15 namespace
16 {
17 template<typename armnn::DataType DataType>
18 INetworkPtr CreatePreluNetwork(const armnn::TensorInfo& inputInfo,
19  const armnn::TensorInfo& alphaInfo,
20  const armnn::TensorInfo& outputInfo)
21 {
22  using namespace armnn;
23 
25 
26  IConnectableLayer* input = net->AddInputLayer(0, "input");
27  IConnectableLayer* alpha = net->AddInputLayer(1, "alpha");
28  IConnectableLayer* prelu = net->AddPreluLayer("Prelu");
29  IConnectableLayer* output = net->AddOutputLayer(0, "output");
30 
31  Connect(input, prelu, inputInfo, 0, 0);
32  Connect(alpha, prelu, alphaInfo, 0, 1);
33  Connect(prelu, output, outputInfo, 0, 0);
34 
35  return net;
36 }
37 
38 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
39 void PreluEndToEnd(const std::vector<BackendId>& backends,
40  const std::vector<T>& inputData,
41  const std::vector<T>& alphaData,
42  const std::vector<T>& expectedOutputData,
43  const float qScale ,
44  const int32_t qOffset)
45 {
46  using namespace armnn;
47 
48  armnn::TensorInfo inputInfo({ 2, 2, 2, 1 }, ArmnnType);
49  armnn::TensorInfo alphaInfo({ 1, 2, 2, 1 }, ArmnnType);
50  armnn::TensorInfo outputInfo({ 2, 2, 2, 1 }, ArmnnType);
51 
52  inputInfo.SetQuantizationOffset(qOffset);
53  inputInfo.SetQuantizationScale(qScale);
54  inputInfo.SetConstant(true);
55  alphaInfo.SetQuantizationOffset(qOffset);
56  alphaInfo.SetQuantizationScale(qScale);
57  alphaInfo.SetConstant(true);
58  outputInfo.SetQuantizationOffset(qOffset);
59  outputInfo.SetQuantizationScale(qScale);
60 
61  INetworkPtr net = CreatePreluNetwork<ArmnnType>(inputInfo, alphaInfo, outputInfo);
62 
63  CHECK(net);
64 
65  std::map<int, std::vector<T>> inputTensorData = { { 0, inputData }, { 1, alphaData} };
66  std::map<int, std::vector<T>> expectedOutputTensorData = { { 0, expectedOutputData } };
67 
68  EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net),
69  inputTensorData,
70  expectedOutputTensorData,
71  backends);
72 }
73 
74 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
75 void PreluEndToEndPositiveTest(const std::vector<BackendId>& backends, const float qScale = 1.0f,
76  const int32_t qOffset = 2)
77 {
78  std::vector<T> inputData{ 1, 2, 3, 4, 5, 6, 7, 8 };
79  std::vector<T> alphaData{ 2, 1, 1, 1 };
80 
81  std::vector<T> expectedOutputData{ 2, 2, 3, 4, 5, 6, 7, 8 };
82 
83  PreluEndToEnd<ArmnnType>(backends, inputData, alphaData, expectedOutputData, qScale, qOffset);
84 }
85 
86 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
87 void PreluEndToEndNegativeTest(const std::vector<BackendId>& backends, const float qScale = 1.0f,
88  const int32_t qOffset = 0)
89 {
90  std::vector<T> inputData{ 1, -2, 3, 4, 5, 6, 7, 8 };
91  std::vector<T> alphaData{ 1, 2, 1, 1 };
92 
93  std::vector<T> expectedOutputData{ 1, -4, 3, 4, 5, 6, 7, 8 };
94 
95  PreluEndToEnd<ArmnnType>(backends, inputData, alphaData, expectedOutputData, qScale, qOffset);
96 }
97 
98 } // anonymous namespace
Interface for a layer that is connectable to other layers via InputSlots and OutputSlots.
Definition: INetwork.hpp:61
Copyright (c) 2021 ARM Limited and Contributors.
void SetQuantizationScale(float scale)
Definition: Tensor.cpp:475
void SetConstant(const bool IsConstant=true)
Marks the data corresponding to this tensor info as constant.
Definition: Tensor.cpp:516
void SetQuantizationOffset(int32_t offset)
Definition: Tensor.cpp:491
void Connect(armnn::IConnectableLayer *from, armnn::IConnectableLayer *to, const armnn::TensorInfo &tensorInfo, unsigned int fromIndex, unsigned int toIndex)
Definition: TestUtils.cpp:12
std::unique_ptr< INetwork, void(*)(INetwork *network)> INetworkPtr
Definition: INetwork.hpp:197
static INetworkPtr Create(NetworkOptions networkOptions={})
Definition: Network.cpp:478