From 0c47974f1800e8770904aecaef15d6f105758c4e Mon Sep 17 00:00:00 2001 From: Colm Donelan Date: Fri, 10 Dec 2021 12:43:54 +0000 Subject: IVGCVSW-6626 Promote backend headers in backendCommon to armnn/backends Move the following header files from backendsCommon to armnn/backends. * MemCopyWorkload.hpp * TensorHandle.hpp * Workload.hpp * WorkloadData.hpp * WorkloadFactory.hpp Replace them with forwarding headers and a pragma deprecation message. Resolve the deprecation messages in Arm NN code. Signed-off-by: Colm Donelan Change-Id: I47f116b30f86e478c9057795bc518c391a8ae514 --- include/armnn/backends/MemCopyWorkload.hpp | 27 + include/armnn/backends/TensorHandle.hpp | 267 ++++++++++ include/armnn/backends/Workload.hpp | 219 ++++++++ include/armnn/backends/WorkloadData.hpp | 769 +++++++++++++++++++++++++++ include/armnn/backends/WorkloadFactory.hpp | 289 ++++++++++ include/armnnTestUtils/WorkloadTestUtils.hpp | 113 ++++ 6 files changed, 1684 insertions(+) create mode 100644 include/armnn/backends/MemCopyWorkload.hpp create mode 100644 include/armnn/backends/TensorHandle.hpp create mode 100644 include/armnn/backends/Workload.hpp create mode 100644 include/armnn/backends/WorkloadData.hpp create mode 100644 include/armnn/backends/WorkloadFactory.hpp create mode 100644 include/armnnTestUtils/WorkloadTestUtils.hpp (limited to 'include') diff --git a/include/armnn/backends/MemCopyWorkload.hpp b/include/armnn/backends/MemCopyWorkload.hpp new file mode 100644 index 0000000000..da23f52be6 --- /dev/null +++ b/include/armnn/backends/MemCopyWorkload.hpp @@ -0,0 +1,27 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "TensorHandle.hpp" +#include "Workload.hpp" + +#include + +namespace armnn +{ + +class CopyMemGenericWorkload : public BaseWorkload +{ +public: + CopyMemGenericWorkload(const MemCopyQueueDescriptor& descriptor, const WorkloadInfo& info); + void Execute() const override; + void ExecuteAsync(WorkingMemDescriptor& descriptor) override; + +private: + using TensorHandlePair = std::pair; + std::vector m_TensorHandlePairs; +}; + +} //namespace armnn diff --git a/include/armnn/backends/TensorHandle.hpp b/include/armnn/backends/TensorHandle.hpp new file mode 100644 index 0000000000..2e6c8485d1 --- /dev/null +++ b/include/armnn/backends/TensorHandle.hpp @@ -0,0 +1,267 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "ITensorHandle.hpp" + +#include +#include +#include + +#include + +namespace armnn +{ + +// Get a TensorShape representing the strides (in bytes) for each dimension +// of a tensor, assuming fully packed data with no padding +TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo); + +// Abstract tensor handles wrapping a readable region of memory, interpreting it as tensor data. +class ConstTensorHandle : public ITensorHandle +{ +public: + template + const T* GetConstTensor() const + { + if (armnnUtils::CompatibleTypes(GetTensorInfo().GetDataType())) + { + return reinterpret_cast(m_Memory); + } + else + { + throw armnn::Exception("Attempting to get not compatible type tensor!"); + } + } + + const TensorInfo& GetTensorInfo() const + { + return m_TensorInfo; + } + + virtual void Manage() override {} + + virtual ITensorHandle* GetParent() const override { return nullptr; } + + virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; } + virtual void Unmap() const override {} + + TensorShape GetStrides() const override + { + return GetUnpaddedTensorStrides(m_TensorInfo); + } + TensorShape GetShape() const override { return m_TensorInfo.GetShape(); } + +protected: + ConstTensorHandle(const TensorInfo& tensorInfo); + + void SetConstMemory(const void* mem) { m_Memory = mem; } + +private: + // Only used for testing + void CopyOutTo(void *) const override { ARMNN_ASSERT_MSG(false, "Unimplemented"); } + void CopyInFrom(const void*) override { ARMNN_ASSERT_MSG(false, "Unimplemented"); } + + ConstTensorHandle(const ConstTensorHandle& other) = delete; + ConstTensorHandle& operator=(const ConstTensorHandle& other) = delete; + + TensorInfo m_TensorInfo; + const void* m_Memory; +}; + +template<> +const void* ConstTensorHandle::GetConstTensor() const; + +// Abstract specialization of ConstTensorHandle that allows write access to the same data. +class TensorHandle : public ConstTensorHandle +{ +public: + template + T* GetTensor() const + { + if (armnnUtils::CompatibleTypes(GetTensorInfo().GetDataType())) + { + return reinterpret_cast(m_MutableMemory); + } + else + { + throw armnn::Exception("Attempting to get not compatible type tensor!"); + } + } + +protected: + TensorHandle(const TensorInfo& tensorInfo); + + void SetMemory(void* mem) + { + m_MutableMemory = mem; + SetConstMemory(m_MutableMemory); + } + +private: + + TensorHandle(const TensorHandle& other) = delete; + TensorHandle& operator=(const TensorHandle& other) = delete; + void* m_MutableMemory; +}; + +template <> +void* TensorHandle::GetTensor() const; + +// A TensorHandle that owns the wrapped memory region. +class ScopedTensorHandle : public TensorHandle +{ +public: + explicit ScopedTensorHandle(const TensorInfo& tensorInfo); + + // Copies contents from Tensor. + explicit ScopedTensorHandle(const ConstTensor& tensor); + + // Copies contents from ConstTensorHandle + explicit ScopedTensorHandle(const ConstTensorHandle& tensorHandle); + + ScopedTensorHandle(const ScopedTensorHandle& other); + ScopedTensorHandle& operator=(const ScopedTensorHandle& other); + ~ScopedTensorHandle(); + + virtual void Allocate() override; + +private: + // Only used for testing + void CopyOutTo(void* memory) const override; + void CopyInFrom(const void* memory) override; + + void CopyFrom(const ScopedTensorHandle& other); + void CopyFrom(const void* srcMemory, unsigned int numBytes); +}; + +// A TensorHandle that wraps an already allocated memory region. +// +// Clients must make sure the passed in memory region stays alive for the lifetime of +// the PassthroughTensorHandle instance. +// +// Note there is no polymorphism to/from ConstPassthroughTensorHandle. +class PassthroughTensorHandle : public TensorHandle +{ +public: + PassthroughTensorHandle(const TensorInfo& tensorInfo, void* mem) + : TensorHandle(tensorInfo) + { + SetMemory(mem); + } + + virtual void Allocate() override; +}; + +// A ConstTensorHandle that wraps an already allocated memory region. +// +// This allows users to pass in const memory to a network. +// Clients must make sure the passed in memory region stays alive for the lifetime of +// the PassthroughTensorHandle instance. +// +// Note there is no polymorphism to/from PassthroughTensorHandle. +class ConstPassthroughTensorHandle : public ConstTensorHandle +{ +public: + ConstPassthroughTensorHandle(const TensorInfo& tensorInfo, const void* mem) + : ConstTensorHandle(tensorInfo) + { + SetConstMemory(mem); + } + + virtual void Allocate() override; +}; + + +// Template specializations. + +template <> +const void* ConstTensorHandle::GetConstTensor() const; + +template <> +void* TensorHandle::GetTensor() const; + +class ManagedConstTensorHandle +{ + +public: + explicit ManagedConstTensorHandle(std::shared_ptr ptr) + : m_Mapped(false) + , m_TensorHandle(std::move(ptr)) {}; + + /// RAII Managed resource Unmaps MemoryArea once out of scope + const void* Map(bool blocking = true) + { + if (m_TensorHandle) + { + auto pRet = m_TensorHandle->Map(blocking); + m_Mapped = true; + return pRet; + } + else + { + throw armnn::Exception("Attempting to Map null TensorHandle"); + } + + } + + // Delete copy constructor as it's unnecessary + ManagedConstTensorHandle(const ConstTensorHandle& other) = delete; + + // Delete copy assignment as it's unnecessary + ManagedConstTensorHandle& operator=(const ManagedConstTensorHandle& other) = delete; + + // Delete move assignment as it's unnecessary + ManagedConstTensorHandle& operator=(ManagedConstTensorHandle&& other) noexcept = delete; + + ~ManagedConstTensorHandle() + { + // Bias tensor handles need to be initialized empty before entering scope of if statement checking if enabled + if (m_TensorHandle) + { + Unmap(); + } + } + + void Unmap() + { + // Only unmap if mapped and TensorHandle exists. + if (m_Mapped && m_TensorHandle) + { + m_TensorHandle->Unmap(); + m_Mapped = false; + } + } + + const TensorInfo& GetTensorInfo() const + { + return m_TensorHandle->GetTensorInfo(); + } + + bool IsMapped() const + { + return m_Mapped; + } + +private: + bool m_Mapped; + std::shared_ptr m_TensorHandle; +}; + +using ConstCpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("ConstCpuTensorHandle is deprecated, " + "use ConstTensorHandle instead", "22.05") = ConstTensorHandle; +using CpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("CpuTensorHandle is deprecated, " + "use TensorHandle instead", "22.05") = TensorHandle; +using ScopedCpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("ScopedCpuTensorHandle is deprecated, " + "use ScopedTensorHandle instead", "22.05") = ScopedTensorHandle; +using PassthroughCpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("PassthroughCpuTensorHandle is deprecated, use " + "PassthroughTensorHandle instead", + "22.05") = PassthroughTensorHandle; +using ConstPassthroughCpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("ConstPassthroughCpuTensorHandle is " + "deprecated, use ConstPassthroughTensorHandle " + "instead", "22.05") = ConstPassthroughTensorHandle; + +} // namespace armnn diff --git a/include/armnn/backends/Workload.hpp b/include/armnn/backends/Workload.hpp new file mode 100644 index 0000000000..7c1bda50bc --- /dev/null +++ b/include/armnn/backends/Workload.hpp @@ -0,0 +1,219 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "IWorkload.hpp" +#include "WorkloadData.hpp" +#include "WorkloadInfo.hpp" +#include "WorkingMemDescriptor.hpp" + +#include +#include + +#include + +namespace armnn +{ + +// NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template +// in the various workload factories. +// There should never be an instantiation of a NullWorkload. +class NullWorkload : public IWorkload +{ + NullWorkload()=delete; +}; + +template +class BaseWorkload : public IWorkload +{ +public: + + BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info) + : m_Data(descriptor), + m_Guid(profiling::ProfilingService::GetNextGuid()) + { + m_Data.Validate(info); + } + + void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override + { + ARMNN_LOG(info) << "Using default async workload execution, this will network affect performance"; + std::lock_guard lockGuard(m_AsyncWorkloadMutex); + + m_Data.m_Inputs = workingMemDescriptor.m_Inputs; + m_Data.m_Outputs = workingMemDescriptor.m_Outputs; + + Execute(); + }; + + void PostAllocationConfigure() override {} + + const QueueDescriptor& GetData() const { return m_Data; } + + profiling::ProfilingGuid GetGuid() const final { return m_Guid; } + +protected: + QueueDescriptor m_Data; + const profiling::ProfilingGuid m_Guid; + +private: + std::mutex m_AsyncWorkloadMutex; +}; + +// TypedWorkload used +template +class TypedWorkload : public BaseWorkload +{ +public: + + TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info) + : BaseWorkload(descriptor, info) + { + std::vector dataTypes = {DataTypes...}; + armnn::DataType expectedInputType; + + if (!info.m_InputTensorInfos.empty()) + { + expectedInputType = info.m_InputTensorInfos.front().GetDataType(); + + if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end()) + { + ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type"); + } + ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()), + info.m_InputTensorInfos.end(), + [&](auto it){ + return it.GetDataType() == expectedInputType; + }), + "Trying to create workload with incorrect type"); + } + armnn::DataType expectedOutputType; + + if (!info.m_OutputTensorInfos.empty()) + { + expectedOutputType = info.m_OutputTensorInfos.front().GetDataType(); + + if (!info.m_InputTensorInfos.empty()) + { + if (expectedOutputType != expectedInputType) + { + ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type"); + } + } + else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end()) + { + ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type"); + } + ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()), + info.m_OutputTensorInfos.end(), + [&](auto it){ + return it.GetDataType() == expectedOutputType; + }), + "Trying to create workload with incorrect type"); + } + } +}; + +template +class MultiTypedWorkload : public BaseWorkload +{ +public: + + MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info) + : BaseWorkload(descriptor, info) + { + ARMNN_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(), + info.m_InputTensorInfos.end(), + [&](auto it){ + return it.GetDataType() == InputDataType; + }), + "Trying to create workload with incorrect type"); + + ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(), + info.m_OutputTensorInfos.end(), + [&](auto it){ + return it.GetDataType() == OutputDataType; + }), + "Trying to create workload with incorrect type"); + } +}; + +// FirstInputTypedWorkload used to check type of the first input +template +class FirstInputTypedWorkload : public BaseWorkload +{ +public: + + FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info) + : BaseWorkload(descriptor, info) + { + if (!info.m_InputTensorInfos.empty()) + { + ARMNN_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType, + "Trying to create workload with incorrect type"); + } + + ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(), + info.m_OutputTensorInfos.end(), + [&](auto it){ + return it.GetDataType() == DataType; + }), + "Trying to create workload with incorrect type"); + } +}; + +template +using FloatWorkload = TypedWorkload; + +template +using Float32Workload = TypedWorkload; + +template +using Uint8Workload = TypedWorkload; + +template +using Int32Workload = TypedWorkload; + +template +using BooleanWorkload = TypedWorkload; + +template +using BaseFloat32ComparisonWorkload = MultiTypedWorkload; + +template +using BaseUint8ComparisonWorkload = MultiTypedWorkload; + +template +using BFloat16ToFloat32Workload = MultiTypedWorkload; + +template +using Float32ToBFloat16Workload = MultiTypedWorkload; + +template +using Float16ToFloat32Workload = MultiTypedWorkload; + +template +using Float32ToFloat16Workload = MultiTypedWorkload; + +template +using Uint8ToFloat32Workload = MultiTypedWorkload; + +} //namespace armnn diff --git a/include/armnn/backends/WorkloadData.hpp b/include/armnn/backends/WorkloadData.hpp new file mode 100644 index 0000000000..7406547216 --- /dev/null +++ b/include/armnn/backends/WorkloadData.hpp @@ -0,0 +1,769 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "TensorHandle.hpp" + +#include +#include +#include +#include +#include +#include + +namespace armnn +{ + +//A helper function that returns the bias data type required for given input data type. +DataType GetBiasDataType(DataType inputDataType); + +struct WorkloadInfo; + +struct QueueDescriptor +{ + std::vector m_Inputs; + std::vector m_Outputs; + void* m_AdditionalInfoObject; + + void ValidateInputsOutputs(const std::string& descName, + unsigned int numExpectedIn, + unsigned int numExpectedOut) const; + + template + const T* GetAdditionalInformation() const + { + return static_cast(m_AdditionalInfoObject); + } + +protected: + ~QueueDescriptor() = default; + QueueDescriptor() + : m_AdditionalInfoObject(nullptr) + {} + QueueDescriptor(QueueDescriptor const&) = default; + QueueDescriptor& operator=(QueueDescriptor const&) = default; +}; + +// Base class for queue descriptors which contain parameters. +template +struct QueueDescriptorWithParameters : public QueueDescriptor +{ + LayerDescriptor m_Parameters; + +protected: + ~QueueDescriptorWithParameters() = default; + QueueDescriptorWithParameters() = default; + QueueDescriptorWithParameters(QueueDescriptorWithParameters const&) = default; + QueueDescriptorWithParameters& operator=(QueueDescriptorWithParameters const&) = default; +}; + +struct MapQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct UnmapQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct MemCopyQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +using InputQueueDescriptor = MemCopyQueueDescriptor; +using OutputQueueDescriptor = MemCopyQueueDescriptor; + +struct MemImportQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct MemSyncQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Softmax layer workload data. +struct SoftmaxQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Splitter layer workload data. +struct SplitterQueueDescriptor : QueueDescriptorWithParameters +{ + struct ViewOrigin + { + ViewOrigin() {} + ViewOrigin(std::vector const& origin) : m_Origin(origin) {} + + //View origin (size of the vector is the same as number of dimensions of the view). + std::vector m_Origin; + }; + + //View defines a tensor that will be carved from the input tensor. + //View origins are stored here, the extents are defined by sizes of the output tensors. + std::vector m_ViewOrigins; + + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Concat layer workload data. +struct ConcatQueueDescriptor : QueueDescriptorWithParameters +{ + struct ViewOrigin + { + ViewOrigin() {} + ViewOrigin(const std::vector& origin) : m_Origin(origin) {} + + //View origin (size of the vector is the same as number of dimensions of the view). + std::vector m_Origin; + }; + + //View defines a sub-area of the output tensor that will be filled with the corresponding input tensor. + //View origins are stored here, the extents are defined by sizes of the input tensors. + std::vector m_ViewOrigins; + + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Deprecated. Use ConcatQueueDescriptor instead +using MergerQueueDescriptor = ConcatQueueDescriptor; + +// Stack layer workload data. +struct StackQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Activation layer workload data. +struct ActivationQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct ArgMinMaxQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct CastQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Fill layer workload data. +struct FillQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Fully connected layer workload data. +struct FullyConnectedQueueDescriptor : QueueDescriptorWithParameters +{ + FullyConnectedQueueDescriptor() + : m_Weight(nullptr) + , m_Bias(nullptr) + { + } + + const ConstTensorHandle* m_Weight; + const ConstTensorHandle* m_Bias; + + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Permute layer workload data. +struct PermuteQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Pooling 2D layer workload data. +struct Pooling2dQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Pooling 3D layer workload data. +struct Pooling3dQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + + +// Convolution 2D layer workload data. +struct Convolution2dQueueDescriptor : QueueDescriptorWithParameters +{ + Convolution2dQueueDescriptor() + : m_Weight(nullptr) + , m_Bias(nullptr) + { + } + + const ConstTensorHandle* m_Weight; + const ConstTensorHandle* m_Bias; + + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Convolution 3D layer workload data. +struct Convolution3dQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +/// Depthwise Convolution 2D layer workload data. +/// +/// @note +/// The weights are in the format [1, H, W, I*M]. Where I is the input channel size, M the depthwise mutliplier and +/// H, W is the height and width of the filter kernel. If per channel quantization is applied +/// the weights will be quantized along the last dimension/axis (I*M) which corresponds to the output channel size. +/// If per channel quantization is applied the weights tensor will have I*M scales, one for each dimension +/// of the quantization axis. You have to be aware of this when reshaping the weights tensor. +/// Splitting the I*M axis, e.g. [1, H, W, I*M] --> [H, W, I, M], won't work without taking care of the +/// corresponding quantization scales. +/// If there is no per channel quantization applied reshaping the weights tensor won't cause any issues. There are +/// preconfigured permutation functions available @link WorkloadUtils.hpp here. +/// +struct DepthwiseConvolution2dQueueDescriptor : QueueDescriptorWithParameters +{ + DepthwiseConvolution2dQueueDescriptor() + : m_Weight(nullptr) + , m_Bias(nullptr) + { + } + + const ConstTensorHandle* m_Weight; + const ConstTensorHandle* m_Bias; + + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct DetectionPostProcessQueueDescriptor : QueueDescriptorWithParameters +{ + DetectionPostProcessQueueDescriptor() + : m_Anchors(nullptr) + { + } + + const ConstTensorHandle* m_Anchors; + + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Normalization layer workload data. +struct NormalizationQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Add layer workload data. +struct AdditionQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Multiplication layer workload data. +struct MultiplicationQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Division layer workload data. +struct DivisionQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Subtraction layer workload data. +struct SubtractionQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Maximum layer workload data. +struct MaximumQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Mean layer workload data. +struct MeanQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Pad layer workload data +struct PadQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct QuantizeQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Deprecated use ComparisonQueueDescriptor instead +struct EqualQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Batch norm layer workload data. +struct BatchNormalizationQueueDescriptor : QueueDescriptorWithParameters +{ + BatchNormalizationQueueDescriptor() + : m_Mean(nullptr) + , m_Variance(nullptr) + , m_Beta(nullptr) + , m_Gamma(nullptr) + { + } + + const ConstTensorHandle* m_Mean; + const ConstTensorHandle* m_Variance; + const ConstTensorHandle* m_Beta; + const ConstTensorHandle* m_Gamma; + + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct RankQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +ARMNN_NO_DEPRECATE_WARN_BEGIN +struct +ARMNN_DEPRECATED_MSG_REMOVAL_DATE("ResizeBilinearQueueDescriptor is deprecated use ResizeQueueDescriptor instead", + "22.08") +ResizeBilinearQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; +ARMNN_NO_DEPRECATE_WARN_END + +struct ResizeQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct FakeQuantizationQueueDescriptor : QueueDescriptorWithParameters +{ + FakeQuantizationQueueDescriptor() + : m_Min(nullptr) + , m_Max(nullptr) + { + } + + const ConstTensorHandle* m_Min; + const ConstTensorHandle* m_Max; + + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct InstanceNormalizationQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct L2NormalizationQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct LogSoftmaxQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct ConstantQueueDescriptor : QueueDescriptor +{ + ConstantQueueDescriptor() + : m_LayerOutput(nullptr) + { + } + + const ConstTensorHandle* m_LayerOutput; + + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct ReshapeQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct SpaceToBatchNdQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct SpaceToDepthQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct FloorQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct LstmQueueDescriptor : QueueDescriptorWithParameters +{ + LstmQueueDescriptor() + : m_InputToInputWeights(nullptr) + , m_InputToForgetWeights(nullptr) + , m_InputToCellWeights(nullptr) + , m_InputToOutputWeights(nullptr) + , m_RecurrentToInputWeights(nullptr) + , m_RecurrentToForgetWeights(nullptr) + , m_RecurrentToCellWeights(nullptr) + , m_RecurrentToOutputWeights(nullptr) + , m_CellToInputWeights(nullptr) + , m_CellToForgetWeights(nullptr) + , m_CellToOutputWeights(nullptr) + , m_InputGateBias(nullptr) + , m_ForgetGateBias(nullptr) + , m_CellBias(nullptr) + , m_OutputGateBias(nullptr) + , m_ProjectionWeights(nullptr) + , m_ProjectionBias(nullptr) + , m_InputLayerNormWeights(nullptr) + , m_ForgetLayerNormWeights(nullptr) + , m_CellLayerNormWeights(nullptr) + , m_OutputLayerNormWeights(nullptr) + { + } + + const ConstTensorHandle* m_InputToInputWeights; + const ConstTensorHandle* m_InputToForgetWeights; + const ConstTensorHandle* m_InputToCellWeights; + const ConstTensorHandle* m_InputToOutputWeights; + const ConstTensorHandle* m_RecurrentToInputWeights; + const ConstTensorHandle* m_RecurrentToForgetWeights; + const ConstTensorHandle* m_RecurrentToCellWeights; + const ConstTensorHandle* m_RecurrentToOutputWeights; + const ConstTensorHandle* m_CellToInputWeights; + const ConstTensorHandle* m_CellToForgetWeights; + const ConstTensorHandle* m_CellToOutputWeights; + const ConstTensorHandle* m_InputGateBias; + const ConstTensorHandle* m_ForgetGateBias; + const ConstTensorHandle* m_CellBias; + const ConstTensorHandle* m_OutputGateBias; + const ConstTensorHandle* m_ProjectionWeights; + const ConstTensorHandle* m_ProjectionBias; + const ConstTensorHandle* m_InputLayerNormWeights; + const ConstTensorHandle* m_ForgetLayerNormWeights; + const ConstTensorHandle* m_CellLayerNormWeights; + const ConstTensorHandle* m_OutputLayerNormWeights; + + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct ConvertBf16ToFp32QueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct ConvertFp32ToBf16QueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct ConvertFp16ToFp32QueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct ConvertFp32ToFp16QueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct BatchToSpaceNdQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct StridedSliceQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Minimum layer workload data. +struct MinimumQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +// Deprecated use ComparisonQueueDescriptor instead +struct GreaterQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct DebugQueueDescriptor : QueueDescriptor +{ + DebugQueueDescriptor() : m_Guid(0) {} + + void Validate(const WorkloadInfo& workloadInfo) const; + + LayerGuid m_Guid; + std::string m_LayerName; + unsigned int m_SlotIndex; +}; + +struct RsqrtQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct GatherQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct PreCompiledQueueDescriptor : QueueDescriptorWithParameters +{ + PreCompiledQueueDescriptor() + : m_PreCompiledObject(nullptr) + { + } + + void* m_PreCompiledObject; + + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct DequantizeQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct MergeQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct SwitchQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct PreluQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct TransposeConvolution2dQueueDescriptor : QueueDescriptorWithParameters +{ + TransposeConvolution2dQueueDescriptor() : + m_Weight(nullptr), + m_Bias(nullptr) + {} + + const ConstTensorHandle* m_Weight; + const ConstTensorHandle* m_Bias; + + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct TransposeQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct QLstmQueueDescriptor : QueueDescriptorWithParameters +{ + QLstmQueueDescriptor() + : m_InputToInputWeights(nullptr) + , m_InputToForgetWeights(nullptr) + , m_InputToCellWeights(nullptr) + , m_InputToOutputWeights(nullptr) + , m_RecurrentToInputWeights(nullptr) + , m_RecurrentToForgetWeights(nullptr) + , m_RecurrentToCellWeights(nullptr) + , m_RecurrentToOutputWeights(nullptr) + , m_CellToInputWeights(nullptr) + , m_CellToForgetWeights(nullptr) + , m_CellToOutputWeights(nullptr) + , m_InputGateBias(nullptr) + , m_ForgetGateBias(nullptr) + , m_CellBias(nullptr) + , m_OutputGateBias(nullptr) + , m_ProjectionWeights(nullptr) + , m_ProjectionBias(nullptr) + , m_InputLayerNormWeights(nullptr) + , m_ForgetLayerNormWeights(nullptr) + , m_CellLayerNormWeights(nullptr) + , m_OutputLayerNormWeights(nullptr) + { + } + + const ConstTensorHandle* m_InputToInputWeights; + const ConstTensorHandle* m_InputToForgetWeights; + const ConstTensorHandle* m_InputToCellWeights; + const ConstTensorHandle* m_InputToOutputWeights; + const ConstTensorHandle* m_RecurrentToInputWeights; + const ConstTensorHandle* m_RecurrentToForgetWeights; + const ConstTensorHandle* m_RecurrentToCellWeights; + const ConstTensorHandle* m_RecurrentToOutputWeights; + const ConstTensorHandle* m_CellToInputWeights; + const ConstTensorHandle* m_CellToForgetWeights; + const ConstTensorHandle* m_CellToOutputWeights; + const ConstTensorHandle* m_InputGateBias; + const ConstTensorHandle* m_ForgetGateBias; + const ConstTensorHandle* m_CellBias; + const ConstTensorHandle* m_OutputGateBias; + const ConstTensorHandle* m_ProjectionWeights; + const ConstTensorHandle* m_ProjectionBias; + const ConstTensorHandle* m_InputLayerNormWeights; + const ConstTensorHandle* m_ForgetLayerNormWeights; + const ConstTensorHandle* m_CellLayerNormWeights; + const ConstTensorHandle* m_OutputLayerNormWeights; + + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct QuantizedLstmQueueDescriptor : QueueDescriptor +{ + QuantizedLstmQueueDescriptor() + : m_InputToInputWeights(nullptr) + , m_InputToForgetWeights(nullptr) + , m_InputToCellWeights(nullptr) + , m_InputToOutputWeights(nullptr) + + , m_RecurrentToInputWeights(nullptr) + , m_RecurrentToForgetWeights(nullptr) + , m_RecurrentToCellWeights(nullptr) + , m_RecurrentToOutputWeights(nullptr) + + , m_InputGateBias(nullptr) + , m_ForgetGateBias(nullptr) + , m_CellBias(nullptr) + , m_OutputGateBias(nullptr) + {} + + const ConstTensorHandle* m_InputToInputWeights; + const ConstTensorHandle* m_InputToForgetWeights; + const ConstTensorHandle* m_InputToCellWeights; + const ConstTensorHandle* m_InputToOutputWeights; + + const ConstTensorHandle* m_RecurrentToInputWeights; + const ConstTensorHandle* m_RecurrentToForgetWeights; + const ConstTensorHandle* m_RecurrentToCellWeights; + const ConstTensorHandle* m_RecurrentToOutputWeights; + + const ConstTensorHandle* m_InputGateBias; + const ConstTensorHandle* m_ForgetGateBias; + const ConstTensorHandle* m_CellBias; + const ConstTensorHandle* m_OutputGateBias; + + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct AbsQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct SliceQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct DepthToSpaceQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct ComparisonQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct ElementwiseUnaryQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct LogicalBinaryQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct ReduceQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct ShapeQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct UnidirectionalSequenceLstmQueueDescriptor : QueueDescriptorWithParameters +{ + UnidirectionalSequenceLstmQueueDescriptor() + : m_InputToInputWeights(nullptr) + , m_InputToForgetWeights(nullptr) + , m_InputToCellWeights(nullptr) + , m_InputToOutputWeights(nullptr) + , m_RecurrentToInputWeights(nullptr) + , m_RecurrentToForgetWeights(nullptr) + , m_RecurrentToCellWeights(nullptr) + , m_RecurrentToOutputWeights(nullptr) + , m_CellToInputWeights(nullptr) + , m_CellToForgetWeights(nullptr) + , m_CellToOutputWeights(nullptr) + , m_InputGateBias(nullptr) + , m_ForgetGateBias(nullptr) + , m_CellBias(nullptr) + , m_OutputGateBias(nullptr) + , m_ProjectionWeights(nullptr) + , m_ProjectionBias(nullptr) + , m_InputLayerNormWeights(nullptr) + , m_ForgetLayerNormWeights(nullptr) + , m_CellLayerNormWeights(nullptr) + , m_OutputLayerNormWeights(nullptr) + { + } + + const ConstTensorHandle* m_InputToInputWeights; + const ConstTensorHandle* m_InputToForgetWeights; + const ConstTensorHandle* m_InputToCellWeights; + const ConstTensorHandle* m_InputToOutputWeights; + const ConstTensorHandle* m_RecurrentToInputWeights; + const ConstTensorHandle* m_RecurrentToForgetWeights; + const ConstTensorHandle* m_RecurrentToCellWeights; + const ConstTensorHandle* m_RecurrentToOutputWeights; + const ConstTensorHandle* m_CellToInputWeights; + const ConstTensorHandle* m_CellToForgetWeights; + const ConstTensorHandle* m_CellToOutputWeights; + const ConstTensorHandle* m_InputGateBias; + const ConstTensorHandle* m_ForgetGateBias; + const ConstTensorHandle* m_CellBias; + const ConstTensorHandle* m_OutputGateBias; + const ConstTensorHandle* m_ProjectionWeights; + const ConstTensorHandle* m_ProjectionBias; + const ConstTensorHandle* m_InputLayerNormWeights; + const ConstTensorHandle* m_ForgetLayerNormWeights; + const ConstTensorHandle* m_CellLayerNormWeights; + const ConstTensorHandle* m_OutputLayerNormWeights; + + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct ChannelShuffleQueueDescriptor : QueueDescriptorWithParameters +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +} // namespace armnn diff --git a/include/armnn/backends/WorkloadFactory.hpp b/include/armnn/backends/WorkloadFactory.hpp new file mode 100644 index 0000000000..68ad2e3741 --- /dev/null +++ b/include/armnn/backends/WorkloadFactory.hpp @@ -0,0 +1,289 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "ITensorHandle.hpp" +#include "Workload.hpp" + +#include +#include +#include + +#include + +namespace armnn +{ + +class Layer; + +// Workload factory interface for compute backends. +class IWorkloadFactory +{ +public: + virtual ~IWorkloadFactory() { } + + virtual void AfterWorkloadsCreated() {}; + + virtual const BackendId& GetBackendId() const = 0; + + static bool IsLayerSupported(const BackendId& backendId, + const IConnectableLayer& layer, + Optional dataType, + std::string& outReasonIfUnsupported); + + static bool IsLayerSupported(const IConnectableLayer& layer, + Optional dataType, + std::string& outReasonIfUnsupported); + + static bool IsLayerSupported(const IConnectableLayer& layer, + Optional dataType, + std::string& outReasonIfUnsupported, + const ModelOptions& modelOptions); + + static bool IsLayerSupported(const BackendId& backendId, + const IConnectableLayer& layer, + Optional dataType, + std::string& outReasonIfUnsupported, + const ModelOptions& modelOptions); + + virtual bool SupportsSubTensors() const = 0; + + ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateSubTensorHandle instead") + virtual std::unique_ptr CreateSubTensorHandle(ITensorHandle& parent, + TensorShape const& subTensorShape, + unsigned int const* subTensorOrigin + ) const = 0; + + virtual std::unique_ptr CreateInput(const InputQueueDescriptor& descriptor, + const WorkloadInfo& info) const = 0; + + ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead") + virtual std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo, + const bool IsMemoryManaged = true) const = 0; + + ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead") + virtual std::unique_ptr CreateTensorHandle(const TensorInfo& tensorInfo, + DataLayout dataLayout, + const bool IsMemoryManaged = true) const = 0; + + virtual std::unique_ptr CreateActivation(const ActivationQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateAddition(const AdditionQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateArgMinMax(const ArgMinMaxQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateBatchNormalization(const BatchNormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor, + const WorkloadInfo& Info) const; + + virtual std::unique_ptr CreateCast(const CastQueueDescriptor& descriptor, + const WorkloadInfo& Info) const; + + virtual std::unique_ptr CreateChannelShuffle(const ChannelShuffleQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateComparison(const ComparisonQueueDescriptor& descriptor, + const WorkloadInfo& Info) const; + + virtual std::unique_ptr CreateConcat(const ConcatQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateConstant(const ConstantQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateConvolution3d(const Convolution3dQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateDebug(const DebugQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateDepthToSpace(const DepthToSpaceQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateDepthwiseConvolution2d( + const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateDequantize(const DequantizeQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateDetectionPostProcess( + const DetectionPostProcessQueueDescriptor& descriptor, const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateDivision(const DivisionQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& descriptor, + const WorkloadInfo& Info) const; + + virtual std::unique_ptr CreateFakeQuantization(const FakeQuantizationQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateFill(const FillQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateFloor(const FloorQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateGather(const GatherQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateInstanceNormalization( + const InstanceNormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateLogicalBinary(const LogicalBinaryQueueDescriptor& descriptor, + const WorkloadInfo& Info) const; + + virtual std::unique_ptr CreateLogicalUnary(const ElementwiseUnaryQueueDescriptor& descriptor, + const WorkloadInfo& Info) const; + + virtual std::unique_ptr CreateLogSoftmax(const LogSoftmaxQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateLstm(const LstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateMaximum(const MaximumQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateMean(const MeanQueueDescriptor& descriptor, + const WorkloadInfo& Info) const; + + virtual std::unique_ptr CreateMemCopy(const MemCopyQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateMemImport(const MemImportQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateMerge(const MergeQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateMinimum(const MinimumQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateMultiplication(const MultiplicationQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateNormalization(const NormalizationQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateOutput(const OutputQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreatePad(const PadQueueDescriptor& descriptor, + const WorkloadInfo& Info) const; + + virtual std::unique_ptr CreatePermute(const PermuteQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreatePooling2d(const Pooling2dQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreatePooling3d(const Pooling3dQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreatePrelu(const PreluQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateQuantize(const QuantizeQueueDescriptor& descriptor, + const WorkloadInfo& Info) const; + + virtual std::unique_ptr CreateQLstm(const QLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateRank(const RankQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateReduce(const ReduceQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateReshape(const ReshapeQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateResize(const ResizeQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateShape(const ShapeQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateSlice(const SliceQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateSoftmax(const SoftmaxQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateSubtraction(const SubtractionQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateSplitter(const SplitterQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateStack(const StackQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor, + const WorkloadInfo& Info) const; + + virtual std::unique_ptr CreateSwitch(const SwitchQueueDescriptor& descriptor, + const WorkloadInfo& Info) const; + + virtual std::unique_ptr CreateTranspose(const TransposeQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateTransposeConvolution2d( + const TransposeConvolution2dQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + + virtual std::unique_ptr CreateUnidirectionalSequenceLstm( + const UnidirectionalSequenceLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + +private: + static bool IsLayerConfigurationSupported(const BackendId& backendId, + const IConnectableLayer& connectableLayer, + Optional dataType, + std::string& outReasonIfUnsupported, + const ModelOptions& modelOptions = {}); +}; + +} // namespace armnn diff --git a/include/armnnTestUtils/WorkloadTestUtils.hpp b/include/armnnTestUtils/WorkloadTestUtils.hpp new file mode 100644 index 0000000000..156258a549 --- /dev/null +++ b/include/armnnTestUtils/WorkloadTestUtils.hpp @@ -0,0 +1,113 @@ +// +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include + +#include +#include +#include +#include + +namespace armnn +{ +class ITensorHandle; +} // namespace armnn + +namespace +{ + +template +void AddInputToWorkload(QueueDescriptor& descriptor, + armnn::WorkloadInfo& info, + const armnn::TensorInfo& tensorInfo, + armnn::ITensorHandle* tensorHandle) +{ + descriptor.m_Inputs.push_back(tensorHandle); + info.m_InputTensorInfos.push_back(tensorInfo); +} + +template +void AddOutputToWorkload(QueueDescriptor& descriptor, + armnn::WorkloadInfo& info, + const armnn::TensorInfo& tensorInfo, + armnn::ITensorHandle* tensorHandle) +{ + descriptor.m_Outputs.push_back(tensorHandle); + info.m_OutputTensorInfos.push_back(tensorInfo); +} + +template +void SetWorkloadInput(QueueDescriptor& descriptor, + armnn::WorkloadInfo& info, + unsigned int index, + const armnn::TensorInfo& tensorInfo, + armnn::ITensorHandle* tensorHandle) +{ + descriptor.m_Inputs[index] = tensorHandle; + info.m_InputTensorInfos[index] = tensorInfo; +} + +template +void SetWorkloadOutput(QueueDescriptor& descriptor, + armnn::WorkloadInfo& info, + unsigned int index, + const armnn::TensorInfo& tensorInfo, + armnn::ITensorHandle* tensorHandle) +{ + descriptor.m_Outputs[index] = tensorHandle; + info.m_OutputTensorInfos[index] = tensorInfo; +} + +inline void ExecuteWorkload(armnn::IWorkload& workload, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + bool memoryManagementRequested = true) +{ + const bool manageMemory = memoryManager && memoryManagementRequested; + + // Acquire working memory (if needed) + if (manageMemory) + { + memoryManager->Acquire(); + } + + // Perform PostAllocationConfiguration + workload.PostAllocationConfigure(); + + // Execute the workload + workload.Execute(); + + // Release working memory (if needed) + if (manageMemory) + { + memoryManager->Release(); + } +} + +inline armnn::Optional GetBiasTypeFromWeightsType(armnn::Optional weightsType) +{ + if (!weightsType) + { + return weightsType; + } + + switch(weightsType.value()) + { + case armnn::DataType::BFloat16: + case armnn::DataType::Float16: + case armnn::DataType::Float32: + return weightsType; + case armnn::DataType::QAsymmS8: + case armnn::DataType::QAsymmU8: + case armnn::DataType::QSymmS8: + case armnn::DataType::QSymmS16: + return armnn::DataType::Signed32; + default: + ARMNN_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type."); + } + return armnn::EmptyOptional(); +} + +} // anonymous namespace -- cgit v1.2.1