// // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "TosaRefBackend.hpp" #include "TosaRefBackendId.hpp" #include "TosaRefWorkloadFactory.hpp" #include "TosaRefLayerSupport.hpp" #include "TosaRefTensorHandleFactory.hpp" #include #include #include #include #include #include #include #include namespace armnn { // Utility function to construct a valid Deleter for TosaSerializationHandler ptrs passed back to ArmNN template void DeleteAsType(const void* const blob) { delete static_cast(blob); } const BackendId& TosaRefBackend::GetIdStatic() { static const BackendId s_Id{TosaRefBackendId()}; return s_Id; } IBackendInternal::IWorkloadFactoryPtr TosaRefBackend::CreateWorkloadFactory( const IBackendInternal::IMemoryManagerSharedPtr& memoryManager) const { return std::make_unique(PolymorphicPointerDowncast(memoryManager)); } IBackendInternal::IWorkloadFactoryPtr TosaRefBackend::CreateWorkloadFactory( class TensorHandleFactoryRegistry& tensorHandleFactoryRegistry) const { auto memoryManager = std::make_shared(); tensorHandleFactoryRegistry.RegisterMemoryManager(memoryManager); auto factory = std::make_unique(memoryManager); // Register copy and import factory pair tensorHandleFactoryRegistry.RegisterCopyAndImportFactoryPair(factory->GetId(), factory->GetId()); // Register the factory tensorHandleFactoryRegistry.RegisterFactory(std::move(factory)); return std::make_unique(PolymorphicPointerDowncast(memoryManager)); } IBackendInternal::IBackendContextPtr TosaRefBackend::CreateBackendContext(const IRuntime::CreationOptions&) const { return IBackendContextPtr{}; } IBackendInternal::IBackendProfilingContextPtr TosaRefBackend::CreateBackendProfilingContext( const IRuntime::CreationOptions&, IBackendProfilingPtr&) { return IBackendProfilingContextPtr{}; } IBackendInternal::IMemoryManagerUniquePtr TosaRefBackend::CreateMemoryManager() const { return std::make_unique(); } IBackendInternal::ILayerSupportSharedPtr TosaRefBackend::GetLayerSupport() const { static ILayerSupportSharedPtr layerSupport{new TosaRefLayerSupport}; return layerSupport; } OptimizationViews TosaRefBackend::OptimizeSubgraphView(const SubgraphView& subgraph, const ModelOptions& modelOptions) const { OptimizationViews optimizationViews(modelOptions); auto handler = std::make_unique(); // A main block should only be added once. bool isMain = true; auto it = subgraph.endIConnectable(); while (it != subgraph.beginIConnectable()) { --it; Layer &base = *(PolymorphicDowncast(*it)); if(base.GetType() == armnn::LayerType::Input || base.GetType() == armnn::LayerType::Output) { continue; } tosa::TosaSerializationBasicBlock* mappings = GetTosaMappingFromLayer(&base, isMain); handler.get()->GetBlocks().push_back(mappings); if(isMain) { isMain = false; } } auto compiledBlob = std::make_unique(handler.release(), DeleteAsType); IConnectableLayer* preCompiledLayer = optimizationViews.GetINetwork()->AddPrecompiledLayer( PreCompiledDescriptor(subgraph.GetNumInputSlots(), subgraph.GetNumOutputSlots()), std::move(*compiledBlob), armnn::Optional(GetId()), "TOSA_Pre_Compiled_Layer"); // Copy the output tensor infos from sub-graph for (unsigned int i = 0; i < subgraph.GetNumOutputSlots(); i++) { preCompiledLayer->GetOutputSlot(i).SetTensorInfo(subgraph.GetIOutputSlot(i)->GetTensorInfo()); } optimizationViews.AddSubstitution({ std::move(subgraph), SubgraphView(preCompiledLayer) }); return optimizationViews; } std::vector TosaRefBackend::GetHandleFactoryPreferences() const { return std::vector { TosaRefTensorHandleFactory::GetIdStatic() }; } void TosaRefBackend::RegisterTensorHandleFactories(class TensorHandleFactoryRegistry& registry) { auto memoryManager = std::make_shared(); registry.RegisterMemoryManager(memoryManager); auto factory = std::make_unique(memoryManager); // Register copy and import factory pair registry.RegisterCopyAndImportFactoryPair(factory->GetId(), factory->GetId()); // Register the factory registry.RegisterFactory(std::move(factory)); } std::unique_ptr TosaRefBackend::GetDefaultAllocator() const { return std::make_unique(); } } // namespace armnn