aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/Workload.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/Workload.hpp')
-rw-r--r--src/backends/backendsCommon/Workload.hpp219
1 files changed, 4 insertions, 215 deletions
diff --git a/src/backends/backendsCommon/Workload.hpp b/src/backends/backendsCommon/Workload.hpp
index 87869c9841..00b6bfe4a7 100644
--- a/src/backends/backendsCommon/Workload.hpp
+++ b/src/backends/backendsCommon/Workload.hpp
@@ -1,219 +1,8 @@
//
-// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
-#pragma once
-#include "WorkloadData.hpp"
-#include "WorkloadInfo.hpp"
-#include "WorkingMemDescriptor.hpp"
-
-#include <armnn/backends/IWorkload.hpp>
-#include <Profiling.hpp>
-#include <ProfilingService.hpp>
-
-#include <algorithm>
-
-namespace armnn
-{
-
-// NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
-// in the various workload factories.
-// There should never be an instantiation of a NullWorkload.
-class NullWorkload : public IWorkload
-{
- NullWorkload()=delete;
-};
-
-template <typename QueueDescriptor>
-class BaseWorkload : public IWorkload
-{
-public:
-
- BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
- : m_Data(descriptor),
- m_Guid(profiling::ProfilingService::GetNextGuid())
- {
- m_Data.Validate(info);
- }
-
- void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override
- {
- ARMNN_LOG(info) << "Using default async workload execution, this will network affect performance";
- std::lock_guard<std::mutex> lockGuard(m_AsyncWorkloadMutex);
-
- m_Data.m_Inputs = workingMemDescriptor.m_Inputs;
- m_Data.m_Outputs = workingMemDescriptor.m_Outputs;
-
- Execute();
- };
-
- void PostAllocationConfigure() override {}
-
- const QueueDescriptor& GetData() const { return m_Data; }
-
- profiling::ProfilingGuid GetGuid() const final { return m_Guid; }
-
-protected:
- QueueDescriptor m_Data;
- const profiling::ProfilingGuid m_Guid;
-
-private:
- std::mutex m_AsyncWorkloadMutex;
-};
-
-// TypedWorkload used
-template <typename QueueDescriptor, armnn::DataType... DataTypes>
-class TypedWorkload : public BaseWorkload<QueueDescriptor>
-{
-public:
-
- TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
- : BaseWorkload<QueueDescriptor>(descriptor, info)
- {
- std::vector<armnn::DataType> dataTypes = {DataTypes...};
- armnn::DataType expectedInputType;
-
- if (!info.m_InputTensorInfos.empty())
- {
- expectedInputType = info.m_InputTensorInfos.front().GetDataType();
-
- if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
- {
- ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
- }
- ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
- info.m_InputTensorInfos.end(),
- [&](auto it){
- return it.GetDataType() == expectedInputType;
- }),
- "Trying to create workload with incorrect type");
- }
- armnn::DataType expectedOutputType;
-
- if (!info.m_OutputTensorInfos.empty())
- {
- expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
-
- if (!info.m_InputTensorInfos.empty())
- {
- if (expectedOutputType != expectedInputType)
- {
- ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
- }
- }
- else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
- {
- ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
- }
- ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
- info.m_OutputTensorInfos.end(),
- [&](auto it){
- return it.GetDataType() == expectedOutputType;
- }),
- "Trying to create workload with incorrect type");
- }
- }
-};
-
-template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
-class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
-{
-public:
-
- MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
- : BaseWorkload<QueueDescriptor>(descriptor, info)
- {
- ARMNN_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
- info.m_InputTensorInfos.end(),
- [&](auto it){
- return it.GetDataType() == InputDataType;
- }),
- "Trying to create workload with incorrect type");
-
- ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
- info.m_OutputTensorInfos.end(),
- [&](auto it){
- return it.GetDataType() == OutputDataType;
- }),
- "Trying to create workload with incorrect type");
- }
-};
-
-// FirstInputTypedWorkload used to check type of the first input
-template <typename QueueDescriptor, armnn::DataType DataType>
-class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
-{
-public:
-
- FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
- : BaseWorkload<QueueDescriptor>(descriptor, info)
- {
- if (!info.m_InputTensorInfos.empty())
- {
- ARMNN_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType,
- "Trying to create workload with incorrect type");
- }
-
- ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
- info.m_OutputTensorInfos.end(),
- [&](auto it){
- return it.GetDataType() == DataType;
- }),
- "Trying to create workload with incorrect type");
- }
-};
-
-template <typename QueueDescriptor>
-using FloatWorkload = TypedWorkload<QueueDescriptor,
- armnn::DataType::Float16,
- armnn::DataType::Float32>;
-
-template <typename QueueDescriptor>
-using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>;
-
-template <typename QueueDescriptor>
-using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QAsymmU8>;
-
-template <typename QueueDescriptor>
-using Int32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Signed32>;
-
-template <typename QueueDescriptor>
-using BooleanWorkload = TypedWorkload<QueueDescriptor, armnn::DataType::Boolean>;
-
-template <typename QueueDescriptor>
-using BaseFloat32ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
- armnn::DataType::Float32,
- armnn::DataType::Boolean>;
-
-template <typename QueueDescriptor>
-using BaseUint8ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
- armnn::DataType::QAsymmU8,
- armnn::DataType::Boolean>;
-
-template <typename QueueDescriptor>
-using BFloat16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
- armnn::DataType::BFloat16,
- armnn::DataType::Float32>;
-
-template <typename QueueDescriptor>
-using Float32ToBFloat16Workload = MultiTypedWorkload<QueueDescriptor,
- armnn::DataType::Float32,
- armnn::DataType::BFloat16>;
-
-template <typename QueueDescriptor>
-using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
- armnn::DataType::Float16,
- armnn::DataType::Float32>;
-
-template <typename QueueDescriptor>
-using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
- armnn::DataType::Float32,
- armnn::DataType::Float16>;
-
-template <typename QueueDescriptor>
-using Uint8ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
- armnn::DataType::QAsymmU8,
- armnn::DataType::Float32>;
-
-} //namespace armnn
+#include <armnn/backends/Workload.hpp>
+#pragma message("src/backends/backendsCommon/Workload.hpp has been deprecated, it is due for removal in"\
+ " 22.08 release. Please use public interface include/armnn/backends/Workload.hpp")