// // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include #include "GpuFsaWorkloadFactory.hpp" #include "GpuFsaBackendId.hpp" #include "GpuFsaTensorHandle.hpp" namespace armnn { namespace { static const BackendId s_Id{GpuFsaBackendId()}; } template std::unique_ptr GpuFsaWorkloadFactory::MakeWorkload(const QueueDescriptorType& /*descriptor*/, const WorkloadInfo& /*info*/) const { return nullptr; } template bool IsDataType(const WorkloadInfo& info) { auto checkType = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == ArmnnType;}; auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkType); if (it != std::end(info.m_InputTensorInfos)) { return true; } it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkType); if (it != std::end(info.m_OutputTensorInfos)) { return true; } return false; } GpuFsaWorkloadFactory::GpuFsaWorkloadFactory(const std::shared_ptr& memoryManager) : m_MemoryManager(memoryManager) { } GpuFsaWorkloadFactory::GpuFsaWorkloadFactory() : m_MemoryManager(new GpuFsaMemoryManager()) { } const BackendId& GpuFsaWorkloadFactory::GetBackendId() const { return s_Id; } bool GpuFsaWorkloadFactory::IsLayerSupported(const Layer& layer, Optional dataType, std::string& outReasonIfUnsupported) { return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported); } std::unique_ptr GpuFsaWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo, const bool /*isMemoryManaged*/) const { std::unique_ptr tensorHandle = std::make_unique(tensorInfo); tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup()); return tensorHandle; } std::unique_ptr GpuFsaWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout, const bool /*isMemoryManaged*/) const { std::unique_ptr tensorHandle = std::make_unique(tensorInfo, dataLayout); tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup()); return tensorHandle; } std::unique_ptr GpuFsaWorkloadFactory::CreateWorkload(LayerType /*type*/, const QueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const { return nullptr; } } // namespace armnn