diff options
author | David Beck <david.beck@arm.com> | 2018-10-30 11:38:41 +0000 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2018-11-07 13:24:50 +0000 |
commit | 09e2f27a9da7a65eb409f3dbdfc029eb3afbb003 (patch) | |
tree | a2af70b701dca0f4688610dffbe68a74778289d3 /src/backends/neon | |
parent | 9efb57d62197aeb7d868c289bb34166c132f0287 (diff) | |
download | armnn-09e2f27a9da7a65eb409f3dbdfc029eb3afbb003.tar.gz |
IVGCVSW-1949 : Refactor ITensorHandle and move backend specifics to their place
Change-Id: I48242425c6a6856e13ebcee1b140cbd2af94a3aa
Diffstat (limited to 'src/backends/neon')
-rw-r--r-- | src/backends/neon/NeonTensorHandle.hpp | 87 | ||||
-rw-r--r-- | src/backends/neon/NeonWorkloadFactory.cpp | 2 |
2 files changed, 82 insertions, 7 deletions
diff --git a/src/backends/neon/NeonTensorHandle.hpp b/src/backends/neon/NeonTensorHandle.hpp index 63e2a781d6..7206b6fc5a 100644 --- a/src/backends/neon/NeonTensorHandle.hpp +++ b/src/backends/neon/NeonTensorHandle.hpp @@ -55,8 +55,6 @@ public: m_MemoryGroup->manage(&m_Tensor); } - virtual ITensorHandle::Type GetType() const override { return ITensorHandle::Neon; } - virtual ITensorHandle* GetParent() const override { return nullptr; } virtual arm_compute::DataType GetDataType() const override @@ -87,6 +85,46 @@ public: } private: + // Only used for testing + void CopyOutTo(void* memory) const override + { + switch (this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<float*>(memory)); + break; + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<uint8_t*>(memory)); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + } + + // Only used for testing + void CopyInFrom(const void* memory) override + { + switch (this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), + this->GetTensor()); + break; + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), + this->GetTensor()); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + } + arm_compute::Tensor m_Tensor; std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup; }; @@ -108,8 +146,6 @@ public: virtual void Allocate() override {} virtual void Manage() override {} - virtual ITensorHandle::Type GetType() const override { return ITensorHandle::Neon; } - virtual ITensorHandle* GetParent() const override { return parentHandle; } virtual arm_compute::DataType GetDataType() const override @@ -134,9 +170,50 @@ public: { return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); } + private: + // Only used for testing + void CopyOutTo(void* memory) const override + { + switch (this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<float*>(memory)); + break; + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<uint8_t*>(memory)); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + } + + // Only used for testing + void CopyInFrom(const void* memory) override + { + switch (this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), + this->GetTensor()); + break; + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), + this->GetTensor()); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + } + arm_compute::SubTensor m_Tensor; ITensorHandle* parentHandle = nullptr; }; -} +} // namespace armnn
\ No newline at end of file diff --git a/src/backends/neon/NeonWorkloadFactory.cpp b/src/backends/neon/NeonWorkloadFactory.cpp index fc3890684d..604686771e 100644 --- a/src/backends/neon/NeonWorkloadFactory.cpp +++ b/src/backends/neon/NeonWorkloadFactory.cpp @@ -54,8 +54,6 @@ std::unique_ptr<ITensorHandle> NeonWorkloadFactory::CreateSubTensorHandle(ITenso TensorShape const& subTensorShape, unsigned int const* subTensorOrigin) const { - BOOST_ASSERT(parent.GetType() == ITensorHandle::Neon); - const arm_compute::TensorShape shape = armcomputetensorutils::BuildArmComputeTensorShape(subTensorShape); arm_compute::Coordinates coords; |