diff options
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 150 | ||||
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.cpp | 11 | ||||
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.hpp | 3 |
3 files changed, 15 insertions, 149 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index d42404d25b..187cc01c77 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -14,6 +14,7 @@ #include <armnn/Descriptors.hpp> #include <backendsCommon/BackendRegistry.hpp> +#include <backendsCommon/LayerSupportRules.hpp> #include <backendsCommon/test/WorkloadTestUtils.hpp> #include <boost/core/ignore_unused.hpp> @@ -65,155 +66,6 @@ std::string CreateIncorrectDimensionsErrorMsg(unsigned int expected, } // anonymous namespace -namespace -{ -template<typename F> -bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason) -{ - bool supported = rule(); - if (!supported && reason) - { - reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line - } - return supported; -} - -struct Rule -{ - bool operator()() const - { - return m_Res; - } - - bool m_Res = true; -}; - -template<typename T> -bool AllTypesAreEqualImpl(T t) -{ - return true; -} - -template<typename T, typename... Rest> -bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest) -{ - static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo"); - - return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...); -} - -struct TypesAreEqual : public Rule -{ - template<typename ... Ts> - TypesAreEqual(const Ts&... ts) - { - m_Res = AllTypesAreEqualImpl(ts...); - } -}; - -struct QuantizationParametersAreEqual : public Rule -{ - QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1) - { - m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() && - info0.GetQuantizationOffset() == info1.GetQuantizationOffset(); - } -}; - -struct TypeAnyOf : public Rule -{ - template<typename Container> - TypeAnyOf(const TensorInfo& info, const Container& c) - { - m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt) - { - return dt == info.GetDataType(); - }); - } -}; - -struct TypeIs : public Rule -{ - TypeIs(const TensorInfo& info, DataType dt) - { - m_Res = dt == info.GetDataType(); - } -}; - -struct BiasAndWeightsTypesMatch : public Rule -{ - BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights) - { - m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value(); - } -}; - -struct BiasAndWeightsTypesCompatible : public Rule -{ - template<typename Container> - BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c) - { - m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt) - { - return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value(); - }); - } -}; - -struct ShapesAreSameRank : public Rule -{ - ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1) - { - m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions(); - } -}; - -struct ShapesAreSameTotalSize : public Rule -{ - ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1) - { - m_Res = info0.GetNumElements() == info1.GetNumElements(); - } -}; - -struct ShapesAreBroadcastCompatible : public Rule -{ - unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx) - { - unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions(); - unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset]; - return sizeIn; - } - - ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out) - { - const TensorShape& shape0 = in0.GetShape(); - const TensorShape& shape1 = in1.GetShape(); - const TensorShape& outShape = out.GetShape(); - - for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++) - { - unsigned int sizeOut = outShape[i]; - unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i); - unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i); - - m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) && - ((sizeIn1 == sizeOut) || (sizeIn1 == 1)); - } - } -}; - -struct TensorNumDimensionsAreCorrect : public Rule -{ - TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions) - { - m_Res = info.GetNumDimensions() == expectedNumDimensions; - } -}; - -} // namespace - - bool RefLayerSupport::IsActivationSupported(const TensorInfo& input, const TensorInfo& output, const ActivationDescriptor& descriptor, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 240acecbad..fff2fd2694 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -5,6 +5,7 @@ #include <Layer.hpp> #include <backendsCommon/CpuTensorHandle.hpp> #include <backendsCommon/MemCopyWorkload.hpp> +#include <backendsCommon/MemImportWorkload.hpp> #include <backendsCommon/MakeWorkloadHelper.hpp> #include "RefWorkloadFactory.hpp" #include "RefBackendId.hpp" @@ -250,6 +251,16 @@ std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMemCopy(const MemCop return std::make_unique<CopyMemGenericWorkload>(descriptor, info); } +std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + if (descriptor.m_Inputs.empty()) + { + throw InvalidArgumentException("RefWorkloadFactory: CreateMemImport() expected an input tensor."); + } + return std::make_unique<ImportMemGenericWorkload>(descriptor, info); +} + std::unique_ptr<IWorkload> RefWorkloadFactory::CreateResize(const ResizeQueueDescriptor& descriptor, const WorkloadInfo& info) const { diff --git a/src/backends/reference/RefWorkloadFactory.hpp b/src/backends/reference/RefWorkloadFactory.hpp index b012fbc6f6..314e11788e 100644 --- a/src/backends/reference/RefWorkloadFactory.hpp +++ b/src/backends/reference/RefWorkloadFactory.hpp @@ -110,6 +110,9 @@ public: std::unique_ptr<IWorkload> CreateMemCopy(const MemCopyQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + std::unique_ptr<IWorkload> CreateMemImport(const MemImportQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; + std::unique_ptr<IWorkload> CreateResize(const ResizeQueueDescriptor& descriptor, const WorkloadInfo& info) const override; |