aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/cl')
-rw-r--r--src/backends/cl/ClLayerSupport.cpp6
-rw-r--r--src/backends/cl/ClLayerSupport.hpp1
-rw-r--r--src/backends/cl/workloads/ClPermuteWorkload.cpp2
-rw-r--r--src/backends/cl/workloads/ClReshapeWorkload.cpp9
-rw-r--r--src/backends/cl/workloads/ClReshapeWorkload.hpp3
5 files changed, 17 insertions, 4 deletions
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp
index f8cc5074b3..ffe68a33d0 100644
--- a/src/backends/cl/ClLayerSupport.cpp
+++ b/src/backends/cl/ClLayerSupport.cpp
@@ -46,6 +46,7 @@
#include "workloads/ClPermuteWorkload.hpp"
#include "workloads/ClPooling2dWorkload.hpp"
#include "workloads/ClPreluWorkload.hpp"
+#include "workloads/ClReshapeWorkload.hpp"
#include "workloads/ClResizeWorkload.hpp"
#include "workloads/ClRsqrtWorkload.hpp"
#include "workloads/ClQuantizedLstmWorkload.hpp"
@@ -670,13 +671,12 @@ bool ClLayerSupport::IsQuantizeSupported(const TensorInfo& input,
}
bool ClLayerSupport::IsReshapeSupported(const TensorInfo& input,
+ const TensorInfo& output,
const ReshapeDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
{
- ignore_unused(input);
ignore_unused(descriptor);
- ignore_unused(reasonIfUnsupported);
- return true;
+ FORWARD_WORKLOAD_VALIDATE_FUNC(ClReshapeWorkloadValidate, reasonIfUnsupported, input, output);
}
bool ClLayerSupport::IsResizeSupported(const TensorInfo& input,
diff --git a/src/backends/cl/ClLayerSupport.hpp b/src/backends/cl/ClLayerSupport.hpp
index 9371717013..819d086cb4 100644
--- a/src/backends/cl/ClLayerSupport.hpp
+++ b/src/backends/cl/ClLayerSupport.hpp
@@ -216,6 +216,7 @@ public:
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
bool IsReshapeSupported(const TensorInfo& input,
+ const TensorInfo& output,
const ReshapeDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
diff --git a/src/backends/cl/workloads/ClPermuteWorkload.cpp b/src/backends/cl/workloads/ClPermuteWorkload.cpp
index dd495c8288..41bce1d4fa 100644
--- a/src/backends/cl/workloads/ClPermuteWorkload.cpp
+++ b/src/backends/cl/workloads/ClPermuteWorkload.cpp
@@ -23,7 +23,7 @@ arm_compute::Status ClPermuteWorkloadValidate(const TensorInfo& input,
const armnn::PermutationVector& mappings = descriptor.m_DimMappings;
return arm_compute::CLPermute::validate(&aclInputInfo, &aclOutputInfo,
- armcomputetensorutils::BuildArmComputePermutationVector(mappings));
+ armcomputetensorutils::BuildArmComputePermutationVector(mappings));
}
ClPermuteWorkload::ClPermuteWorkload(const PermuteQueueDescriptor& descriptor,
diff --git a/src/backends/cl/workloads/ClReshapeWorkload.cpp b/src/backends/cl/workloads/ClReshapeWorkload.cpp
index db1702a74f..d752290444 100644
--- a/src/backends/cl/workloads/ClReshapeWorkload.cpp
+++ b/src/backends/cl/workloads/ClReshapeWorkload.cpp
@@ -12,6 +12,15 @@
namespace armnn
{
+arm_compute::Status ClReshapeWorkloadValidate(const TensorInfo& input,
+ const TensorInfo& output)
+{
+ const arm_compute::TensorInfo aclInputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(input);
+ const arm_compute::TensorInfo aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
+
+ return arm_compute::CLReshapeLayer::validate(&aclInputInfo, &aclOutputInfo);
+}
+
ClReshapeWorkload::ClReshapeWorkload(const ReshapeQueueDescriptor& descriptor, const WorkloadInfo& info)
: BaseWorkload<ReshapeQueueDescriptor>(descriptor, info)
{
diff --git a/src/backends/cl/workloads/ClReshapeWorkload.hpp b/src/backends/cl/workloads/ClReshapeWorkload.hpp
index a7b464e719..62f5fccec8 100644
--- a/src/backends/cl/workloads/ClReshapeWorkload.hpp
+++ b/src/backends/cl/workloads/ClReshapeWorkload.hpp
@@ -12,6 +12,9 @@
namespace armnn
{
+arm_compute::Status ClReshapeWorkloadValidate(const TensorInfo& input,
+ const TensorInfo& output);
+
class ClReshapeWorkload : public BaseWorkload<ReshapeQueueDescriptor>
{
public: