aboutsummaryrefslogtreecommitdiff
path: root/src/backends/tosaReference/TosaRefWorkloadFactory.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/tosaReference/TosaRefWorkloadFactory.cpp')
-rw-r--r--src/backends/tosaReference/TosaRefWorkloadFactory.cpp126
1 files changed, 126 insertions, 0 deletions
diff --git a/src/backends/tosaReference/TosaRefWorkloadFactory.cpp b/src/backends/tosaReference/TosaRefWorkloadFactory.cpp
new file mode 100644
index 0000000000..e6b5f9e934
--- /dev/null
+++ b/src/backends/tosaReference/TosaRefWorkloadFactory.cpp
@@ -0,0 +1,126 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#include <Layer.hpp>
+#include <armnn/backends/MemCopyWorkload.hpp>
+#include <backendsCommon/MemImportWorkload.hpp>
+#include <backendsCommon/MakeWorkloadHelper.hpp>
+#include <armnn/backends/TensorHandle.hpp>
+#include "TosaRefWorkloadFactory.hpp"
+#include "TosaRefBackendId.hpp"
+#include "workloads/TosaRefWorkloads.hpp"
+#include "TosaRefTensorHandle.hpp"
+#include "TosaRefWorkloadFactory.hpp"
+
+
+namespace armnn
+{
+
+namespace
+{
+static const BackendId s_Id{TosaRefBackendId()};
+}
+template <typename F32Workload, typename U8Workload, typename QueueDescriptorType>
+std::unique_ptr<IWorkload> TosaRefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
+ const WorkloadInfo& info) const
+{
+ return MakeWorkloadHelper<NullWorkload, F32Workload, U8Workload, NullWorkload, NullWorkload, NullWorkload>
+ (descriptor, info);
+}
+
+template <DataType ArmnnType>
+bool IsDataType(const WorkloadInfo& info)
+{
+ auto checkType = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == ArmnnType;};
+ auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkType);
+ if (it != std::end(info.m_InputTensorInfos))
+ {
+ return true;
+ }
+ it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkType);
+ if (it != std::end(info.m_OutputTensorInfos))
+ {
+ return true;
+ }
+ return false;
+}
+
+TosaRefWorkloadFactory::TosaRefWorkloadFactory(const std::shared_ptr<TosaRefMemoryManager>& memoryManager)
+ : m_MemoryManager(memoryManager)
+{
+}
+
+TosaRefWorkloadFactory::TosaRefWorkloadFactory()
+ : m_MemoryManager(new TosaRefMemoryManager())
+{
+}
+
+const BackendId& TosaRefWorkloadFactory::GetBackendId() const
+{
+ return s_Id;
+}
+
+bool TosaRefWorkloadFactory::IsLayerSupported(const Layer& layer,
+ Optional<DataType> dataType,
+ std::string& outReasonIfUnsupported)
+{
+ return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported);
+}
+
+bool TosaRefWorkloadFactory::IsLayerSupported(const IConnectableLayer& layer,
+ Optional<DataType> dataType,
+ std::string& outReasonIfUnsupported,
+ const ModelOptions& modelOptions)
+{
+ return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported, modelOptions);
+}
+
+std::unique_ptr<ITensorHandle> TosaRefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
+ const bool isMemoryManaged) const
+{
+ if (isMemoryManaged)
+ {
+ return std::make_unique<TosaRefTensorHandle>(tensorInfo, m_MemoryManager);
+ }
+ else
+ {
+ return std::make_unique<TosaRefTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc));
+ }
+}
+
+std::unique_ptr<ITensorHandle> TosaRefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
+ DataLayout dataLayout,
+ const bool isMemoryManaged) const
+{
+ // For TosaRef it is okay to make the TensorHandle memory managed as it can also store a pointer
+ // to unmanaged memory. This also ensures memory alignment.
+ IgnoreUnused(isMemoryManaged, dataLayout);
+
+ if (isMemoryManaged)
+ {
+ return std::make_unique<TosaRefTensorHandle>(tensorInfo, m_MemoryManager);
+ }
+ else
+ {
+ return std::make_unique<TosaRefTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc));
+ }
+}
+
+std::unique_ptr<IWorkload> TosaRefWorkloadFactory::CreateWorkload(LayerType type,
+ const QueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ switch(type)
+ {
+ case LayerType::PreCompiled:
+ {
+ auto precompiledQueueDescriptor = PolymorphicDowncast<const PreCompiledQueueDescriptor*>(&descriptor);
+ return std::make_unique<TosaRefPreCompiledWorkload>(*precompiledQueueDescriptor, info);
+ }
+ default:
+ return nullptr;
+ }
+}
+
+} // namespace armnn