aboutsummaryrefslogtreecommitdiff
path: root/src/backends/gpuFsa/GpuFsaWorkloadFactory.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/gpuFsa/GpuFsaWorkloadFactory.cpp')
-rw-r--r--src/backends/gpuFsa/GpuFsaWorkloadFactory.cpp58
1 files changed, 53 insertions, 5 deletions
diff --git a/src/backends/gpuFsa/GpuFsaWorkloadFactory.cpp b/src/backends/gpuFsa/GpuFsaWorkloadFactory.cpp
index 6d13879f51..faa0d38386 100644
--- a/src/backends/gpuFsa/GpuFsaWorkloadFactory.cpp
+++ b/src/backends/gpuFsa/GpuFsaWorkloadFactory.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -9,6 +9,11 @@
#include "GpuFsaBackendId.hpp"
#include "GpuFsaTensorHandle.hpp"
+#include "workloads/GpuFsaConstantWorkload.hpp"
+#include "workloads/GpuFsaPreCompiledWorkload.hpp"
+
+#include <armnn/backends/MemCopyWorkload.hpp>
+
namespace armnn
{
@@ -43,11 +48,13 @@ bool IsDataType(const WorkloadInfo& info)
GpuFsaWorkloadFactory::GpuFsaWorkloadFactory(const std::shared_ptr<GpuFsaMemoryManager>& memoryManager)
: m_MemoryManager(memoryManager)
{
+ InitializeCLCompileContext();
}
GpuFsaWorkloadFactory::GpuFsaWorkloadFactory()
: m_MemoryManager(new GpuFsaMemoryManager())
{
+ InitializeCLCompileContext();
}
const BackendId& GpuFsaWorkloadFactory::GetBackendId() const
@@ -81,11 +88,52 @@ std::unique_ptr<ITensorHandle> GpuFsaWorkloadFactory::CreateTensorHandle(const T
return tensorHandle;
}
-std::unique_ptr<IWorkload> GpuFsaWorkloadFactory::CreateWorkload(LayerType /*type*/,
- const QueueDescriptor& /*descriptor*/,
- const WorkloadInfo& /*info*/) const
+
+void GpuFsaWorkloadFactory::InitializeCLCompileContext() {
+ // Initialize our m_CLCompileContext using default device and context
+ auto context = arm_compute::CLKernelLibrary::get().context();
+ auto device = arm_compute::CLKernelLibrary::get().get_device();
+ m_CLCompileContext = arm_compute::CLCompileContext(context, device);
+}
+
+std::unique_ptr<IWorkload> GpuFsaWorkloadFactory::CreateWorkload(LayerType type,
+ const QueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
{
- return nullptr;
+ switch(type)
+ {
+ case LayerType::Constant :
+ {
+ auto constQueueDescriptor = PolymorphicDowncast<const ConstantQueueDescriptor*>(&descriptor);
+ return std::make_unique<GpuFsaConstantWorkload>(*constQueueDescriptor, info, m_CLCompileContext);
+ }
+ case LayerType::Input :
+ {
+ auto inputQueueDescriptor = PolymorphicDowncast<const InputQueueDescriptor*>(&descriptor);
+ return std::make_unique<CopyMemGenericWorkload>(*inputQueueDescriptor, info);
+ }
+ case LayerType::Output :
+ {
+ auto outputQueueDescriptor = PolymorphicDowncast<const OutputQueueDescriptor*>(&descriptor);
+ return std::make_unique<CopyMemGenericWorkload>(*outputQueueDescriptor, info);
+ }
+ case LayerType::MemCopy :
+ {
+ auto memCopyQueueDescriptor = PolymorphicDowncast<const MemCopyQueueDescriptor*>(&descriptor);
+ if (memCopyQueueDescriptor->m_Inputs.empty() || !memCopyQueueDescriptor->m_Inputs[0])
+ {
+ throw InvalidArgumentException("GpuFsaWorkloadFactory: Invalid null input for MemCopy workload");
+ }
+ return std::make_unique<CopyMemGenericWorkload>(*memCopyQueueDescriptor, info);
+ }
+ case LayerType::PreCompiled :
+ {
+ auto precompiledQueueDescriptor = PolymorphicDowncast<const PreCompiledQueueDescriptor*>(&descriptor);
+ return std::make_unique<GpuFsaPreCompiledWorkload>(*precompiledQueueDescriptor, info);
+ }
+ default :
+ return nullptr;
+ }
}
} // namespace armnn \ No newline at end of file