aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin May <kevin.may@arm.com>2019-12-12 17:28:05 +0000
committerJim Flynn Arm <jim.flynn@arm.com>2019-12-17 19:54:02 +0000
commit93e023b9917f695e4e18f8c6ae8c4e1c84ba3b37 (patch)
treec149b43b719ae3afd7e80a931c48ec1f92f2778a
parent664e4f2dd6a5f7053f41af1ee2d04a4e490f7b4c (diff)
downloadarmnn-93e023b9917f695e4e18f8c6ae8c4e1c84ba3b37.tar.gz
IVGCVSW-4262 Use ACL Permute and Reshape Validate function in Neon and CL
!android-nn-driver:2487 Signed-off-by: Kevin May <kevin.may@arm.com> Change-Id: Ibabb73c0ae0df2e530a68398f75c76e6b80c0701
-rw-r--r--include/armnn/ILayerSupport.hpp1
-rw-r--r--src/armnn/LayerSupport.cpp3
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.cpp1
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.hpp1
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp2
-rw-r--r--src/backends/cl/ClLayerSupport.cpp10
-rw-r--r--src/backends/cl/ClLayerSupport.hpp1
-rw-r--r--src/backends/cl/workloads/ClPermuteWorkload.cpp16
-rw-r--r--src/backends/cl/workloads/ClPermuteWorkload.hpp4
-rw-r--r--src/backends/cl/workloads/ClReshapeWorkload.cpp9
-rw-r--r--src/backends/cl/workloads/ClReshapeWorkload.hpp3
-rw-r--r--src/backends/neon/NeonLayerSupport.cpp10
-rw-r--r--src/backends/neon/NeonLayerSupport.hpp1
-rw-r--r--src/backends/neon/workloads/NeonReshapeWorkload.cpp9
-rw-r--r--src/backends/neon/workloads/NeonReshapeWorkload.hpp5
-rw-r--r--src/backends/reference/RefLayerSupport.cpp2
-rw-r--r--src/backends/reference/RefLayerSupport.hpp1
17 files changed, 59 insertions, 20 deletions
diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp
index 54f4a2883b..fdd7fbb55c 100644
--- a/include/armnn/ILayerSupport.hpp
+++ b/include/armnn/ILayerSupport.hpp
@@ -276,6 +276,7 @@ public:
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
virtual bool IsReshapeSupported(const TensorInfo& input,
+ const TensorInfo& output,
const ReshapeDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
diff --git a/src/armnn/LayerSupport.cpp b/src/armnn/LayerSupport.cpp
index 7b9ada9150..79465cc151 100644
--- a/src/armnn/LayerSupport.cpp
+++ b/src/armnn/LayerSupport.cpp
@@ -531,11 +531,12 @@ bool IsPreluSupported(const BackendId& backend,
bool IsReshapeSupported(const BackendId& backend,
const TensorInfo& input,
+ const TensorInfo& output,
const ReshapeDescriptor& descriptor,
char* reasonIfUnsupported,
size_t reasonIfUnsupportedMaxLength)
{
- FORWARD_LAYER_SUPPORT_FUNC(backend, IsReshapeSupported, input, descriptor);
+ FORWARD_LAYER_SUPPORT_FUNC(backend, IsReshapeSupported, input, output, descriptor);
}
bool IsResizeSupported(const BackendId& backend,
diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp
index 9ffad7b8e2..604a388a1f 100644
--- a/src/backends/backendsCommon/LayerSupportBase.cpp
+++ b/src/backends/backendsCommon/LayerSupportBase.cpp
@@ -419,6 +419,7 @@ bool LayerSupportBase::IsQuantizedLstmSupported(const TensorInfo& input,
}
bool LayerSupportBase::IsReshapeSupported(const TensorInfo& input,
+ const TensorInfo& output,
const ReshapeDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
{
diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp
index e99cb67614..5568fca018 100644
--- a/src/backends/backendsCommon/LayerSupportBase.hpp
+++ b/src/backends/backendsCommon/LayerSupportBase.hpp
@@ -260,6 +260,7 @@ public:
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
bool IsReshapeSupported(const TensorInfo& input,
+ const TensorInfo& output,
const ReshapeDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 9901dcb7c1..16185b6bc7 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -779,7 +779,9 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
{
auto cLayer = boost::polymorphic_downcast<const ReshapeLayer*>(&layer);
const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
+ const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType),
+ OverrideDataType(output, dataType),
cLayer->GetParameters(),
reason);
break;
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp
index 7a1c573c0f..2998d293f9 100644
--- a/src/backends/cl/ClLayerSupport.cpp
+++ b/src/backends/cl/ClLayerSupport.cpp
@@ -45,6 +45,7 @@
#include "workloads/ClPermuteWorkload.hpp"
#include "workloads/ClPooling2dWorkload.hpp"
#include "workloads/ClPreluWorkload.hpp"
+#include "workloads/ClReshapeWorkload.hpp"
#include "workloads/ClResizeWorkload.hpp"
#include "workloads/ClRsqrtWorkload.hpp"
#include "workloads/ClQuantizedLstmWorkload.hpp"
@@ -603,9 +604,7 @@ bool ClLayerSupport::IsPermuteSupported(const TensorInfo& input,
const PermuteDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(input);
- ignore_unused(output);
- FORWARD_WORKLOAD_VALIDATE_FUNC(ClPermuteWorkloadValidate, reasonIfUnsupported, descriptor);
+ FORWARD_WORKLOAD_VALIDATE_FUNC(ClPermuteWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
}
bool ClLayerSupport::IsPooling2dSupported(const TensorInfo& input,
@@ -653,13 +652,12 @@ bool ClLayerSupport::IsQuantizeSupported(const TensorInfo& input,
}
bool ClLayerSupport::IsReshapeSupported(const TensorInfo& input,
+ const TensorInfo& output,
const ReshapeDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(input);
ignore_unused(descriptor);
- ignore_unused(reasonIfUnsupported);
- return true;
+ FORWARD_WORKLOAD_VALIDATE_FUNC(ClReshapeWorkloadValidate, reasonIfUnsupported, input, output);
}
bool ClLayerSupport::IsResizeSupported(const TensorInfo& input,
diff --git a/src/backends/cl/ClLayerSupport.hpp b/src/backends/cl/ClLayerSupport.hpp
index 219ce3b49e..b72f298928 100644
--- a/src/backends/cl/ClLayerSupport.hpp
+++ b/src/backends/cl/ClLayerSupport.hpp
@@ -210,6 +210,7 @@ public:
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
bool IsReshapeSupported(const TensorInfo& input,
+ const TensorInfo& output,
const ReshapeDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
diff --git a/src/backends/cl/workloads/ClPermuteWorkload.cpp b/src/backends/cl/workloads/ClPermuteWorkload.cpp
index bec80e55f8..41bce1d4fa 100644
--- a/src/backends/cl/workloads/ClPermuteWorkload.cpp
+++ b/src/backends/cl/workloads/ClPermuteWorkload.cpp
@@ -14,16 +14,16 @@
namespace armnn
{
-arm_compute::Status ClPermuteWorkloadValidate(const PermuteDescriptor& descriptor)
+arm_compute::Status ClPermuteWorkloadValidate(const TensorInfo& input,
+ const TensorInfo& output,
+ const PermuteDescriptor& descriptor)
{
- const armnn::PermutationVector& perm = descriptor.m_DimMappings;
+ const arm_compute::TensorInfo aclInputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(input);
+ const arm_compute::TensorInfo aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
+ const armnn::PermutationVector& mappings = descriptor.m_DimMappings;
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!perm.IsEqual({ 0U, 3U, 1U, 2U })
- && !perm.IsEqual({ 0U, 2U, 3U, 1U })
- && !perm.IsEqual({ 3U, 2U, 0U, 1U }),
- "Only [0, 3, 1, 2], [0, 2, 3, 1] and [3, 2, 0, 1] permutations are supported");
-
- return arm_compute::Status{};
+ return arm_compute::CLPermute::validate(&aclInputInfo, &aclOutputInfo,
+ armcomputetensorutils::BuildArmComputePermutationVector(mappings));
}
ClPermuteWorkload::ClPermuteWorkload(const PermuteQueueDescriptor& descriptor,
diff --git a/src/backends/cl/workloads/ClPermuteWorkload.hpp b/src/backends/cl/workloads/ClPermuteWorkload.hpp
index 58aa7ea0fa..8b5f4c6147 100644
--- a/src/backends/cl/workloads/ClPermuteWorkload.hpp
+++ b/src/backends/cl/workloads/ClPermuteWorkload.hpp
@@ -16,7 +16,9 @@
namespace armnn
{
-arm_compute::Status ClPermuteWorkloadValidate(const PermuteDescriptor& descriptor);
+arm_compute::Status ClPermuteWorkloadValidate(const TensorInfo& input,
+ const TensorInfo& output,
+ const PermuteDescriptor& descriptor);
class ClPermuteWorkload : public BaseWorkload<PermuteQueueDescriptor>
{
diff --git a/src/backends/cl/workloads/ClReshapeWorkload.cpp b/src/backends/cl/workloads/ClReshapeWorkload.cpp
index db1702a74f..d752290444 100644
--- a/src/backends/cl/workloads/ClReshapeWorkload.cpp
+++ b/src/backends/cl/workloads/ClReshapeWorkload.cpp
@@ -12,6 +12,15 @@
namespace armnn
{
+arm_compute::Status ClReshapeWorkloadValidate(const TensorInfo& input,
+ const TensorInfo& output)
+{
+ const arm_compute::TensorInfo aclInputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(input);
+ const arm_compute::TensorInfo aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
+
+ return arm_compute::CLReshapeLayer::validate(&aclInputInfo, &aclOutputInfo);
+}
+
ClReshapeWorkload::ClReshapeWorkload(const ReshapeQueueDescriptor& descriptor, const WorkloadInfo& info)
: BaseWorkload<ReshapeQueueDescriptor>(descriptor, info)
{
diff --git a/src/backends/cl/workloads/ClReshapeWorkload.hpp b/src/backends/cl/workloads/ClReshapeWorkload.hpp
index a7b464e719..62f5fccec8 100644
--- a/src/backends/cl/workloads/ClReshapeWorkload.hpp
+++ b/src/backends/cl/workloads/ClReshapeWorkload.hpp
@@ -12,6 +12,9 @@
namespace armnn
{
+arm_compute::Status ClReshapeWorkloadValidate(const TensorInfo& input,
+ const TensorInfo& output);
+
class ClReshapeWorkload : public BaseWorkload<ReshapeQueueDescriptor>
{
public:
diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp
index ed0f41a888..151487d8c7 100644
--- a/src/backends/neon/NeonLayerSupport.cpp
+++ b/src/backends/neon/NeonLayerSupport.cpp
@@ -44,6 +44,7 @@
#include "workloads/NeonPreluWorkload.hpp"
#include "workloads/NeonQuantizeWorkload.hpp"
#include "workloads/NeonQuantizedLstmWorkload.hpp"
+#include "workloads/NeonReshapeWorkload.hpp"
#include "workloads/NeonResizeWorkload.hpp"
#include "workloads/NeonRsqrtWorkload.hpp"
#include "workloads/NeonSliceWorkload.hpp"
@@ -594,14 +595,15 @@ bool NeonLayerSupport::IsQuantizedLstmSupported(const TensorInfo& input,
}
bool NeonLayerSupport::IsReshapeSupported(const TensorInfo& input,
+ const TensorInfo& output,
const ReshapeDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
{
ignore_unused(descriptor);
- return IsSupportedForDataTypeNeon(reasonIfUnsupported,
- input.GetDataType(),
- &TrueFunc<>,
- &TrueFunc<>);
+ FORWARD_WORKLOAD_VALIDATE_FUNC(NeonReshapeWorkloadValidate,
+ reasonIfUnsupported,
+ input,
+ output);
}
bool NeonLayerSupport::IsResizeSupported(const TensorInfo& input,
diff --git a/src/backends/neon/NeonLayerSupport.hpp b/src/backends/neon/NeonLayerSupport.hpp
index 5d4fbad97f..d9bb39950f 100644
--- a/src/backends/neon/NeonLayerSupport.hpp
+++ b/src/backends/neon/NeonLayerSupport.hpp
@@ -201,6 +201,7 @@ public:
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
bool IsReshapeSupported(const TensorInfo& input,
+ const TensorInfo& output,
const ReshapeDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
diff --git a/src/backends/neon/workloads/NeonReshapeWorkload.cpp b/src/backends/neon/workloads/NeonReshapeWorkload.cpp
index 7f2056c8e2..659bb94723 100644
--- a/src/backends/neon/workloads/NeonReshapeWorkload.cpp
+++ b/src/backends/neon/workloads/NeonReshapeWorkload.cpp
@@ -14,6 +14,15 @@
namespace armnn
{
+arm_compute::Status NeonReshapeWorkloadValidate(const TensorInfo& input,
+ const TensorInfo& output)
+{
+ const arm_compute::TensorInfo aclInputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(input);
+ const arm_compute::TensorInfo aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
+
+ return arm_compute::NEReshapeLayer::validate(&aclInputInfo, &aclOutputInfo);
+}
+
NeonReshapeWorkload::NeonReshapeWorkload(const ReshapeQueueDescriptor& descriptor,
const WorkloadInfo& info)
: BaseWorkload<ReshapeQueueDescriptor>(descriptor, info)
diff --git a/src/backends/neon/workloads/NeonReshapeWorkload.hpp b/src/backends/neon/workloads/NeonReshapeWorkload.hpp
index 2202463928..186a02ba26 100644
--- a/src/backends/neon/workloads/NeonReshapeWorkload.hpp
+++ b/src/backends/neon/workloads/NeonReshapeWorkload.hpp
@@ -6,7 +6,10 @@
#pragma once
#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+#include <neon/workloads/NeonWorkloadUtils.hpp>
+#include <armnn/TypesUtils.hpp>
#include <arm_compute/runtime/IFunction.h>
#include <memory>
@@ -14,6 +17,8 @@
namespace armnn
{
+arm_compute::Status NeonReshapeWorkloadValidate(const TensorInfo& input, const TensorInfo& output);
+
class NeonReshapeWorkload : public BaseWorkload<ReshapeQueueDescriptor>
{
public:
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 5a84d8ac78..9f3835658c 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -1415,9 +1415,11 @@ bool RefLayerSupport::IsQuantizeSupported(const TensorInfo& input,
}
bool RefLayerSupport::IsReshapeSupported(const TensorInfo& input,
+ const TensorInfo& output,
const ReshapeDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
{
+ ignore_unused(output);
ignore_unused(descriptor);
// Define supported output types.
std::array<DataType,5> supportedOutputTypes =
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index 04b355ee0a..d44bb8ff6d 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -231,6 +231,7 @@ public:
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
bool IsReshapeSupported(const TensorInfo& input,
+ const TensorInfo& output,
const ReshapeDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;