diff options
Diffstat (limited to 'src/backends/backendsCommon/CpuTensorHandle.hpp')
-rw-r--r-- | src/backends/backendsCommon/CpuTensorHandle.hpp | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/src/backends/backendsCommon/CpuTensorHandle.hpp b/src/backends/backendsCommon/CpuTensorHandle.hpp index b88a0d385b..dd6413f2e7 100644 --- a/src/backends/backendsCommon/CpuTensorHandle.hpp +++ b/src/backends/backendsCommon/CpuTensorHandle.hpp @@ -5,6 +5,7 @@ #pragma once #include "CpuTensorHandleFwd.hpp" +#include "CompatibleTypes.hpp" #include <armnn/TypesUtils.hpp> @@ -22,7 +23,7 @@ public: template <typename T> const T* GetConstTensor() const { - BOOST_ASSERT(GetTensorInfo().GetDataType() == GetDataType<T>()); + BOOST_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType())); return reinterpret_cast<const T*>(m_Memory); } @@ -82,7 +83,7 @@ public: template <typename T> T* GetTensor() const { - BOOST_ASSERT(GetTensorInfo().GetDataType() == GetDataType<T>()); + BOOST_ASSERT(CompatibleTypes<T>(GetTensorInfo().GetDataType())); return reinterpret_cast<T*>(m_MutableMemory); } |