diff options
Diffstat (limited to 'src/backends/neon')
-rw-r--r-- | src/backends/neon/NeonTensorHandle.hpp | 87 | ||||
-rw-r--r-- | src/backends/neon/NeonWorkloadFactory.cpp | 2 |
2 files changed, 82 insertions, 7 deletions
diff --git a/src/backends/neon/NeonTensorHandle.hpp b/src/backends/neon/NeonTensorHandle.hpp index 63e2a781d6..7206b6fc5a 100644 --- a/src/backends/neon/NeonTensorHandle.hpp +++ b/src/backends/neon/NeonTensorHandle.hpp @@ -55,8 +55,6 @@ public: m_MemoryGroup->manage(&m_Tensor); } - virtual ITensorHandle::Type GetType() const override { return ITensorHandle::Neon; } - virtual ITensorHandle* GetParent() const override { return nullptr; } virtual arm_compute::DataType GetDataType() const override @@ -87,6 +85,46 @@ public: } private: + // Only used for testing + void CopyOutTo(void* memory) const override + { + switch (this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<float*>(memory)); + break; + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<uint8_t*>(memory)); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + } + + // Only used for testing + void CopyInFrom(const void* memory) override + { + switch (this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), + this->GetTensor()); + break; + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), + this->GetTensor()); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + } + arm_compute::Tensor m_Tensor; std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup; }; @@ -108,8 +146,6 @@ public: virtual void Allocate() override {} virtual void Manage() override {} - virtual ITensorHandle::Type GetType() const override { return ITensorHandle::Neon; } - virtual ITensorHandle* GetParent() const override { return parentHandle; } virtual arm_compute::DataType GetDataType() const override @@ -134,9 +170,50 @@ public: { return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); } + private: + // Only used for testing + void CopyOutTo(void* memory) const override + { + switch (this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<float*>(memory)); + break; + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), + static_cast<uint8_t*>(memory)); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + } + + // Only used for testing + void CopyInFrom(const void* memory) override + { + switch (this->GetDataType()) + { + case arm_compute::DataType::F32: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), + this->GetTensor()); + break; + case arm_compute::DataType::QASYMM8: + armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), + this->GetTensor()); + break; + default: + { + throw armnn::UnimplementedException(); + } + } + } + arm_compute::SubTensor m_Tensor; ITensorHandle* parentHandle = nullptr; }; -} +} // namespace armnn
\ No newline at end of file diff --git a/src/backends/neon/NeonWorkloadFactory.cpp b/src/backends/neon/NeonWorkloadFactory.cpp index fc3890684d..604686771e 100644 --- a/src/backends/neon/NeonWorkloadFactory.cpp +++ b/src/backends/neon/NeonWorkloadFactory.cpp @@ -54,8 +54,6 @@ std::unique_ptr<ITensorHandle> NeonWorkloadFactory::CreateSubTensorHandle(ITenso TensorShape const& subTensorShape, unsigned int const* subTensorOrigin) const { - BOOST_ASSERT(parent.GetType() == ITensorHandle::Neon); - const arm_compute::TensorShape shape = armcomputetensorutils::BuildArmComputeTensorShape(subTensorShape); arm_compute::Coordinates coords; |