diff options
Diffstat (limited to 'src/backends/cl')
-rw-r--r-- | src/backends/cl/ClLayerSupport.cpp | 18 | ||||
-rw-r--r-- | src/backends/cl/ClLayerSupport.hpp | 6 | ||||
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.cpp | 19 | ||||
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.hpp | 4 | ||||
-rw-r--r-- | src/backends/cl/test/ClCreateWorkloadTests.cpp | 5 |
5 files changed, 50 insertions, 2 deletions
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp index ec134a16e8..d79f6126a7 100644 --- a/src/backends/cl/ClLayerSupport.cpp +++ b/src/backends/cl/ClLayerSupport.cpp @@ -591,6 +591,24 @@ bool ClLayerSupport::IsReshapeSupported(const TensorInfo& input, return true; } +bool ClLayerSupport::IsResizeSupported(const TensorInfo& input, + const TensorInfo& output, + const ResizeDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported) const +{ + ignore_unused(output); + + if (descriptor.m_Method == ResizeMethod::Bilinear) + { + return IsSupportedForDataTypeCl(reasonIfUnsupported, + input.GetDataType(), + &TrueFunc<>, + &FalseFuncU8<>); + } + + return false; +} + bool ClLayerSupport::IsResizeBilinearSupported(const TensorInfo& input, const TensorInfo& output, Optional<std::string&> reasonIfUnsupported) const diff --git a/src/backends/cl/ClLayerSupport.hpp b/src/backends/cl/ClLayerSupport.hpp index 4d0f5bdfbb..1461f41691 100644 --- a/src/backends/cl/ClLayerSupport.hpp +++ b/src/backends/cl/ClLayerSupport.hpp @@ -198,6 +198,12 @@ public: const ReshapeDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + bool IsResizeSupported(const TensorInfo& input, + const TensorInfo& output, + const ResizeDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + + ARMNN_DEPRECATED_MSG("Use IsResizeSupported instead") bool IsResizeBilinearSupported(const TensorInfo& input, const TensorInfo& output, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index 4bce653462..c662a9db29 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -251,6 +251,25 @@ std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateMemCopy(const MemCopy return MakeWorkload<CopyMemGenericWorkload, CopyMemGenericWorkload>(descriptor, info); } +std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateResize(const ResizeQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + if (descriptor.m_Parameters.m_Method == ResizeMethod::Bilinear) + { + ResizeBilinearQueueDescriptor resizeBilinearDescriptor; + resizeBilinearDescriptor.m_Inputs = descriptor.m_Inputs; + resizeBilinearDescriptor.m_Outputs = descriptor.m_Outputs; + + resizeBilinearDescriptor.m_Parameters.m_DataLayout = descriptor.m_Parameters.m_DataLayout; + resizeBilinearDescriptor.m_Parameters.m_TargetWidth = descriptor.m_Parameters.m_TargetWidth; + resizeBilinearDescriptor.m_Parameters.m_TargetHeight = descriptor.m_Parameters.m_TargetHeight; + + return MakeWorkload<ClResizeBilinearFloatWorkload, NullWorkload>(resizeBilinearDescriptor, info); + } + + return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info); +} + std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateResizeBilinear( const ResizeBilinearQueueDescriptor& descriptor, const WorkloadInfo& info) const diff --git a/src/backends/cl/ClWorkloadFactory.hpp b/src/backends/cl/ClWorkloadFactory.hpp index 8c3e756c0d..32925f7f95 100644 --- a/src/backends/cl/ClWorkloadFactory.hpp +++ b/src/backends/cl/ClWorkloadFactory.hpp @@ -91,6 +91,10 @@ public: std::unique_ptr<IWorkload> CreateMemCopy(const MemCopyQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + std::unique_ptr<IWorkload> CreateResize(const ResizeQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; + + ARMNN_DEPRECATED_MSG("Use CreateResize instead") std::unique_ptr<IWorkload> CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor, const WorkloadInfo& info) const override; diff --git a/src/backends/cl/test/ClCreateWorkloadTests.cpp b/src/backends/cl/test/ClCreateWorkloadTests.cpp index d401701dda..aa1393f407 100644 --- a/src/backends/cl/test/ClCreateWorkloadTests.cpp +++ b/src/backends/cl/test/ClCreateWorkloadTests.cpp @@ -726,8 +726,9 @@ static void ClResizeBilinearWorkloadTest(DataLayout dataLayout) auto workload = CreateResizeBilinearWorkloadTest<ResizeBilinearWorkloadType, DataType>(factory, graph, dataLayout); // Checks that inputs/outputs are as we expect them (see definition of CreateResizeBilinearWorkloadTest). - ResizeBilinearQueueDescriptor queueDescriptor = workload->GetData(); - auto inputHandle = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Inputs[0]); + auto queueDescriptor = workload->GetData(); + + auto inputHandle = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Inputs[0]); auto outputHandle = boost::polymorphic_downcast<IClTensorHandle*>(queueDescriptor.m_Outputs[0]); switch (dataLayout) |