diff options
Diffstat (limited to 'src/backends/gpuFsa/GpuFsaTensorHandleFactory.cpp')
-rw-r--r-- | src/backends/gpuFsa/GpuFsaTensorHandleFactory.cpp | 67 |
1 files changed, 46 insertions, 21 deletions
diff --git a/src/backends/gpuFsa/GpuFsaTensorHandleFactory.cpp b/src/backends/gpuFsa/GpuFsaTensorHandleFactory.cpp index cd9d8cd64d..c1a34d24e5 100644 --- a/src/backends/gpuFsa/GpuFsaTensorHandleFactory.cpp +++ b/src/backends/gpuFsa/GpuFsaTensorHandleFactory.cpp @@ -1,32 +1,50 @@ // -// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "GpuFsaTensorHandle.hpp" #include "GpuFsaTensorHandleFactory.hpp" -#include "armnn/Logging.hpp" -#include <armnn/utility/IgnoreUnused.hpp> - namespace armnn { using FactoryId = ITensorHandleFactory::FactoryId; -const FactoryId& GpuFsaTensorHandleFactory::GetIdStatic() -{ - static const FactoryId s_Id(GpuFsaTensorHandleFactoryId()); - return s_Id; -} - std::unique_ptr<ITensorHandle> GpuFsaTensorHandleFactory::CreateSubTensorHandle(ITensorHandle& parent, - const TensorShape& subTensorShape, - const unsigned int* subTensorOrigin) - const + const TensorShape& subTensorShape, + const unsigned int* subTensorOrigin) const { - IgnoreUnused(parent, subTensorShape, subTensorOrigin); - return nullptr; + arm_compute::Coordinates coords; + arm_compute::TensorShape shape = armcomputetensorutils::BuildArmComputeTensorShape(subTensorShape); + + coords.set_num_dimensions(subTensorShape.GetNumDimensions()); + for (unsigned int i = 0; i < subTensorShape.GetNumDimensions(); ++i) + { + // Arm compute indexes tensor coords in reverse order. + unsigned int revertedIndex = subTensorShape.GetNumDimensions() - i - 1; + coords.set(i, armnn::numeric_cast<int>(subTensorOrigin[revertedIndex])); + } + + const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape(parent.GetShape()); + + // In order for ACL to support subtensors the concat axis cannot be on x or y and the values of x and y + // must match the parent shapes + if (coords.x() != 0 || coords.y() != 0) + { + return nullptr; + } + if ((parentShape.x() != shape.x()) || (parentShape.y() != shape.y())) + { + return nullptr; + } + + if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape)) + { + return nullptr; + } + + return std::make_unique<GpuFsaSubTensorHandle>(PolymorphicDowncast<IClTensorHandle*>(&parent), shape, coords); } std::unique_ptr<ITensorHandle> GpuFsaTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const @@ -43,25 +61,32 @@ std::unique_ptr<ITensorHandle> GpuFsaTensorHandleFactory::CreateTensorHandle(con std::unique_ptr<ITensorHandle> GpuFsaTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo, const bool IsMemoryManaged) const { - std::unique_ptr<GpuFsaTensorHandle> handle = std::make_unique<GpuFsaTensorHandle>(tensorInfo, m_MemoryManager); + std::unique_ptr<GpuFsaTensorHandle> tensorHandle = std::make_unique<GpuFsaTensorHandle>(tensorInfo); if (!IsMemoryManaged) { ARMNN_LOG(warning) << "GpuFsaTensorHandleFactory only has support for memory managed."; } - return handle; + tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup()); + return tensorHandle; } std::unique_ptr<ITensorHandle> GpuFsaTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout, const bool IsMemoryManaged) const { - IgnoreUnused(dataLayout); - std::unique_ptr<GpuFsaTensorHandle> handle = std::make_unique<GpuFsaTensorHandle>(tensorInfo, m_MemoryManager); + std::unique_ptr<GpuFsaTensorHandle> tensorHandle = std::make_unique<GpuFsaTensorHandle>(tensorInfo, dataLayout); if (!IsMemoryManaged) { ARMNN_LOG(warning) << "GpuFsaTensorHandleFactory only has support for memory managed."; } - return handle; + tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup()); + return tensorHandle; +} + +const FactoryId& GpuFsaTensorHandleFactory::GetIdStatic() +{ + static const FactoryId s_Id(GpuFsaTensorHandleFactoryId()); + return s_Id; } const FactoryId& GpuFsaTensorHandleFactory::GetId() const @@ -71,7 +96,7 @@ const FactoryId& GpuFsaTensorHandleFactory::GetId() const bool GpuFsaTensorHandleFactory::SupportsSubTensors() const { - return false; + return true; } MemorySourceFlags GpuFsaTensorHandleFactory::GetExportFlags() const |