ArmNN
 24.02
TestUtils.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2021-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/INetwork.hpp>
9 #include <Graph.hpp>
10 #include <Runtime.hpp>
11 
13  unsigned int fromIndex = 0, unsigned int toIndex = 0);
14 
16 {
17 public:
18  LayerNameAndTypeCheck(armnn::LayerType layerType, const char* name)
19  : m_layerType(layerType)
20  , m_name(name)
21  {}
22 
23  bool operator()(const armnn::Layer* const layer)
24  {
25  return (layer->GetNameStr() == m_name &&
26  layer->GetType() == m_layerType);
27  }
28 private:
29  armnn::LayerType m_layerType;
30  const char* m_name;
31 };
32 
33 template <typename LayerT>
34 bool IsLayerOfType(const armnn::Layer* const layer)
35 {
36  return (layer->GetType() == armnn::LayerEnumOf<LayerT>());
37 }
38 
40 {
41  return (first == last);
42 }
43 
44 /// Checks each unary function in Us evaluates true for each correspondent layer in the sequence [first, last).
45 template <typename U, typename... Us>
46 bool CheckSequence(const armnn::Graph::ConstIterator first, const armnn::Graph::ConstIterator last, U&& u, Us&&... us)
47 {
48  return u(*first) && CheckSequence(std::next(first), last, us...);
49 }
50 
51 template <typename LayerT>
52 bool CheckRelatedLayers(armnn::Graph& graph, const std::list<std::string>& testRelatedLayers)
53 {
54  for (auto& layer : graph)
55  {
56  if (layer->GetType() == armnn::LayerEnumOf<LayerT>())
57  {
58  auto& relatedLayers = layer->GetRelatedLayerNames();
59  if (!std::equal(relatedLayers.begin(), relatedLayers.end(), testRelatedLayers.begin(),
60  testRelatedLayers.end()))
61  {
62  return false;
63  }
64  }
65  }
66 
67  return true;
68 }
69 
70 namespace armnn
71 {
72 Graph& GetGraphForTesting(IOptimizedNetwork* optNetPtr);
73 ModelOptions& GetModelOptionsForTesting(IOptimizedNetwork* optNetPtr);
74 arm::pipe::IProfilingService& GetProfilingService(RuntimeImpl* runtime);
75 
76 } // namespace armnn
CheckRelatedLayers
bool CheckRelatedLayers(armnn::Graph &graph, const std::list< std::string > &testRelatedLayers)
Definition: TestUtils.hpp:52
armnn::TensorInfo
Definition: Tensor.hpp:152
Graph.hpp
armnn::Layer
Definition: Layer.hpp:230
INetwork.hpp
armnn::GetProfilingService
arm::pipe::IProfilingService & GetProfilingService(armnn::RuntimeImpl *runtime)
Definition: TestUtils.cpp:59
LayerNameAndTypeCheck::operator()
bool operator()(const armnn::Layer *const layer)
Definition: TestUtils.hpp:23
Runtime.hpp
LayerNameAndTypeCheck
Definition: TestUtils.hpp:15
IsLayerOfType
bool IsLayerOfType(const armnn::Layer *const layer)
Definition: TestUtils.hpp:34
armnn::Layer::GetNameStr
const std::string & GetNameStr() const
Definition: Layer.hpp:240
armnn::Layer::GetType
LayerType GetType() const override
Returns the armnn::LayerType of this layer.
Definition: Layer.hpp:286
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
Connect
void Connect(armnn::IConnectableLayer *from, armnn::IConnectableLayer *to, const armnn::TensorInfo &tensorInfo, unsigned int fromIndex=0, unsigned int toIndex=0)
Definition: TestUtils.cpp:14
armnn::GetGraphForTesting
Graph & GetGraphForTesting(IOptimizedNetwork *optNet)
Definition: TestUtils.cpp:49
armnn::IConnectableLayer
Interface for a layer that is connectable to other layers via InputSlots and OutputSlots.
Definition: INetwork.hpp:80
armnn::ModelOptions
std::vector< BackendOptions > ModelOptions
Definition: BackendOptions.hpp:18
armnn::GetModelOptionsForTesting
ModelOptions & GetModelOptionsForTesting(IOptimizedNetwork *optNet)
Definition: TestUtils.cpp:54
CheckSequence
bool CheckSequence(const armnn::Graph::ConstIterator first, const armnn::Graph::ConstIterator last)
Definition: TestUtils.hpp:39
LayerNameAndTypeCheck::LayerNameAndTypeCheck
LayerNameAndTypeCheck(armnn::LayerType layerType, const char *name)
Definition: TestUtils.hpp:18
armnn::LayerType
LayerType
When adding a new layer, adapt also the LastLayer enum value in the enum class LayerType below.
Definition: Types.hpp:491
armnn::Graph
Definition: Graph.hpp:30
armnn::TransformIterator
Definition: TransformIterator.hpp:25