aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/ClWorkloadFactory.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/cl/ClWorkloadFactory.cpp')
-rw-r--r--src/backends/cl/ClWorkloadFactory.cpp69
1 files changed, 50 insertions, 19 deletions
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp
index 68d371388c..e1d8314d82 100644
--- a/src/backends/cl/ClWorkloadFactory.cpp
+++ b/src/backends/cl/ClWorkloadFactory.cpp
@@ -16,12 +16,13 @@
#include <arm_compute/runtime/CL/CLBufferAllocator.h>
#include <arm_compute/runtime/CL/CLScheduler.h>
-#include <backends/cl/workloads/ClWorkloads.hpp>
-
#include <backends/MemCopyWorkload.hpp>
-#include <backends/cl/ClTensorHandle.hpp>
#include <backends/aclCommon/memory/IPoolManager.hpp>
+
+#include <backends/cl/ClTensorHandle.hpp>
+#include <backends/cl/workloads/ClWorkloads.hpp>
+#include <backends/cl/workloads/ClWorkloadUtils.hpp>
#endif
#include <backends/MakeWorkloadHelper.hpp>
@@ -42,6 +43,36 @@ bool ClWorkloadFactory::IsLayerSupported(const Layer& layer,
#ifdef ARMCOMPUTECL_ENABLED
+template <typename FloatWorkload, typename Uint8Workload, typename QueueDescriptorType, typename... Args>
+std::unique_ptr<IWorkload> ClWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
+ const WorkloadInfo& info,
+ Args&&... args)
+{
+ try
+ {
+ return MakeWorkloadHelper<FloatWorkload, Uint8Workload>(descriptor, info, std::forward<Args>(args)...);
+ }
+ catch (const cl::Error& clError)
+ {
+ throw WrapClError(clError, CHECK_LOCATION());
+ }
+}
+
+template <typename Workload, typename QueueDescriptorType, typename... Args>
+std::unique_ptr<IWorkload> ClWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
+ const WorkloadInfo& info,
+ Args&&... args)
+{
+ try
+ {
+ return std::make_unique<Workload>(descriptor, info, std::forward<Args>(args)...);
+ }
+ catch (const cl::Error& clError)
+ {
+ throw WrapClError(clError, CHECK_LOCATION());
+ }
+}
+
ClWorkloadFactory::ClWorkloadFactory()
: m_MemoryManager(std::make_unique<arm_compute::CLBufferAllocator>())
{
@@ -100,26 +131,26 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateOutput(const OutputQueueDesc
std::unique_ptr<IWorkload> ClWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<ClActivationWorkload>(descriptor, info);
+ return MakeWorkload<ClActivationWorkload>(descriptor, info);
}
std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
return MakeWorkload<ClSoftmaxFloatWorkload, ClSoftmaxUint8Workload>(descriptor, info,
- m_MemoryManager.GetIntraLayerManager());
+ m_MemoryManager.GetIntraLayerManager());
}
std::unique_ptr<IWorkload> ClWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<ClSplitterWorkload>(descriptor, info);
+ return MakeWorkload<ClSplitterWorkload>(descriptor, info);
}
std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<ClMergerWorkload>(descriptor, info);
+ return MakeWorkload<ClMergerWorkload>(descriptor, info);
}
std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateFullyConnected(
@@ -132,25 +163,25 @@ std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateFullyConnected(
std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<ClPermuteWorkload>(descriptor, info);
+ return MakeWorkload<ClPermuteWorkload>(descriptor, info);
}
std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<ClPooling2dWorkload>(descriptor, info);
+ return MakeWorkload<ClPooling2dWorkload>(descriptor, info);
}
std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<ClConvolution2dWorkload>(descriptor, info, m_MemoryManager.GetIntraLayerManager());
+ return MakeWorkload<ClConvolution2dWorkload>(descriptor, info, m_MemoryManager.GetIntraLayerManager());
}
std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDepthwiseConvolution2d(
const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const
{
- return std::make_unique<ClDepthwiseConvolutionWorkload>(descriptor, info);
+ return MakeWorkload<ClDepthwiseConvolutionWorkload>(descriptor, info);
}
std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor,
@@ -162,13 +193,13 @@ std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateNormalization(const N
std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<ClAdditionWorkload>(descriptor, info);
+ return MakeWorkload<ClAdditionWorkload>(descriptor, info);
}
std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateMultiplication(
const MultiplicationQueueDescriptor& descriptor, const WorkloadInfo& info) const
{
- return std::make_unique<ClMultiplicationWorkload>(descriptor, info);
+ return MakeWorkload<ClMultiplicationWorkload>(descriptor, info);
}
std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateDivision(
@@ -180,7 +211,7 @@ std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateDivision(
std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<ClSubtractionWorkload>(descriptor, info);
+ return MakeWorkload<ClSubtractionWorkload>(descriptor, info);
}
std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateBatchNormalization(
@@ -223,13 +254,13 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateL2Normalization(const L2Norm
std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<ClConstantWorkload>(descriptor, info);
+ return MakeWorkload<ClConstantWorkload>(descriptor, info);
}
std::unique_ptr<IWorkload> ClWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<ClReshapeWorkload>(descriptor, info);
+ return MakeWorkload<ClReshapeWorkload>(descriptor, info);
}
std::unique_ptr<IWorkload> ClWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
@@ -248,14 +279,14 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConvertFp16ToFp32(
const ConvertFp16ToFp32QueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<ClConvertFp16ToFp32Workload>(descriptor, info);
+ return MakeWorkload<ClConvertFp16ToFp32Workload>(descriptor, info);
}
std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConvertFp32ToFp16(
const ConvertFp32ToFp16QueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<ClConvertFp32ToFp16Workload>(descriptor, info);
+ return MakeWorkload<ClConvertFp32ToFp16Workload>(descriptor, info);
}
std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor,
@@ -267,7 +298,7 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMean(const MeanQueueDescript
std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<ClPadWorkload>(descriptor, info);
+ return MakeWorkload<ClPadWorkload>(descriptor, info);
}
void ClWorkloadFactory::Finalize()