aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/LayerSupportRules.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/LayerSupportRules.hpp')
-rw-r--r--src/backends/backendsCommon/LayerSupportRules.hpp185
1 files changed, 185 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/LayerSupportRules.hpp b/src/backends/backendsCommon/LayerSupportRules.hpp
new file mode 100644
index 0000000000..db3f38ccbb
--- /dev/null
+++ b/src/backends/backendsCommon/LayerSupportRules.hpp
@@ -0,0 +1,185 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <boost/assert.hpp>
+#include <algorithm>
+
+namespace armnn
+{
+
+namespace
+{
+
+inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
+{
+ if (!weightsType)
+ {
+ return weightsType;
+ }
+
+ switch(weightsType.value())
+ {
+ case armnn::DataType::Float16:
+ case armnn::DataType::Float32:
+ return weightsType;
+ case armnn::DataType::QuantisedAsymm8:
+ return armnn::DataType::Signed32;
+ case armnn::DataType::QuantisedSymm16:
+ return armnn::DataType::Signed32;
+ default:
+ BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
+ }
+ return armnn::EmptyOptional();
+}
+
+} //namespace
+
+template<typename F>
+bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
+{
+ bool supported = rule();
+ if (!supported && reason)
+ {
+ reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
+ }
+ return supported;
+}
+
+struct Rule
+{
+ bool operator()() const
+ {
+ return m_Res;
+ }
+
+ bool m_Res = true;
+};
+
+template<typename T>
+bool AllTypesAreEqualImpl(T t)
+{
+ return true;
+}
+
+template<typename T, typename... Rest>
+bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
+{
+ static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
+
+ return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
+}
+
+struct TypesAreEqual : public Rule
+{
+ template<typename ... Ts>
+ TypesAreEqual(const Ts&... ts)
+ {
+ m_Res = AllTypesAreEqualImpl(ts...);
+ }
+};
+
+struct QuantizationParametersAreEqual : public Rule
+{
+ QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
+ {
+ m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
+ info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
+ }
+};
+
+struct TypeAnyOf : public Rule
+{
+ template<typename Container>
+ TypeAnyOf(const TensorInfo& info, const Container& c)
+ {
+ m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
+ {
+ return dt == info.GetDataType();
+ });
+ }
+};
+
+struct TypeIs : public Rule
+{
+ TypeIs(const TensorInfo& info, DataType dt)
+ {
+ m_Res = dt == info.GetDataType();
+ }
+};
+
+struct BiasAndWeightsTypesMatch : public Rule
+{
+ BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
+ {
+ m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
+ }
+};
+
+struct BiasAndWeightsTypesCompatible : public Rule
+{
+ template<typename Container>
+ BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
+ {
+ m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
+ {
+ return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
+ });
+ }
+};
+
+struct ShapesAreSameRank : public Rule
+{
+ ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
+ {
+ m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
+ }
+};
+
+struct ShapesAreSameTotalSize : public Rule
+{
+ ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
+ {
+ m_Res = info0.GetNumElements() == info1.GetNumElements();
+ }
+};
+
+struct ShapesAreBroadcastCompatible : public Rule
+{
+ unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
+ {
+ unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
+ unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
+ return sizeIn;
+ }
+
+ ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
+ {
+ const TensorShape& shape0 = in0.GetShape();
+ const TensorShape& shape1 = in1.GetShape();
+ const TensorShape& outShape = out.GetShape();
+
+ for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
+ {
+ unsigned int sizeOut = outShape[i];
+ unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
+ unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
+
+ m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
+ ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
+ }
+ }
+};
+
+struct TensorNumDimensionsAreCorrect : public Rule
+{
+ TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
+ {
+ m_Res = info.GetNumDimensions() == expectedNumDimensions;
+ }
+};
+
+} //namespace armnn \ No newline at end of file