aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/Workload.hpp
diff options
context:
space:
mode:
authornarpra01 <narumol.prangnawarat@arm.com>2019-01-18 16:53:53 +0000
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2019-01-22 17:46:51 +0000
commit4951d84b1174a4bb0a5d9c900740f64201f765bf (patch)
treefe713740ac0acbaa8d74bbc9cdb450d08ef9f575 /src/backends/backendsCommon/Workload.hpp
parent0edd46737065d3e5c09aa959807e81f9836ee709 (diff)
downloadarmnn-4951d84b1174a4bb0a5d9c900740f64201f765bf.tar.gz
IVGCVSW-2510 Ref workload implementation for Gather operator
* add implemenentation for GatherQueueDescriptor validate function * add FirstInputTypedWorkload to allow type check on the first input tensor only * add ref workload implemenentation for float and uint8 * add Gather layer support in Ref * unit tests Change-Id: I4578a3211f11d24aa29d15bcf7f45b0445bcd1ee
Diffstat (limited to 'src/backends/backendsCommon/Workload.hpp')
-rw-r--r--src/backends/backendsCommon/Workload.hpp25
1 files changed, 25 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/Workload.hpp b/src/backends/backendsCommon/Workload.hpp
index 309d53f48e..65392194a2 100644
--- a/src/backends/backendsCommon/Workload.hpp
+++ b/src/backends/backendsCommon/Workload.hpp
@@ -116,6 +116,7 @@ public:
return it.GetDataType() == InputDataType;
}),
"Trying to create workload with incorrect type");
+
BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
info.m_OutputTensorInfos.end(),
[&](auto it){
@@ -125,6 +126,30 @@ public:
}
};
+// FirstInputTypedWorkload used to check type of the first input
+template <typename QueueDescriptor, armnn::DataType DataType>
+class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
+{
+public:
+
+ FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
+ : BaseWorkload<QueueDescriptor>(descriptor, info)
+ {
+ if (!info.m_InputTensorInfos.empty())
+ {
+ BOOST_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType,
+ "Trying to create workload with incorrect type");
+ }
+
+ BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
+ info.m_OutputTensorInfos.end(),
+ [&](auto it){
+ return it.GetDataType() == DataType;
+ }),
+ "Trying to create workload with incorrect type");
+ }
+};
+
template <typename QueueDescriptor>
using FloatWorkload = TypedWorkload<QueueDescriptor,
armnn::DataType::Float16,