// // Copyright © 2021 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include #include #include #include namespace armnn { constexpr const char* ClImportTensorHandleFactoryId() { return "Arm/Cl/ImportTensorHandleFactory"; } /** * This factory creates ClImportTensorHandles that refer to imported memory tensors. */ class ClImportTensorHandleFactory : public ITensorHandleFactory { public: static const FactoryId m_Id; /** * Create a tensor handle factory for tensors that will be imported or exported. * * @param importFlags * @param exportFlags */ ClImportTensorHandleFactory(MemorySourceFlags importFlags, MemorySourceFlags exportFlags) : m_ImportFlags(importFlags) , m_ExportFlags(exportFlags) {} std::unique_ptr CreateSubTensorHandle(ITensorHandle& parent, const TensorShape& subTensorShape, const unsigned int* subTensorOrigin) const override; std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo) const override; std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout) const override; std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo, const bool IsMemoryManaged) const override; std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout, const bool IsMemoryManaged) const override; static const FactoryId& GetIdStatic(); const FactoryId& GetId() const override; bool SupportsSubTensors() const override; bool SupportsMapUnmap() const override; MemorySourceFlags GetExportFlags() const override; MemorySourceFlags GetImportFlags() const override; std::vector GetCapabilities(const IConnectableLayer* layer, const IConnectableLayer* connectedLayer, CapabilityClass capabilityClass) override; private: MemorySourceFlags m_ImportFlags; MemorySourceFlags m_ExportFlags; }; } // namespace armnn