From 1f58f03d82c482626b1b4673b6c0e25da4338fb5 Mon Sep 17 00:00:00 2001 From: James Conroy Date: Tue, 27 Apr 2021 17:13:27 +0100 Subject: IVGCVSW-5815 Generalise ConstCpuTensorHandle * Generalises ConstCpuTensorHandle and inherited classes by removing 'Cpu' from aliases. * New renamed classes: ConstTensorHandle, TensorHandle, ScopedTensorHandle, PassthroughTensorHandle, ConstPassthroughTensorHandle. Signed-off-by: James Conroy Change-Id: I1824e0e134202735fb77051f20a7252f161dfe16 --- src/backends/backendsCommon/TensorHandle.cpp | 141 +++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 src/backends/backendsCommon/TensorHandle.cpp (limited to 'src/backends/backendsCommon/TensorHandle.cpp') diff --git a/src/backends/backendsCommon/TensorHandle.cpp b/src/backends/backendsCommon/TensorHandle.cpp new file mode 100644 index 0000000000..d4660d6de3 --- /dev/null +++ b/src/backends/backendsCommon/TensorHandle.cpp @@ -0,0 +1,141 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// +#include +#include + +#include + +#include + +namespace armnn +{ + +TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo) +{ + TensorShape shape(tensorInfo.GetShape()); + auto size = GetDataTypeSize(tensorInfo.GetDataType()); + auto runningSize = size; + std::vector strides(shape.GetNumDimensions()); + auto lastIdx = shape.GetNumDimensions()-1; + for (unsigned int i=0; i < lastIdx ; i++) + { + strides[lastIdx-i] = runningSize; + runningSize *= shape[lastIdx-i]; + } + strides[0] = runningSize; + return TensorShape(shape.GetNumDimensions(), strides.data()); +} + +ConstTensorHandle::ConstTensorHandle(const TensorInfo& tensorInfo) +: m_TensorInfo(tensorInfo) +, m_Memory(nullptr) +{ +} + +template <> +const void* ConstTensorHandle::GetConstTensor() const +{ + return m_Memory; +} + +TensorHandle::TensorHandle(const TensorInfo& tensorInfo) +: ConstTensorHandle(tensorInfo) +, m_MutableMemory(nullptr) +{ +} + +template <> +void* TensorHandle::GetTensor() const +{ + return m_MutableMemory; +} + +ScopedTensorHandle::ScopedTensorHandle(const TensorInfo& tensorInfo) +: TensorHandle(tensorInfo) +{ +} + +ScopedTensorHandle::ScopedTensorHandle(const ConstTensor& tensor) +: ScopedTensorHandle(tensor.GetInfo()) +{ + CopyFrom(tensor.GetMemoryArea(), tensor.GetNumBytes()); +} + +ScopedTensorHandle::ScopedTensorHandle(const ConstTensorHandle& tensorHandle) +: ScopedTensorHandle(tensorHandle.GetTensorInfo()) +{ + CopyFrom(tensorHandle.GetConstTensor(), tensorHandle.GetTensorInfo().GetNumBytes()); +} + +ScopedTensorHandle::ScopedTensorHandle(const ScopedTensorHandle& other) +: TensorHandle(other.GetTensorInfo()) +{ + CopyFrom(other); +} + +ScopedTensorHandle& ScopedTensorHandle::operator=(const ScopedTensorHandle& other) +{ + ::operator delete(GetTensor()); + SetMemory(nullptr); + CopyFrom(other); + return *this; +} + +ScopedTensorHandle::~ScopedTensorHandle() +{ + ::operator delete(GetTensor()); +} + +void ScopedTensorHandle::Allocate() +{ + if (GetTensor() == nullptr) + { + SetMemory(::operator new(GetTensorInfo().GetNumBytes())); + } + else + { + throw InvalidArgumentException("TensorHandle::Allocate Trying to allocate a TensorHandle" + "that already has allocated memory."); + } +} + +void ScopedTensorHandle::CopyOutTo(void* memory) const +{ + memcpy(memory, GetTensor(), GetTensorInfo().GetNumBytes()); +} + +void ScopedTensorHandle::CopyInFrom(const void* memory) +{ + memcpy(GetTensor(), memory, GetTensorInfo().GetNumBytes()); +} + +void ScopedTensorHandle::CopyFrom(const ScopedTensorHandle& other) +{ + CopyFrom(other.GetTensor(), other.GetTensorInfo().GetNumBytes()); +} + +void ScopedTensorHandle::CopyFrom(const void* srcMemory, unsigned int numBytes) +{ + ARMNN_ASSERT(GetTensor() == nullptr); + ARMNN_ASSERT(GetTensorInfo().GetNumBytes() == numBytes); + + if (srcMemory) + { + Allocate(); + memcpy(GetTensor(), srcMemory, numBytes); + } +} + +void PassthroughTensorHandle::Allocate() +{ + throw InvalidArgumentException("PassthroughTensorHandle::Allocate() should never be called"); +} + +void ConstPassthroughTensorHandle::Allocate() +{ + throw InvalidArgumentException("ConstPassthroughTensorHandle::Allocate() should never be called"); +} + +} // namespace armnn -- cgit v1.2.1