ArmNN
 20.05
ActivationFixture.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "TensorCopyUtils.hpp"
8 #include "WorkloadTestUtils.hpp"
9 
10 #include <test/TensorHelpers.hpp>
11 
12 #include <boost/numeric/conversion/cast.hpp>
13 #include <boost/multi_array.hpp>
14 
15 struct ActivationFixture
16 {
17  ActivationFixture()
18  {
19  auto boostArrayExtents = boost::extents
20  [boost::numeric_cast<boost::multi_array_types::extent_gen::index>(batchSize)]
21  [boost::numeric_cast<boost::multi_array_types::extent_gen::index>(channels)]
22  [boost::numeric_cast<boost::multi_array_types::extent_gen::index>(height)]
23  [boost::numeric_cast<boost::multi_array_types::extent_gen::index>(width)];
24  output.resize(boostArrayExtents);
25  outputExpected.resize(boostArrayExtents);
26  input.resize(boostArrayExtents);
27 
28  unsigned int inputShape[] = { batchSize, channels, height, width };
29  unsigned int outputShape[] = { batchSize, channels, height, width };
30 
31  inputTensorInfo = armnn::TensorInfo(4, inputShape, armnn::DataType::Float32);
32  outputTensorInfo = armnn::TensorInfo(4, outputShape, armnn::DataType::Float32);
33 
34  input = MakeRandomTensor<float, 4>(inputTensorInfo, 21453);
35  }
36 
37  unsigned int width = 17;
38  unsigned int height = 29;
39  unsigned int channels = 2;
40  unsigned int batchSize = 5;
41 
42  boost::multi_array<float, 4> output;
43  boost::multi_array<float, 4> outputExpected;
44  boost::multi_array<float, 4> input;
45 
46  armnn::TensorInfo inputTensorInfo;
47  armnn::TensorInfo outputTensorInfo;
48 
49  // Parameters used by some of the activation functions.
50  float a = 0.234f;
51  float b = -12.345f;
52 };
53 
54 
55 struct PositiveActivationFixture : public ActivationFixture
56 {
58  {
59  input = MakeRandomTensor<float, 4>(inputTensorInfo, 2342423, 0.0f, 1.0f);
60  }
61 };
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
Definition: NumericCast.hpp:33