diff options
Diffstat (limited to 'src/backends/backendsCommon/Workload.hpp')
-rw-r--r-- | src/backends/backendsCommon/Workload.hpp | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/Workload.hpp b/src/backends/backendsCommon/Workload.hpp index 309d53f48e..65392194a2 100644 --- a/src/backends/backendsCommon/Workload.hpp +++ b/src/backends/backendsCommon/Workload.hpp @@ -116,6 +116,7 @@ public: return it.GetDataType() == InputDataType; }), "Trying to create workload with incorrect type"); + BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(), info.m_OutputTensorInfos.end(), [&](auto it){ @@ -125,6 +126,30 @@ public: } }; +// FirstInputTypedWorkload used to check type of the first input +template <typename QueueDescriptor, armnn::DataType DataType> +class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor> +{ +public: + + FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info) + : BaseWorkload<QueueDescriptor>(descriptor, info) + { + if (!info.m_InputTensorInfos.empty()) + { + BOOST_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType, + "Trying to create workload with incorrect type"); + } + + BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(), + info.m_OutputTensorInfos.end(), + [&](auto it){ + return it.GetDataType() == DataType; + }), + "Trying to create workload with incorrect type"); + } +}; + template <typename QueueDescriptor> using FloatWorkload = TypedWorkload<QueueDescriptor, armnn::DataType::Float16, |