aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon/NeonTensorHandle.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/neon/NeonTensorHandle.hpp')
-rw-r--r--src/backends/neon/NeonTensorHandle.hpp9
1 files changed, 9 insertions, 0 deletions
diff --git a/src/backends/neon/NeonTensorHandle.hpp b/src/backends/neon/NeonTensorHandle.hpp
index 2e9be11be1..11d20878d7 100644
--- a/src/backends/neon/NeonTensorHandle.hpp
+++ b/src/backends/neon/NeonTensorHandle.hpp
@@ -4,6 +4,7 @@
//
#pragma once
+#include <BFloat16.hpp>
#include <Half.hpp>
#include <aclCommon/ArmComputeTensorHandle.hpp>
@@ -176,6 +177,10 @@ private:
armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
static_cast<uint8_t*>(memory));
break;
+ case arm_compute::DataType::BFLOAT16:
+ armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
+ static_cast<armnn::BFloat16*>(memory));
+ break;
case arm_compute::DataType::F16:
armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
static_cast<armnn::Half*>(memory));
@@ -210,6 +215,10 @@ private:
armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
this->GetTensor());
break;
+ case arm_compute::DataType::BFLOAT16:
+ armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::BFloat16*>(memory),
+ this->GetTensor());
+ break;
case arm_compute::DataType::F16:
armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
this->GetTensor());