diff options
Diffstat (limited to 'src/backends/backendsCommon/LayerSupportRules.hpp')
-rw-r--r-- | src/backends/backendsCommon/LayerSupportRules.hpp | 185 |
1 files changed, 185 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/LayerSupportRules.hpp b/src/backends/backendsCommon/LayerSupportRules.hpp new file mode 100644 index 0000000000..db3f38ccbb --- /dev/null +++ b/src/backends/backendsCommon/LayerSupportRules.hpp @@ -0,0 +1,185 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <boost/assert.hpp> +#include <algorithm> + +namespace armnn +{ + +namespace +{ + +inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType) +{ + if (!weightsType) + { + return weightsType; + } + + switch(weightsType.value()) + { + case armnn::DataType::Float16: + case armnn::DataType::Float32: + return weightsType; + case armnn::DataType::QuantisedAsymm8: + return armnn::DataType::Signed32; + case armnn::DataType::QuantisedSymm16: + return armnn::DataType::Signed32; + default: + BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type."); + } + return armnn::EmptyOptional(); +} + +} //namespace + +template<typename F> +bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason) +{ + bool supported = rule(); + if (!supported && reason) + { + reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line + } + return supported; +} + +struct Rule +{ + bool operator()() const + { + return m_Res; + } + + bool m_Res = true; +}; + +template<typename T> +bool AllTypesAreEqualImpl(T t) +{ + return true; +} + +template<typename T, typename... Rest> +bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest) +{ + static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo"); + + return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...); +} + +struct TypesAreEqual : public Rule +{ + template<typename ... Ts> + TypesAreEqual(const Ts&... ts) + { + m_Res = AllTypesAreEqualImpl(ts...); + } +}; + +struct QuantizationParametersAreEqual : public Rule +{ + QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1) + { + m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() && + info0.GetQuantizationOffset() == info1.GetQuantizationOffset(); + } +}; + +struct TypeAnyOf : public Rule +{ + template<typename Container> + TypeAnyOf(const TensorInfo& info, const Container& c) + { + m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt) + { + return dt == info.GetDataType(); + }); + } +}; + +struct TypeIs : public Rule +{ + TypeIs(const TensorInfo& info, DataType dt) + { + m_Res = dt == info.GetDataType(); + } +}; + +struct BiasAndWeightsTypesMatch : public Rule +{ + BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights) + { + m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value(); + } +}; + +struct BiasAndWeightsTypesCompatible : public Rule +{ + template<typename Container> + BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c) + { + m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt) + { + return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value(); + }); + } +}; + +struct ShapesAreSameRank : public Rule +{ + ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1) + { + m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions(); + } +}; + +struct ShapesAreSameTotalSize : public Rule +{ + ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1) + { + m_Res = info0.GetNumElements() == info1.GetNumElements(); + } +}; + +struct ShapesAreBroadcastCompatible : public Rule +{ + unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx) + { + unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions(); + unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset]; + return sizeIn; + } + + ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out) + { + const TensorShape& shape0 = in0.GetShape(); + const TensorShape& shape1 = in1.GetShape(); + const TensorShape& outShape = out.GetShape(); + + for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++) + { + unsigned int sizeOut = outShape[i]; + unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i); + unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i); + + m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) && + ((sizeIn1 == sizeOut) || (sizeIn1 == 1)); + } + } +}; + +struct TensorNumDimensionsAreCorrect : public Rule +{ + TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions) + { + m_Res = info.GetNumDimensions() == expectedNumDimensions; + } +}; + +} //namespace armnn
\ No newline at end of file |