diff options
Diffstat (limited to 'src/backends/gpuFsa/GpuFsaWorkloadFactory.cpp')
-rw-r--r-- | src/backends/gpuFsa/GpuFsaWorkloadFactory.cpp | 58 |
1 files changed, 53 insertions, 5 deletions
diff --git a/src/backends/gpuFsa/GpuFsaWorkloadFactory.cpp b/src/backends/gpuFsa/GpuFsaWorkloadFactory.cpp index 6d13879f51..faa0d38386 100644 --- a/src/backends/gpuFsa/GpuFsaWorkloadFactory.cpp +++ b/src/backends/gpuFsa/GpuFsaWorkloadFactory.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2022-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -9,6 +9,11 @@ #include "GpuFsaBackendId.hpp" #include "GpuFsaTensorHandle.hpp" +#include "workloads/GpuFsaConstantWorkload.hpp" +#include "workloads/GpuFsaPreCompiledWorkload.hpp" + +#include <armnn/backends/MemCopyWorkload.hpp> + namespace armnn { @@ -43,11 +48,13 @@ bool IsDataType(const WorkloadInfo& info) GpuFsaWorkloadFactory::GpuFsaWorkloadFactory(const std::shared_ptr<GpuFsaMemoryManager>& memoryManager) : m_MemoryManager(memoryManager) { + InitializeCLCompileContext(); } GpuFsaWorkloadFactory::GpuFsaWorkloadFactory() : m_MemoryManager(new GpuFsaMemoryManager()) { + InitializeCLCompileContext(); } const BackendId& GpuFsaWorkloadFactory::GetBackendId() const @@ -81,11 +88,52 @@ std::unique_ptr<ITensorHandle> GpuFsaWorkloadFactory::CreateTensorHandle(const T return tensorHandle; } -std::unique_ptr<IWorkload> GpuFsaWorkloadFactory::CreateWorkload(LayerType /*type*/, - const QueueDescriptor& /*descriptor*/, - const WorkloadInfo& /*info*/) const + +void GpuFsaWorkloadFactory::InitializeCLCompileContext() { + // Initialize our m_CLCompileContext using default device and context + auto context = arm_compute::CLKernelLibrary::get().context(); + auto device = arm_compute::CLKernelLibrary::get().get_device(); + m_CLCompileContext = arm_compute::CLCompileContext(context, device); +} + +std::unique_ptr<IWorkload> GpuFsaWorkloadFactory::CreateWorkload(LayerType type, + const QueueDescriptor& descriptor, + const WorkloadInfo& info) const { - return nullptr; + switch(type) + { + case LayerType::Constant : + { + auto constQueueDescriptor = PolymorphicDowncast<const ConstantQueueDescriptor*>(&descriptor); + return std::make_unique<GpuFsaConstantWorkload>(*constQueueDescriptor, info, m_CLCompileContext); + } + case LayerType::Input : + { + auto inputQueueDescriptor = PolymorphicDowncast<const InputQueueDescriptor*>(&descriptor); + return std::make_unique<CopyMemGenericWorkload>(*inputQueueDescriptor, info); + } + case LayerType::Output : + { + auto outputQueueDescriptor = PolymorphicDowncast<const OutputQueueDescriptor*>(&descriptor); + return std::make_unique<CopyMemGenericWorkload>(*outputQueueDescriptor, info); + } + case LayerType::MemCopy : + { + auto memCopyQueueDescriptor = PolymorphicDowncast<const MemCopyQueueDescriptor*>(&descriptor); + if (memCopyQueueDescriptor->m_Inputs.empty() || !memCopyQueueDescriptor->m_Inputs[0]) + { + throw InvalidArgumentException("GpuFsaWorkloadFactory: Invalid null input for MemCopy workload"); + } + return std::make_unique<CopyMemGenericWorkload>(*memCopyQueueDescriptor, info); + } + case LayerType::PreCompiled : + { + auto precompiledQueueDescriptor = PolymorphicDowncast<const PreCompiledQueueDescriptor*>(&descriptor); + return std::make_unique<GpuFsaPreCompiledWorkload>(*precompiledQueueDescriptor, info); + } + default : + return nullptr; + } } } // namespace armnn
\ No newline at end of file |