aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon/NeonTensorHandle.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/neon/NeonTensorHandle.hpp')
-rw-r--r--src/backends/neon/NeonTensorHandle.hpp168
1 files changed, 165 insertions, 3 deletions
diff --git a/src/backends/neon/NeonTensorHandle.hpp b/src/backends/neon/NeonTensorHandle.hpp
index fcae77cdaa..e5f210773d 100644
--- a/src/backends/neon/NeonTensorHandle.hpp
+++ b/src/backends/neon/NeonTensorHandle.hpp
@@ -1,7 +1,8 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
+
#pragma once
#include <BFloat16.hpp>
@@ -19,9 +20,11 @@
#include <arm_compute/runtime/SubTensor.h>
#include <arm_compute/core/TensorShape.h>
#include <arm_compute/core/Coordinates.h>
+#include "armnn/TypesUtils.hpp"
namespace armnn
{
+class NeonTensorHandleDecorator;
class NeonTensorHandle : public IAclTensorHandle
{
@@ -125,7 +128,7 @@ public:
virtual bool Import(void* memory, MemorySource source) override
{
- if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
+ if (m_ImportFlags& static_cast<MemorySourceFlags>(source))
{
if (source == MemorySource::Malloc && m_IsImportEnabled)
{
@@ -181,6 +184,8 @@ public:
return false;
}
+ virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo& tensorInfo) override;
+
private:
// Only used for testing
void CopyOutTo(void* memory) const override
@@ -275,6 +280,7 @@ private:
bool m_Imported;
bool m_IsImportEnabled;
const uintptr_t m_TypeAlignment;
+ std::vector<std::shared_ptr<NeonTensorHandleDecorator>> m_Decorated;
};
class NeonSubTensorHandle : public IAclTensorHandle
@@ -283,7 +289,7 @@ public:
NeonSubTensorHandle(IAclTensorHandle* parent,
const arm_compute::TensorShape& shape,
const arm_compute::Coordinates& coords)
- : m_Tensor(&parent->GetTensor(), shape, coords)
+ : m_Tensor(&parent->GetTensor(), shape, coords, true)
{
parentHandle = parent;
}
@@ -319,6 +325,11 @@ public:
return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
}
+ virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo&) override
+ {
+ return nullptr;
+ };
+
private:
// Only used for testing
void CopyOutTo(void* memory) const override
@@ -394,4 +405,155 @@ private:
ITensorHandle* parentHandle = nullptr;
};
+/// NeonTensorDecorator wraps an existing Neon tensor allowing us to override the TensorInfo for it
+class NeonTensorDecorator : public arm_compute::ITensor
+{
+public:
+ NeonTensorDecorator();
+
+ NeonTensorDecorator(arm_compute::ITensor* original, const TensorInfo& info);
+
+ ~NeonTensorDecorator() = default;
+
+ NeonTensorDecorator(const NeonTensorDecorator&) = delete;
+
+ NeonTensorDecorator& operator=(const NeonTensorDecorator&) = delete;
+
+ NeonTensorDecorator(NeonTensorDecorator&&) = default;
+
+ NeonTensorDecorator& operator=(NeonTensorDecorator&&) = default;
+
+ // Inherited methods overridden:
+ arm_compute::ITensorInfo* info() const override;
+
+ arm_compute::ITensorInfo* info() override;
+
+ uint8_t* buffer() const override;
+
+private:
+ arm_compute::ITensor* m_Original;
+ mutable arm_compute::TensorInfo m_TensorInfo;
+};
+
+class NeonTensorHandleDecorator : public IAclTensorHandle
+{
+public:
+ NeonTensorHandleDecorator(IAclTensorHandle* parent, const TensorInfo& info)
+ : m_Tensor(&parent->GetTensor(), info)
+ {
+ parentHandle = parent;
+ }
+
+ arm_compute::ITensor& GetTensor() override { return m_Tensor; }
+ arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
+
+ virtual void Allocate() override {}
+ virtual void Manage() override {}
+
+ virtual ITensorHandle* GetParent() const override { return nullptr; }
+
+ virtual arm_compute::DataType GetDataType() const override
+ {
+ return m_Tensor.info()->data_type();
+ }
+
+ virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
+
+ virtual const void* Map(bool /* blocking = true */) const override
+ {
+ return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
+ }
+ virtual void Unmap() const override {}
+
+ TensorShape GetStrides() const override
+ {
+ return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
+ }
+
+ TensorShape GetShape() const override
+ {
+ return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
+ }
+
+ virtual std::shared_ptr<ITensorHandle> DecorateTensorHandle(const TensorInfo&) override
+ {
+ return nullptr;
+ };
+
+private:
+ // Only used for testing
+ void CopyOutTo(void* memory) const override
+ {
+ switch (this->GetDataType())
+ {
+ case arm_compute::DataType::F32:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<float*>(memory));
+ break;
+ case arm_compute::DataType::U8:
+ case arm_compute::DataType::QASYMM8:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<uint8_t*>(memory));
+ break;
+ case arm_compute::DataType::QSYMM8:
+ case arm_compute::DataType::QASYMM8_SIGNED:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<int8_t*>(memory));
+ break;
+ case arm_compute::DataType::S16:
+ case arm_compute::DataType::QSYMM16:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<int16_t*>(memory));
+ break;
+ case arm_compute::DataType::S32:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<int32_t*>(memory));
+ break;
+ default:
+ {
+ throw armnn::UnimplementedException();
+ }
+ }
+ }
+
+ // Only used for testing
+ void CopyInFrom(const void* memory) override
+ {
+ switch (this->GetDataType())
+ {
+ case arm_compute::DataType::F32:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::U8:
+ case arm_compute::DataType::QASYMM8:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::QSYMM8:
+ case arm_compute::DataType::QASYMM8_SIGNED:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::S16:
+ case arm_compute::DataType::QSYMM16:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::S32:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
+ this->GetTensor());
+ break;
+ default:
+ {
+ throw armnn::UnimplementedException();
+ }
+ }
+ }
+
+ NeonTensorDecorator m_Tensor;
+ ITensorHandle* parentHandle = nullptr;
+};
+
+
} // namespace armnn