diff options
author | Jan Eilers <jan.eilers@arm.com> | 2019-08-16 10:28:37 +0100 |
---|---|---|
committer | David Monahan <david.monahan@arm.com> | 2019-08-20 10:41:18 +0100 |
commit | e9f0f0fdda42ba17085ba4d86e0c84ab68ed2f5a (patch) | |
tree | 5b3f1a2471d9ba0bd9d3698a8135649414ca328e /src/backends/cl/ClTensorHandleFactory.cpp | |
parent | 895339092fa9edc0aa59de0309f79bebacc3fa63 (diff) | |
download | armnn-e9f0f0fdda42ba17085ba4d86e0c84ab68ed2f5a.tar.gz |
IVGCVSW-3617 Add CL TensorHandleFactory
* Adds ClTensorHandleFactory
* Includes some refactoring of NeonTensorHandleFactory
Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Change-Id: I9f0874b1945050267c5ece970e6d9b200ed8a865
Diffstat (limited to 'src/backends/cl/ClTensorHandleFactory.cpp')
-rw-r--r-- | src/backends/cl/ClTensorHandleFactory.cpp | 91 |
1 files changed, 91 insertions, 0 deletions
diff --git a/src/backends/cl/ClTensorHandleFactory.cpp b/src/backends/cl/ClTensorHandleFactory.cpp new file mode 100644 index 0000000000..47e36b3c76 --- /dev/null +++ b/src/backends/cl/ClTensorHandleFactory.cpp @@ -0,0 +1,91 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + + +#include "ClTensorHandleFactory.hpp" +#include "ClTensorHandle.hpp" + +#include <arm_compute/runtime/CL/CLTensor.h> +#include <arm_compute/core/Coordinates.h> +#include <arm_compute/runtime/CL/CLSubTensor.h> + +#include <boost/polymorphic_cast.hpp> + + +namespace armnn +{ + +using FactoryId = ITensorHandleFactory::FactoryId; + +std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateSubTensorHandle(ITensorHandle& parent, + const TensorShape& subTensorShape, + const unsigned int* subTensorOrigin) const +{ + arm_compute::Coordinates coords; + arm_compute::TensorShape shape = armcomputetensorutils::BuildArmComputeTensorShape(subTensorShape); + + coords.set_num_dimensions(subTensorShape.GetNumDimensions()); + for (unsigned int i = 0; i < subTensorShape.GetNumDimensions(); ++i) + { + // Arm compute indexes tensor coords in reverse order. + unsigned int revertedIndex = subTensorShape.GetNumDimensions() - i - 1; + coords.set(i, boost::numeric_cast<int>(subTensorOrigin[revertedIndex])); + } + + const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape( + parent.GetShape()); + if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape)) + { + return nullptr; + } + + return std::make_unique<ClSubTensorHandle>( + boost::polymorphic_downcast<IClTensorHandle *>(&parent), shape, coords); +} + +std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const +{ + std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo); + tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup()); + + return tensorHandle; +} + +std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo, + DataLayout dataLayout) const +{ + std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo, dataLayout); + tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup()); + + return tensorHandle; +} + +const FactoryId& ClTensorHandleFactory::GetIdStatic() +{ + static const FactoryId s_Id(ClTensorHandleFactoryId()); + return s_Id; +} + +const FactoryId ClTensorHandleFactory::GetId() const +{ + return GetIdStatic(); +} + +bool ClTensorHandleFactory::SupportsSubTensors() const +{ + return true; +} + +MemorySourceFlags ClTensorHandleFactory::GetExportFlags() const +{ + return m_ExportFlags; +} + +MemorySourceFlags ClTensorHandleFactory::GetImportFlags() const +{ + return m_ImportFlags; +} + +} // namespace armnn
\ No newline at end of file |