aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon/NeonTensorHandle.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/neon/NeonTensorHandle.hpp')
-rw-r--r--src/backends/neon/NeonTensorHandle.hpp87
1 files changed, 82 insertions, 5 deletions
diff --git a/src/backends/neon/NeonTensorHandle.hpp b/src/backends/neon/NeonTensorHandle.hpp
index 63e2a781d6..7206b6fc5a 100644
--- a/src/backends/neon/NeonTensorHandle.hpp
+++ b/src/backends/neon/NeonTensorHandle.hpp
@@ -55,8 +55,6 @@ public:
m_MemoryGroup->manage(&m_Tensor);
}
- virtual ITensorHandle::Type GetType() const override { return ITensorHandle::Neon; }
-
virtual ITensorHandle* GetParent() const override { return nullptr; }
virtual arm_compute::DataType GetDataType() const override
@@ -87,6 +85,46 @@ public:
}
private:
+ // Only used for testing
+ void CopyOutTo(void* memory) const override
+ {
+ switch (this->GetDataType())
+ {
+ case arm_compute::DataType::F32:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<float*>(memory));
+ break;
+ case arm_compute::DataType::QASYMM8:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<uint8_t*>(memory));
+ break;
+ default:
+ {
+ throw armnn::UnimplementedException();
+ }
+ }
+ }
+
+ // Only used for testing
+ void CopyInFrom(const void* memory) override
+ {
+ switch (this->GetDataType())
+ {
+ case arm_compute::DataType::F32:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::QASYMM8:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
+ this->GetTensor());
+ break;
+ default:
+ {
+ throw armnn::UnimplementedException();
+ }
+ }
+ }
+
arm_compute::Tensor m_Tensor;
std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
};
@@ -108,8 +146,6 @@ public:
virtual void Allocate() override {}
virtual void Manage() override {}
- virtual ITensorHandle::Type GetType() const override { return ITensorHandle::Neon; }
-
virtual ITensorHandle* GetParent() const override { return parentHandle; }
virtual arm_compute::DataType GetDataType() const override
@@ -134,9 +170,50 @@ public:
{
return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
}
+
private:
+ // Only used for testing
+ void CopyOutTo(void* memory) const override
+ {
+ switch (this->GetDataType())
+ {
+ case arm_compute::DataType::F32:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<float*>(memory));
+ break;
+ case arm_compute::DataType::QASYMM8:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<uint8_t*>(memory));
+ break;
+ default:
+ {
+ throw armnn::UnimplementedException();
+ }
+ }
+ }
+
+ // Only used for testing
+ void CopyInFrom(const void* memory) override
+ {
+ switch (this->GetDataType())
+ {
+ case arm_compute::DataType::F32:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
+ this->GetTensor());
+ break;
+ case arm_compute::DataType::QASYMM8:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
+ this->GetTensor());
+ break;
+ default:
+ {
+ throw armnn::UnimplementedException();
+ }
+ }
+ }
+
arm_compute::SubTensor m_Tensor;
ITensorHandle* parentHandle = nullptr;
};
-}
+} // namespace armnn \ No newline at end of file