// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include #include namespace armnn { constexpr const char* NeonTensorHandleFactoryId() { return "Arm/Neon/TensorHandleFactory"; } const std::set paddingRequiredLayers { LayerType::ArgMinMax, LayerType::Concat, LayerType::Convolution2d, LayerType::DepthToSpace, LayerType::DepthwiseConvolution2d, LayerType::Dequantize, LayerType::FullyConnected, LayerType::Gather, LayerType::L2Normalization, LayerType::Lstm, LayerType::Mean, LayerType::Multiplication, LayerType::Normalization, LayerType::Permute, LayerType::Pooling2d, LayerType::Quantize, LayerType::QuantizedLstm, LayerType::Resize, LayerType::Stack, LayerType::Transpose, LayerType::TransposeConvolution2d }; class NeonTensorHandleFactory : public ITensorHandleFactory { public: NeonTensorHandleFactory(std::weak_ptr mgr) : m_MemoryManager(mgr), m_ImportFlags(static_cast(MemorySource::Malloc)), m_ExportFlags(static_cast(MemorySource::Malloc)) {} std::unique_ptr CreateSubTensorHandle(ITensorHandle& parent, const TensorShape& subTensorShape, const unsigned int* subTensorOrigin) const override; std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo) const override; std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout) const override; std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo, const bool IsMemoryManaged) const override; std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout, const bool IsMemoryManaged = true) const override; static const FactoryId& GetIdStatic(); const FactoryId& GetId() const override; bool SupportsSubTensors() const override; MemorySourceFlags GetExportFlags() const override; MemorySourceFlags GetImportFlags() const override; std::vector GetCapabilities(const IConnectableLayer* layer, const IConnectableLayer* connectedLayer, CapabilityClass capabilityClass) override; private: mutable std::shared_ptr m_MemoryManager; MemorySourceFlags m_ImportFlags; MemorySourceFlags m_ExportFlags; }; } // namespace armnn