diff options
Diffstat (limited to 'src/backends/neon/NeonTensorHandle.hpp')
-rw-r--r-- | src/backends/neon/NeonTensorHandle.hpp | 87 |
1 files changed, 82 insertions, 5 deletions
diff --git a/src/backends/neon/NeonTensorHandle.hpp b/src/backends/neon/NeonTensorHandle.hpp index 63e2a781d6..7206b6fc5a 100644 --- a/src/backends/neon/NeonTensorHandle.hpp +++ b/src/backends/neon/NeonTensorHandle.hpp @@ -55,8 +55,6 @@ public: m_MemoryGroup->manage(&m_Tensor); } - virtual ITensorHandle::Type GetType() const override { return ITensorHandle::Neon; } - virtual ITensorHandle* GetParent() const override { return nullptr; } virtual arm_compute::DataType GetDataType() const override @@ -87,6 +85,46 @@ public: } private: + // Only used for testing + void CopyOutTo(void* memory) const override + { + switch (this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<float*>(memory)); + break; + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<uint8_t*>(memory)); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + } + + // Only used for testing + void CopyInFrom(const void* memory) override + { + switch (this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), + this->GetTensor()); + break; + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), + this->GetTensor()); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + } + arm_compute::Tensor m_Tensor; std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup; }; @@ -108,8 +146,6 @@ public: virtual void Allocate() override {} virtual void Manage() override {} - virtual ITensorHandle::Type GetType() const override { return ITensorHandle::Neon; } - virtual ITensorHandle* GetParent() const override { return parentHandle; } virtual arm_compute::DataType GetDataType() const override @@ -134,9 +170,50 @@ public: { return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); } + private: + // Only used for testing + void CopyOutTo(void* memory) const override + { + switch (this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<float*>(memory)); + break; + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<uint8_t*>(memory)); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + } + + // Only used for testing + void CopyInFrom(const void* memory) override + { + switch (this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), + this->GetTensor()); + break; + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), + this->GetTensor()); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + } + arm_compute::SubTensor m_Tensor; ITensorHandle* parentHandle = nullptr; }; -} +} // namespace armnn
\ No newline at end of file |