diff options
Diffstat (limited to 'src/backends/tosaReference/TosaRefWorkloadFactory.cpp')
-rw-r--r-- | src/backends/tosaReference/TosaRefWorkloadFactory.cpp | 126 |
1 files changed, 126 insertions, 0 deletions
diff --git a/src/backends/tosaReference/TosaRefWorkloadFactory.cpp b/src/backends/tosaReference/TosaRefWorkloadFactory.cpp new file mode 100644 index 0000000000..e6b5f9e934 --- /dev/null +++ b/src/backends/tosaReference/TosaRefWorkloadFactory.cpp @@ -0,0 +1,126 @@ +// +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// +#include <Layer.hpp> +#include <armnn/backends/MemCopyWorkload.hpp> +#include <backendsCommon/MemImportWorkload.hpp> +#include <backendsCommon/MakeWorkloadHelper.hpp> +#include <armnn/backends/TensorHandle.hpp> +#include "TosaRefWorkloadFactory.hpp" +#include "TosaRefBackendId.hpp" +#include "workloads/TosaRefWorkloads.hpp" +#include "TosaRefTensorHandle.hpp" +#include "TosaRefWorkloadFactory.hpp" + + +namespace armnn +{ + +namespace +{ +static const BackendId s_Id{TosaRefBackendId()}; +} +template <typename F32Workload, typename U8Workload, typename QueueDescriptorType> +std::unique_ptr<IWorkload> TosaRefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor, + const WorkloadInfo& info) const +{ + return MakeWorkloadHelper<NullWorkload, F32Workload, U8Workload, NullWorkload, NullWorkload, NullWorkload> + (descriptor, info); +} + +template <DataType ArmnnType> +bool IsDataType(const WorkloadInfo& info) +{ + auto checkType = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == ArmnnType;}; + auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkType); + if (it != std::end(info.m_InputTensorInfos)) + { + return true; + } + it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkType); + if (it != std::end(info.m_OutputTensorInfos)) + { + return true; + } + return false; +} + +TosaRefWorkloadFactory::TosaRefWorkloadFactory(const std::shared_ptr<TosaRefMemoryManager>& memoryManager) + : m_MemoryManager(memoryManager) +{ +} + +TosaRefWorkloadFactory::TosaRefWorkloadFactory() + : m_MemoryManager(new TosaRefMemoryManager()) +{ +} + +const BackendId& TosaRefWorkloadFactory::GetBackendId() const +{ + return s_Id; +} + +bool TosaRefWorkloadFactory::IsLayerSupported(const Layer& layer, + Optional<DataType> dataType, + std::string& outReasonIfUnsupported) +{ + return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported); +} + +bool TosaRefWorkloadFactory::IsLayerSupported(const IConnectableLayer& layer, + Optional<DataType> dataType, + std::string& outReasonIfUnsupported, + const ModelOptions& modelOptions) +{ + return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported, modelOptions); +} + +std::unique_ptr<ITensorHandle> TosaRefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo, + const bool isMemoryManaged) const +{ + if (isMemoryManaged) + { + return std::make_unique<TosaRefTensorHandle>(tensorInfo, m_MemoryManager); + } + else + { + return std::make_unique<TosaRefTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc)); + } +} + +std::unique_ptr<ITensorHandle> TosaRefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo, + DataLayout dataLayout, + const bool isMemoryManaged) const +{ + // For TosaRef it is okay to make the TensorHandle memory managed as it can also store a pointer + // to unmanaged memory. This also ensures memory alignment. + IgnoreUnused(isMemoryManaged, dataLayout); + + if (isMemoryManaged) + { + return std::make_unique<TosaRefTensorHandle>(tensorInfo, m_MemoryManager); + } + else + { + return std::make_unique<TosaRefTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc)); + } +} + +std::unique_ptr<IWorkload> TosaRefWorkloadFactory::CreateWorkload(LayerType type, + const QueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + switch(type) + { + case LayerType::PreCompiled: + { + auto precompiledQueueDescriptor = PolymorphicDowncast<const PreCompiledQueueDescriptor*>(&descriptor); + return std::make_unique<TosaRefPreCompiledWorkload>(*precompiledQueueDescriptor, info); + } + default: + return nullptr; + } +} + +} // namespace armnn |