diff options
Diffstat (limited to 'src/backends/neon/NeonTensorHandle.hpp')
-rw-r--r-- | src/backends/neon/NeonTensorHandle.hpp | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/src/backends/neon/NeonTensorHandle.hpp b/src/backends/neon/NeonTensorHandle.hpp index ae8aa5d8c7..dd4c2572f9 100644 --- a/src/backends/neon/NeonTensorHandle.hpp +++ b/src/backends/neon/NeonTensorHandle.hpp @@ -29,7 +29,8 @@ public: NeonTensorHandle(const TensorInfo& tensorInfo) : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)), m_Imported(false), - m_IsImportEnabled(false) + m_IsImportEnabled(false), + m_TypeAlignment(GetDataTypeSize(tensorInfo.GetDataType())) { armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo); } @@ -39,7 +40,9 @@ public: MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc)) : m_ImportFlags(importFlags), m_Imported(false), - m_IsImportEnabled(false) + m_IsImportEnabled(false), + m_TypeAlignment(GetDataTypeSize(tensorInfo.GetDataType())) + { armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout); @@ -117,9 +120,7 @@ public: { if (source == MemorySource::Malloc && m_IsImportEnabled) { - // Checks the 16 byte memory alignment - constexpr uintptr_t alignment = sizeof(size_t); - if (reinterpret_cast<uintptr_t>(memory) % alignment) + if (reinterpret_cast<uintptr_t>(memory) % m_TypeAlignment) { throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory"); } @@ -263,6 +264,7 @@ private: MemorySourceFlags m_ImportFlags; bool m_Imported; bool m_IsImportEnabled; + const uintptr_t m_TypeAlignment; }; class NeonSubTensorHandle : public IAclTensorHandle |