From 09e2f27a9da7a65eb409f3dbdfc029eb3afbb003 Mon Sep 17 00:00:00 2001 From: David Beck Date: Tue, 30 Oct 2018 11:38:41 +0000 Subject: IVGCVSW-1949 : Refactor ITensorHandle and move backend specifics to their place Change-Id: I48242425c6a6856e13ebcee1b140cbd2af94a3aa --- src/backends/neon/NeonTensorHandle.hpp | 87 +++++++++++++++++++++++++++++-- src/backends/neon/NeonWorkloadFactory.cpp | 2 - 2 files changed, 82 insertions(+), 7 deletions(-) (limited to 'src/backends/neon') diff --git a/src/backends/neon/NeonTensorHandle.hpp b/src/backends/neon/NeonTensorHandle.hpp index 63e2a781d6..7206b6fc5a 100644 --- a/src/backends/neon/NeonTensorHandle.hpp +++ b/src/backends/neon/NeonTensorHandle.hpp @@ -55,8 +55,6 @@ public: m_MemoryGroup->manage(&m_Tensor); } - virtual ITensorHandle::Type GetType() const override { return ITensorHandle::Neon; } - virtual ITensorHandle* GetParent() const override { return nullptr; } virtual arm_compute::DataType GetDataType() const override @@ -87,6 +85,46 @@ public: } private: + // Only used for testing + void CopyOutTo(void* memory) const override + { + switch (this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast(memory)); + break; + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast(memory)); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + } + + // Only used for testing + void CopyInFrom(const void* memory) override + { + switch (this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), + this->GetTensor()); + break; + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), + this->GetTensor()); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + } + arm_compute::Tensor m_Tensor; std::shared_ptr m_MemoryGroup; }; @@ -108,8 +146,6 @@ public: virtual void Allocate() override {} virtual void Manage() override {} - virtual ITensorHandle::Type GetType() const override { return ITensorHandle::Neon; } - virtual ITensorHandle* GetParent() const override { return parentHandle; } virtual arm_compute::DataType GetDataType() const override @@ -134,9 +170,50 @@ public: { return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); } + private: + // Only used for testing + void CopyOutTo(void* memory) const override + { + switch (this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast(memory)); + break; + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast(memory)); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + } + + // Only used for testing + void CopyInFrom(const void* memory) override + { + switch (this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), + this->GetTensor()); + break; + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(static_cast(memory), + this->GetTensor()); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + } + arm_compute::SubTensor m_Tensor; ITensorHandle* parentHandle = nullptr; }; -} +} // namespace armnn \ No newline at end of file diff --git a/src/backends/neon/NeonWorkloadFactory.cpp b/src/backends/neon/NeonWorkloadFactory.cpp index fc3890684d..604686771e 100644 --- a/src/backends/neon/NeonWorkloadFactory.cpp +++ b/src/backends/neon/NeonWorkloadFactory.cpp @@ -54,8 +54,6 @@ std::unique_ptr NeonWorkloadFactory::CreateSubTensorHandle(ITenso TensorShape const& subTensorShape, unsigned int const* subTensorOrigin) const { - BOOST_ASSERT(parent.GetType() == ITensorHandle::Neon); - const arm_compute::TensorShape shape = armcomputetensorutils::BuildArmComputeTensorShape(subTensorShape); arm_compute::Coordinates coords; -- cgit v1.2.1