aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/Workload.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/Workload.hpp')
-rw-r--r--src/backends/backendsCommon/Workload.hpp25
1 files changed, 25 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/Workload.hpp b/src/backends/backendsCommon/Workload.hpp
index 309d53f48e..65392194a2 100644
--- a/src/backends/backendsCommon/Workload.hpp
+++ b/src/backends/backendsCommon/Workload.hpp
@@ -116,6 +116,7 @@ public:
return it.GetDataType() == InputDataType;
}),
"Trying to create workload with incorrect type");
+
BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
info.m_OutputTensorInfos.end(),
[&](auto it){
@@ -125,6 +126,30 @@ public:
}
};
+// FirstInputTypedWorkload used to check type of the first input
+template <typename QueueDescriptor, armnn::DataType DataType>
+class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
+{
+public:
+
+ FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
+ : BaseWorkload<QueueDescriptor>(descriptor, info)
+ {
+ if (!info.m_InputTensorInfos.empty())
+ {
+ BOOST_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType,
+ "Trying to create workload with incorrect type");
+ }
+
+ BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
+ info.m_OutputTensorInfos.end(),
+ [&](auto it){
+ return it.GetDataType() == DataType;
+ }),
+ "Trying to create workload with incorrect type");
+ }
+};
+
template <typename QueueDescriptor>
using FloatWorkload = TypedWorkload<QueueDescriptor,
armnn::DataType::Float16,