// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "WorkloadData.hpp" #include "WorkloadInfo.hpp" #include #include #include #include namespace armnn { /// Workload interface to enqueue a layer computation. class IWorkload { public: virtual ~IWorkload() {} virtual void PostAllocationConfigure() = 0; virtual void Execute() const = 0; virtual profiling::ProfilingGuid GetGuid() const = 0; virtual void RegisterDebugCallback(const DebugCallbackFunction& func) {} }; // NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template // in the various workload factories. // There should never be an instantiation of a NullWorkload. class NullWorkload : public IWorkload { NullWorkload()=delete; }; template class BaseWorkload : public IWorkload { public: BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info) : m_Data(descriptor), m_Guid(profiling::ProfilingService::Instance().NextGuid()) { m_Data.Validate(info); } void PostAllocationConfigure() override {} const QueueDescriptor& GetData() const { return m_Data; } profiling::ProfilingGuid GetGuid() const final { return m_Guid; } protected: const QueueDescriptor m_Data; const profiling::ProfilingGuid m_Guid; }; // TypedWorkload used template class TypedWorkload : public BaseWorkload { public: TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info) : BaseWorkload(descriptor, info) { std::vector dataTypes = {DataTypes...}; armnn::DataType expectedInputType; if (!info.m_InputTensorInfos.empty()) { expectedInputType = info.m_InputTensorInfos.front().GetDataType(); if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end()) { BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type"); } BOOST_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()), info.m_InputTensorInfos.end(), [&](auto it){ return it.GetDataType() == expectedInputType; }), "Trying to create workload with incorrect type"); } armnn::DataType expectedOutputType; if (!info.m_OutputTensorInfos.empty()) { expectedOutputType = info.m_OutputTensorInfos.front().GetDataType(); if (!info.m_InputTensorInfos.empty()) { if (expectedOutputType != expectedInputType) { BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type"); } } else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end()) { BOOST_ASSERT_MSG(false, "Trying to create workload with incorrect type"); } BOOST_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()), info.m_OutputTensorInfos.end(), [&](auto it){ return it.GetDataType() == expectedOutputType; }), "Trying to create workload with incorrect type"); } } }; template class MultiTypedWorkload : public BaseWorkload { public: MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info) : BaseWorkload(descriptor, info) { BOOST_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(), info.m_InputTensorInfos.end(), [&](auto it){ return it.GetDataType() == InputDataType; }), "Trying to create workload with incorrect type"); BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(), info.m_OutputTensorInfos.end(), [&](auto it){ return it.GetDataType() == OutputDataType; }), "Trying to create workload with incorrect type"); } }; // FirstInputTypedWorkload used to check type of the first input template class FirstInputTypedWorkload : public BaseWorkload { public: FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info) : BaseWorkload(descriptor, info) { if (!info.m_InputTensorInfos.empty()) { BOOST_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType, "Trying to create workload with incorrect type"); } BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(), info.m_OutputTensorInfos.end(), [&](auto it){ return it.GetDataType() == DataType; }), "Trying to create workload with incorrect type"); } }; template using FloatWorkload = TypedWorkload; template using Float32Workload = TypedWorkload; template using Uint8Workload = TypedWorkload; template using Int32Workload = TypedWorkload; template using BooleanWorkload = TypedWorkload; template using BaseFloat32ComparisonWorkload = MultiTypedWorkload; template using BaseUint8ComparisonWorkload = MultiTypedWorkload; template using Float16ToFloat32Workload = MultiTypedWorkload; template using Float32ToFloat16Workload = MultiTypedWorkload; template using Uint8ToFloat32Workload = MultiTypedWorkload; } //namespace armnn