aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/RefLayerSupport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/RefLayerSupport.cpp')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp203
1 files changed, 191 insertions, 12 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 4b32a8938d..cdc6acae7f 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -9,11 +9,16 @@
#include <InternalTypes.hpp>
#include <LayerSupportCommon.hpp>
#include <armnn/Types.hpp>
+#include <armnn/Descriptors.hpp>
#include <backendsCommon/BackendRegistry.hpp>
#include <boost/core/ignore_unused.hpp>
+#include <vector>
+#include <algorithm>
+#include <array>
+
using namespace boost;
namespace armnn
@@ -41,17 +46,171 @@ bool IsSupportedForDataTypeRef(Optional<std::string&> reasonIfUnsupported,
} // anonymous namespace
+
+namespace
+{
+template<typename F>
+bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
+{
+ bool supported = rule();
+ if (!supported && reason)
+ {
+ reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
+ }
+ return supported;
+}
+
+struct Rule
+{
+ bool operator()() const
+ {
+ return m_Res;
+ }
+
+ bool m_Res = true;
+};
+
+template<class none = void>
+bool AllTypesAreEqualImpl()
+{
+ return true;
+}
+
+template<typename T, typename... Rest>
+bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
+{
+ static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
+
+ return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(rest...);
+}
+
+struct TypesAreEqual : public Rule
+{
+ template<typename ... Ts>
+ TypesAreEqual(const Ts&... ts)
+ {
+ m_Res = AllTypesAreEqualImpl(ts...);
+ }
+};
+
+struct QuantizationParametersAreEqual : public Rule
+{
+ QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
+ {
+ m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
+ info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
+ }
+};
+
+struct TypeAnyOf : public Rule
+{
+ template<typename Container>
+ TypeAnyOf(const TensorInfo& info, const Container& c)
+ {
+ m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
+ {
+ return dt == info.GetDataType();
+ });
+ }
+};
+
+struct ShapesAreSameRank : public Rule
+{
+ ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
+ {
+ m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
+ }
+};
+
+struct ShapesAreBroadcastCompatible : public Rule
+{
+ unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
+ {
+ unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
+ unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
+ return sizeIn;
+ }
+
+ ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
+ {
+ const TensorShape& shape0 = in0.GetShape();
+ const TensorShape& shape1 = in1.GetShape();
+ const TensorShape& outShape = out.GetShape();
+
+ for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
+ {
+ unsigned int sizeOut = outShape[i];
+ unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
+ unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
+
+ m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
+ ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
+ }
+ }
+};
+} // namespace
+
+
bool RefLayerSupport::IsActivationSupported(const TensorInfo& input,
const TensorInfo& output,
const ActivationDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(output);
- ignore_unused(descriptor);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+
+ // Define supported types.
+ std::array<DataType,2> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+ "Reference activation: input type not supported.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference activation: output type not supported.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+ "Reference activation: input and output types mismatched.");
+
+ supported &= CheckSupportRule(ShapesAreSameRank(input, output), reasonIfUnsupported,
+ "Reference activation: input and output shapes are of different rank.");
+
+
+ struct ActivationFunctionSupported : public Rule
+ {
+ ActivationFunctionSupported(const ActivationDescriptor& desc)
+ {
+ switch(desc.m_Function)
+ {
+ case ActivationFunction::Abs:
+ case ActivationFunction::BoundedReLu:
+ case ActivationFunction::LeakyReLu:
+ case ActivationFunction::Linear:
+ case ActivationFunction::ReLu:
+ case ActivationFunction::Sigmoid:
+ case ActivationFunction::SoftReLu:
+ case ActivationFunction::Sqrt:
+ case ActivationFunction::Square:
+ case ActivationFunction::TanH:
+ {
+ m_Res = true;
+ break;
+ }
+ default:
+ {
+ m_Res = false;
+ break;
+ }
+ }
+ }
+ };
+
+ // Function is supported
+ supported &= CheckSupportRule(ActivationFunctionSupported(descriptor), reasonIfUnsupported,
+ "Reference activation: function not supported.");
+
+ return supported;
}
bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
@@ -59,12 +218,32 @@ bool RefLayerSupport::IsAdditionSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(input1);
- ignore_unused(output);
- return IsSupportedForDataTypeRef(reasonIfUnsupported,
- input0.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ bool supported = true;
+
+ std::array<DataType,2> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
+ "Reference addition: input 0 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
+ "Reference addition: input 1 is not a supported type.");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference addition: output is not a supported type.");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
+ "Reference addition: input 0 and Input 1 types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, output), reasonIfUnsupported,
+ "Reference addition: input and output types are mismatched");
+
+ supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
+ "Reference addition: shapes are not suitable for implicit broadcast.");
+
+ return supported;
}
bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,