ArmNN
 20.08
LayerSupportRules.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
9 #include <algorithm>
10 
11 namespace armnn
12 {
13 
15 {
16  if (!weightsType)
17  {
18  return weightsType;
19  }
20 
21  switch(weightsType.value())
22  {
25  return weightsType;
32  default:
33  ARMNN_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
34  }
35  return armnn::EmptyOptional();
36 }
37 
38 template<typename F>
39 bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
40 {
41  bool supported = rule();
42  if (!supported && reason)
43  {
44  reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
45  }
46  return supported;
47 }
48 
49 struct Rule
50 {
51  bool operator()() const
52  {
53  return m_Res;
54  }
55 
56  bool m_Res = true;
57 };
58 
59 template<typename T>
61 {
62  return true;
63 }
64 
65 template<typename T, typename... Rest>
66 bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
67 {
68  static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
69 
70  return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
71 }
72 
73 struct TypesAreEqual : public Rule
74 {
75  template<typename ... Ts>
76  TypesAreEqual(const Ts&... ts)
77  {
78  m_Res = AllTypesAreEqualImpl(ts...);
79  }
80 };
81 
83 {
85  {
86  m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
88  }
89 };
90 
91 struct TypeAnyOf : public Rule
92 {
93  template<typename Container>
94  TypeAnyOf(const TensorInfo& info, const Container& c)
95  {
96  m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
97  {
98  return dt == info.GetDataType();
99  });
100  }
101 };
102 
103 struct TypeIs : public Rule
104 {
106  {
107  m_Res = dt == info.GetDataType();
108  }
109 };
110 
112 {
114  {
115  m_Res = !info.IsQuantized() || !info.HasPerAxisQuantization();
116  }
117 };
118 
120 {
121  BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
122  {
123  m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
124  }
125 };
126 
128 {
129  template<typename Container>
130  BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
131  {
132  m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
133  {
134  return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
135  });
136  }
137 };
138 
139 struct ShapesAreSameRank : public Rule
140 {
141  ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
142  {
143  m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
144  }
145 };
146 
148 {
149  ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
150  {
151  m_Res = info0.GetNumElements() == info1.GetNumElements();
152  }
153 };
154 
156 {
157  unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
158  {
159  unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
160  unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
161  return sizeIn;
162  }
163 
164  ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
165  {
166  const TensorShape& shape0 = in0.GetShape();
167  const TensorShape& shape1 = in1.GetShape();
168  const TensorShape& outShape = out.GetShape();
169 
170  for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
171  {
172  unsigned int sizeOut = outShape[i];
173  unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
174  unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
175 
176  m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
177  ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
178  }
179  }
180 };
181 
183 {
184  TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
185  {
186  m_Res = info.GetNumDimensions() == expectedNumDimensions;
187  }
188 };
189 
190 } //namespace armnn
TypeNotPerAxisQuantized(const TensorInfo &info)
TypeAnyOf(const TensorInfo &info, const Container &c)
const TensorShape & GetShape() const
Definition: Tensor.hpp:187
TypeIs(const TensorInfo &info, DataType dt)
bool HasPerAxisQuantization() const
Definition: Tensor.cpp:438
ISubgraphViewConverter supported
QuantizationParametersAreEqual(const TensorInfo &info0, const TensorInfo &info1)
TypesAreEqual(const Ts &... ts)
bool operator()() const
Copyright (c) 2020 ARM Limited.
ShapesAreSameTotalSize(const TensorInfo &info0, const TensorInfo &info1)
armnn::Optional< armnn::DataType > GetBiasTypeFromWeightsType(armnn::Optional< armnn::DataType > weightsType)
DataType
Definition: Types.hpp:32
#define ARMNN_ASSERT_MSG(COND, MSG)
Definition: Assert.hpp:15
ShapesAreBroadcastCompatible(const TensorInfo &in0, const TensorInfo &in1, const TensorInfo &out)
int32_t GetQuantizationOffset() const
Definition: Tensor.cpp:470
BiasAndWeightsTypesCompatible(const TensorInfo &info, const Container &c)
float GetQuantizationScale() const
Definition: Tensor.cpp:453
DataType GetDataType() const
Definition: Tensor.hpp:194
BiasAndWeightsTypesMatch(const TensorInfo &biases, const TensorInfo &weights)
ShapesAreSameRank(const TensorInfo &info0, const TensorInfo &info1)
bool AllTypesAreEqualImpl(T)
unsigned int CalcInputSize(const TensorShape &in, const TensorShape &out, unsigned int idx)
EmptyOptional is used to initialize the Optional class in case we want to have default value for an O...
Definition: Optional.hpp:32
unsigned int GetNumDimensions() const
Function that returns the tensor rank.
Definition: Tensor.cpp:175
TensorNumDimensionsAreCorrect(const TensorInfo &info, unsigned int expectedNumDimensions)
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:191
bool IsQuantized() const
Definition: Tensor.cpp:496
bool CheckSupportRule(F rule, Optional< std::string &> reasonIfUnsupported, const char *reason)
unsigned int GetNumElements() const
Definition: Tensor.hpp:192