aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/CpuTensorHandle.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/CpuTensorHandle.hpp')
-rw-r--r--src/backends/backendsCommon/CpuTensorHandle.hpp5
1 files changed, 3 insertions, 2 deletions
diff --git a/src/backends/backendsCommon/CpuTensorHandle.hpp b/src/backends/backendsCommon/CpuTensorHandle.hpp
index b88a0d385b..dd6413f2e7 100644
--- a/src/backends/backendsCommon/CpuTensorHandle.hpp
+++ b/src/backends/backendsCommon/CpuTensorHandle.hpp
@@ -5,6 +5,7 @@
#pragma once
#include "CpuTensorHandleFwd.hpp"
+#include "CompatibleTypes.hpp"
#include <armnn/TypesUtils.hpp>
@@ -22,7 +23,7 @@ public:
template <typename T>
const T* GetConstTensor() const
{
- BOOST_ASSERT(GetTensorInfo().GetDataType() == GetDataType<T>());
+ BOOST_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType()));
return reinterpret_cast<const T*>(m_Memory);
}
@@ -82,7 +83,7 @@ public:
template <typename T>
T* GetTensor() const
{
- BOOST_ASSERT(GetTensorInfo().GetDataType() == GetDataType<T>());
+ BOOST_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType()));
return reinterpret_cast<T*>(m_MutableMemory);
}