// // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "IWorkload.hpp" #include "WorkloadData.hpp" #include "WorkloadInfo.hpp" #include "WorkingMemDescriptor.hpp" #include #include #include namespace armnn { // NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template // in the various workload factories. // There should never be an instantiation of a NullWorkload. class NullWorkload : public IWorkload { NullWorkload()=delete; }; template class BaseWorkload : public IWorkload { public: BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info) : m_Data(descriptor), m_Guid(profiling::ProfilingService::GetNextGuid()) { m_Data.Validate(info); } void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override { ARMNN_LOG(info) << "Using default async workload execution, this will network affect performance"; std::lock_guard lockGuard(m_AsyncWorkloadMutex); m_Data.m_Inputs = workingMemDescriptor.m_Inputs; m_Data.m_Outputs = workingMemDescriptor.m_Outputs; Execute(); }; void PostAllocationConfigure() override {} const QueueDescriptor& GetData() const { return m_Data; } profiling::ProfilingGuid GetGuid() const final { return m_Guid; } // Replace input tensor handle with the given TensorHandle void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override { m_Data.m_Inputs[slot] = tensorHandle; } // Replace output tensor handle with the given TensorHandle void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override { m_Data.m_Outputs[slot] = tensorHandle; } protected: QueueDescriptor m_Data; const profiling::ProfilingGuid m_Guid; private: std::mutex m_AsyncWorkloadMutex; }; // TypedWorkload used template class TypedWorkload : public BaseWorkload { public: TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info) : BaseWorkload(descriptor, info) { std::vector dataTypes = {DataTypes...}; armnn::DataType expectedInputType; if (!info.m_InputTensorInfos.empty()) { expectedInputType = info.m_InputTensorInfos.front().GetDataType(); if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end()) { ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type"); } ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()), info.m_InputTensorInfos.end(), [&](auto it){ return it.GetDataType() == expectedInputType; }), "Trying to create workload with incorrect type"); } armnn::DataType expectedOutputType; if (!info.m_OutputTensorInfos.empty()) { expectedOutputType = info.m_OutputTensorInfos.front().GetDataType(); if (!info.m_InputTensorInfos.empty()) { if (expectedOutputType != expectedInputType) { ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type"); } } else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end()) { ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type"); } ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()), info.m_OutputTensorInfos.end(), [&](auto it){ return it.GetDataType() == expectedOutputType; }), "Trying to create workload with incorrect type"); } } }; template class MultiTypedWorkload : public BaseWorkload { public: MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info) : BaseWorkload(descriptor, info) { ARMNN_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(), info.m_InputTensorInfos.end(), [&](auto it){ return it.GetDataType() == InputDataType; }), "Trying to create workload with incorrect type"); ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(), info.m_OutputTensorInfos.end(), [&](auto it){ return it.GetDataType() == OutputDataType; }), "Trying to create workload with incorrect type"); } }; // FirstInputTypedWorkload used to check type of the first input template class FirstInputTypedWorkload : public BaseWorkload { public: FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info) : BaseWorkload(descriptor, info) { if (!info.m_InputTensorInfos.empty()) { ARMNN_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType, "Trying to create workload with incorrect type"); } ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(), info.m_OutputTensorInfos.end(), [&](auto it){ return it.GetDataType() == DataType; }), "Trying to create workload with incorrect type"); } }; template using FloatWorkload = TypedWorkload; template using Float32Workload = TypedWorkload; template using Uint8Workload = TypedWorkload; template using Int32Workload = TypedWorkload; template using BooleanWorkload = TypedWorkload; template using BaseFloat32ComparisonWorkload = MultiTypedWorkload; template using BaseUint8ComparisonWorkload = MultiTypedWorkload; template using BFloat16ToFloat32Workload = MultiTypedWorkload; template using Float32ToBFloat16Workload = MultiTypedWorkload; template using Float16ToFloat32Workload = MultiTypedWorkload; template using Float32ToFloat16Workload = MultiTypedWorkload; template using Uint8ToFloat32Workload = MultiTypedWorkload; } //namespace armnn