aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorColm Donelan <colm.donelan@arm.com>2021-12-10 12:43:54 +0000
committerColm Donelan <colm.donelan@arm.com>2021-12-15 12:53:20 +0000
commit0c47974f1800e8770904aecaef15d6f105758c4e (patch)
treef5424858c6fe6f33376b3432580179958ab8ac5a /include
parentcdbb09f6e15ea6698a62197cf76ecba87b81cb9d (diff)
downloadarmnn-0c47974f1800e8770904aecaef15d6f105758c4e.tar.gz
IVGCVSW-6626 Promote backend headers in backendCommon to armnn/backends
Move the following header files from backendsCommon to armnn/backends. * MemCopyWorkload.hpp * TensorHandle.hpp * Workload.hpp * WorkloadData.hpp * WorkloadFactory.hpp Replace them with forwarding headers and a pragma deprecation message. Resolve the deprecation messages in Arm NN code. Signed-off-by: Colm Donelan <colm.donelan@arm.com> Change-Id: I47f116b30f86e478c9057795bc518c391a8ae514
Diffstat (limited to 'include')
-rw-r--r--include/armnn/backends/MemCopyWorkload.hpp27
-rw-r--r--include/armnn/backends/TensorHandle.hpp267
-rw-r--r--include/armnn/backends/Workload.hpp219
-rw-r--r--include/armnn/backends/WorkloadData.hpp769
-rw-r--r--include/armnn/backends/WorkloadFactory.hpp289
-rw-r--r--include/armnnTestUtils/WorkloadTestUtils.hpp113
6 files changed, 1684 insertions, 0 deletions
diff --git a/include/armnn/backends/MemCopyWorkload.hpp b/include/armnn/backends/MemCopyWorkload.hpp
new file mode 100644
index 0000000000..da23f52be6
--- /dev/null
+++ b/include/armnn/backends/MemCopyWorkload.hpp
@@ -0,0 +1,27 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include "TensorHandle.hpp"
+#include "Workload.hpp"
+
+#include <utility>
+
+namespace armnn
+{
+
+class CopyMemGenericWorkload : public BaseWorkload<MemCopyQueueDescriptor>
+{
+public:
+ CopyMemGenericWorkload(const MemCopyQueueDescriptor& descriptor, const WorkloadInfo& info);
+ void Execute() const override;
+ void ExecuteAsync(WorkingMemDescriptor& descriptor) override;
+
+private:
+ using TensorHandlePair = std::pair<const ITensorHandle*, ITensorHandle*>;
+ std::vector<TensorHandlePair> m_TensorHandlePairs;
+};
+
+} //namespace armnn
diff --git a/include/armnn/backends/TensorHandle.hpp b/include/armnn/backends/TensorHandle.hpp
new file mode 100644
index 0000000000..2e6c8485d1
--- /dev/null
+++ b/include/armnn/backends/TensorHandle.hpp
@@ -0,0 +1,267 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "ITensorHandle.hpp"
+
+#include <armnn/TypesUtils.hpp>
+#include <armnn/utility/Assert.hpp>
+#include <armnnUtils/CompatibleTypes.hpp>
+
+#include <algorithm>
+
+namespace armnn
+{
+
+// Get a TensorShape representing the strides (in bytes) for each dimension
+// of a tensor, assuming fully packed data with no padding
+TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo);
+
+// Abstract tensor handles wrapping a readable region of memory, interpreting it as tensor data.
+class ConstTensorHandle : public ITensorHandle
+{
+public:
+ template <typename T>
+ const T* GetConstTensor() const
+ {
+ if (armnnUtils::CompatibleTypes<T>(GetTensorInfo().GetDataType()))
+ {
+ return reinterpret_cast<const T*>(m_Memory);
+ }
+ else
+ {
+ throw armnn::Exception("Attempting to get not compatible type tensor!");
+ }
+ }
+
+ const TensorInfo& GetTensorInfo() const
+ {
+ return m_TensorInfo;
+ }
+
+ virtual void Manage() override {}
+
+ virtual ITensorHandle* GetParent() const override { return nullptr; }
+
+ virtual const void* Map(bool /* blocking = true */) const override { return m_Memory; }
+ virtual void Unmap() const override {}
+
+ TensorShape GetStrides() const override
+ {
+ return GetUnpaddedTensorStrides(m_TensorInfo);
+ }
+ TensorShape GetShape() const override { return m_TensorInfo.GetShape(); }
+
+protected:
+ ConstTensorHandle(const TensorInfo& tensorInfo);
+
+ void SetConstMemory(const void* mem) { m_Memory = mem; }
+
+private:
+ // Only used for testing
+ void CopyOutTo(void *) const override { ARMNN_ASSERT_MSG(false, "Unimplemented"); }
+ void CopyInFrom(const void*) override { ARMNN_ASSERT_MSG(false, "Unimplemented"); }
+
+ ConstTensorHandle(const ConstTensorHandle& other) = delete;
+ ConstTensorHandle& operator=(const ConstTensorHandle& other) = delete;
+
+ TensorInfo m_TensorInfo;
+ const void* m_Memory;
+};
+
+template<>
+const void* ConstTensorHandle::GetConstTensor<void>() const;
+
+// Abstract specialization of ConstTensorHandle that allows write access to the same data.
+class TensorHandle : public ConstTensorHandle
+{
+public:
+ template <typename T>
+ T* GetTensor() const
+ {
+ if (armnnUtils::CompatibleTypes<T>(GetTensorInfo().GetDataType()))
+ {
+ return reinterpret_cast<T*>(m_MutableMemory);
+ }
+ else
+ {
+ throw armnn::Exception("Attempting to get not compatible type tensor!");
+ }
+ }
+
+protected:
+ TensorHandle(const TensorInfo& tensorInfo);
+
+ void SetMemory(void* mem)
+ {
+ m_MutableMemory = mem;
+ SetConstMemory(m_MutableMemory);
+ }
+
+private:
+
+ TensorHandle(const TensorHandle& other) = delete;
+ TensorHandle& operator=(const TensorHandle& other) = delete;
+ void* m_MutableMemory;
+};
+
+template <>
+void* TensorHandle::GetTensor<void>() const;
+
+// A TensorHandle that owns the wrapped memory region.
+class ScopedTensorHandle : public TensorHandle
+{
+public:
+ explicit ScopedTensorHandle(const TensorInfo& tensorInfo);
+
+ // Copies contents from Tensor.
+ explicit ScopedTensorHandle(const ConstTensor& tensor);
+
+ // Copies contents from ConstTensorHandle
+ explicit ScopedTensorHandle(const ConstTensorHandle& tensorHandle);
+
+ ScopedTensorHandle(const ScopedTensorHandle& other);
+ ScopedTensorHandle& operator=(const ScopedTensorHandle& other);
+ ~ScopedTensorHandle();
+
+ virtual void Allocate() override;
+
+private:
+ // Only used for testing
+ void CopyOutTo(void* memory) const override;
+ void CopyInFrom(const void* memory) override;
+
+ void CopyFrom(const ScopedTensorHandle& other);
+ void CopyFrom(const void* srcMemory, unsigned int numBytes);
+};
+
+// A TensorHandle that wraps an already allocated memory region.
+//
+// Clients must make sure the passed in memory region stays alive for the lifetime of
+// the PassthroughTensorHandle instance.
+//
+// Note there is no polymorphism to/from ConstPassthroughTensorHandle.
+class PassthroughTensorHandle : public TensorHandle
+{
+public:
+ PassthroughTensorHandle(const TensorInfo& tensorInfo, void* mem)
+ : TensorHandle(tensorInfo)
+ {
+ SetMemory(mem);
+ }
+
+ virtual void Allocate() override;
+};
+
+// A ConstTensorHandle that wraps an already allocated memory region.
+//
+// This allows users to pass in const memory to a network.
+// Clients must make sure the passed in memory region stays alive for the lifetime of
+// the PassthroughTensorHandle instance.
+//
+// Note there is no polymorphism to/from PassthroughTensorHandle.
+class ConstPassthroughTensorHandle : public ConstTensorHandle
+{
+public:
+ ConstPassthroughTensorHandle(const TensorInfo& tensorInfo, const void* mem)
+ : ConstTensorHandle(tensorInfo)
+ {
+ SetConstMemory(mem);
+ }
+
+ virtual void Allocate() override;
+};
+
+
+// Template specializations.
+
+template <>
+const void* ConstTensorHandle::GetConstTensor() const;
+
+template <>
+void* TensorHandle::GetTensor() const;
+
+class ManagedConstTensorHandle
+{
+
+public:
+ explicit ManagedConstTensorHandle(std::shared_ptr<ConstTensorHandle> ptr)
+ : m_Mapped(false)
+ , m_TensorHandle(std::move(ptr)) {};
+
+ /// RAII Managed resource Unmaps MemoryArea once out of scope
+ const void* Map(bool blocking = true)
+ {
+ if (m_TensorHandle)
+ {
+ auto pRet = m_TensorHandle->Map(blocking);
+ m_Mapped = true;
+ return pRet;
+ }
+ else
+ {
+ throw armnn::Exception("Attempting to Map null TensorHandle");
+ }
+
+ }
+
+ // Delete copy constructor as it's unnecessary
+ ManagedConstTensorHandle(const ConstTensorHandle& other) = delete;
+
+ // Delete copy assignment as it's unnecessary
+ ManagedConstTensorHandle& operator=(const ManagedConstTensorHandle& other) = delete;
+
+ // Delete move assignment as it's unnecessary
+ ManagedConstTensorHandle& operator=(ManagedConstTensorHandle&& other) noexcept = delete;
+
+ ~ManagedConstTensorHandle()
+ {
+ // Bias tensor handles need to be initialized empty before entering scope of if statement checking if enabled
+ if (m_TensorHandle)
+ {
+ Unmap();
+ }
+ }
+
+ void Unmap()
+ {
+ // Only unmap if mapped and TensorHandle exists.
+ if (m_Mapped && m_TensorHandle)
+ {
+ m_TensorHandle->Unmap();
+ m_Mapped = false;
+ }
+ }
+
+ const TensorInfo& GetTensorInfo() const
+ {
+ return m_TensorHandle->GetTensorInfo();
+ }
+
+ bool IsMapped() const
+ {
+ return m_Mapped;
+ }
+
+private:
+ bool m_Mapped;
+ std::shared_ptr<ConstTensorHandle> m_TensorHandle;
+};
+
+using ConstCpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("ConstCpuTensorHandle is deprecated, "
+ "use ConstTensorHandle instead", "22.05") = ConstTensorHandle;
+using CpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("CpuTensorHandle is deprecated, "
+ "use TensorHandle instead", "22.05") = TensorHandle;
+using ScopedCpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("ScopedCpuTensorHandle is deprecated, "
+ "use ScopedTensorHandle instead", "22.05") = ScopedTensorHandle;
+using PassthroughCpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("PassthroughCpuTensorHandle is deprecated, use "
+ "PassthroughTensorHandle instead",
+ "22.05") = PassthroughTensorHandle;
+using ConstPassthroughCpuTensorHandle ARMNN_DEPRECATED_MSG_REMOVAL_DATE("ConstPassthroughCpuTensorHandle is "
+ "deprecated, use ConstPassthroughTensorHandle "
+ "instead", "22.05") = ConstPassthroughTensorHandle;
+
+} // namespace armnn
diff --git a/include/armnn/backends/Workload.hpp b/include/armnn/backends/Workload.hpp
new file mode 100644
index 0000000000..7c1bda50bc
--- /dev/null
+++ b/include/armnn/backends/Workload.hpp
@@ -0,0 +1,219 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include "IWorkload.hpp"
+#include "WorkloadData.hpp"
+#include "WorkloadInfo.hpp"
+#include "WorkingMemDescriptor.hpp"
+
+#include <Profiling.hpp>
+#include <ProfilingService.hpp>
+
+#include <algorithm>
+
+namespace armnn
+{
+
+// NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
+// in the various workload factories.
+// There should never be an instantiation of a NullWorkload.
+class NullWorkload : public IWorkload
+{
+ NullWorkload()=delete;
+};
+
+template <typename QueueDescriptor>
+class BaseWorkload : public IWorkload
+{
+public:
+
+ BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
+ : m_Data(descriptor),
+ m_Guid(profiling::ProfilingService::GetNextGuid())
+ {
+ m_Data.Validate(info);
+ }
+
+ void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override
+ {
+ ARMNN_LOG(info) << "Using default async workload execution, this will network affect performance";
+ std::lock_guard<std::mutex> lockGuard(m_AsyncWorkloadMutex);
+
+ m_Data.m_Inputs = workingMemDescriptor.m_Inputs;
+ m_Data.m_Outputs = workingMemDescriptor.m_Outputs;
+
+ Execute();
+ };
+
+ void PostAllocationConfigure() override {}
+
+ const QueueDescriptor& GetData() const { return m_Data; }
+
+ profiling::ProfilingGuid GetGuid() const final { return m_Guid; }
+
+protected:
+ QueueDescriptor m_Data;
+ const profiling::ProfilingGuid m_Guid;
+
+private:
+ std::mutex m_AsyncWorkloadMutex;
+};
+
+// TypedWorkload used
+template <typename QueueDescriptor, armnn::DataType... DataTypes>
+class TypedWorkload : public BaseWorkload<QueueDescriptor>
+{
+public:
+
+ TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
+ : BaseWorkload<QueueDescriptor>(descriptor, info)
+ {
+ std::vector<armnn::DataType> dataTypes = {DataTypes...};
+ armnn::DataType expectedInputType;
+
+ if (!info.m_InputTensorInfos.empty())
+ {
+ expectedInputType = info.m_InputTensorInfos.front().GetDataType();
+
+ if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
+ {
+ ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
+ }
+ ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
+ info.m_InputTensorInfos.end(),
+ [&](auto it){
+ return it.GetDataType() == expectedInputType;
+ }),
+ "Trying to create workload with incorrect type");
+ }
+ armnn::DataType expectedOutputType;
+
+ if (!info.m_OutputTensorInfos.empty())
+ {
+ expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
+
+ if (!info.m_InputTensorInfos.empty())
+ {
+ if (expectedOutputType != expectedInputType)
+ {
+ ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
+ }
+ }
+ else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
+ {
+ ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
+ }
+ ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
+ info.m_OutputTensorInfos.end(),
+ [&](auto it){
+ return it.GetDataType() == expectedOutputType;
+ }),
+ "Trying to create workload with incorrect type");
+ }
+ }
+};
+
+template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
+class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
+{
+public:
+
+ MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
+ : BaseWorkload<QueueDescriptor>(descriptor, info)
+ {
+ ARMNN_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
+ info.m_InputTensorInfos.end(),
+ [&](auto it){
+ return it.GetDataType() == InputDataType;
+ }),
+ "Trying to create workload with incorrect type");
+
+ ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
+ info.m_OutputTensorInfos.end(),
+ [&](auto it){
+ return it.GetDataType() == OutputDataType;
+ }),
+ "Trying to create workload with incorrect type");
+ }
+};
+
+// FirstInputTypedWorkload used to check type of the first input
+template <typename QueueDescriptor, armnn::DataType DataType>
+class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
+{
+public:
+
+ FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
+ : BaseWorkload<QueueDescriptor>(descriptor, info)
+ {
+ if (!info.m_InputTensorInfos.empty())
+ {
+ ARMNN_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType,
+ "Trying to create workload with incorrect type");
+ }
+
+ ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
+ info.m_OutputTensorInfos.end(),
+ [&](auto it){
+ return it.GetDataType() == DataType;
+ }),
+ "Trying to create workload with incorrect type");
+ }
+};
+
+template <typename QueueDescriptor>
+using FloatWorkload = TypedWorkload<QueueDescriptor,
+ armnn::DataType::Float16,
+ armnn::DataType::Float32>;
+
+template <typename QueueDescriptor>
+using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>;
+
+template <typename QueueDescriptor>
+using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QAsymmU8>;
+
+template <typename QueueDescriptor>
+using Int32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Signed32>;
+
+template <typename QueueDescriptor>
+using BooleanWorkload = TypedWorkload<QueueDescriptor, armnn::DataType::Boolean>;
+
+template <typename QueueDescriptor>
+using BaseFloat32ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
+ armnn::DataType::Float32,
+ armnn::DataType::Boolean>;
+
+template <typename QueueDescriptor>
+using BaseUint8ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
+ armnn::DataType::QAsymmU8,
+ armnn::DataType::Boolean>;
+
+template <typename QueueDescriptor>
+using BFloat16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
+ armnn::DataType::BFloat16,
+ armnn::DataType::Float32>;
+
+template <typename QueueDescriptor>
+using Float32ToBFloat16Workload = MultiTypedWorkload<QueueDescriptor,
+ armnn::DataType::Float32,
+ armnn::DataType::BFloat16>;
+
+template <typename QueueDescriptor>
+using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
+ armnn::DataType::Float16,
+ armnn::DataType::Float32>;
+
+template <typename QueueDescriptor>
+using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
+ armnn::DataType::Float32,
+ armnn::DataType::Float16>;
+
+template <typename QueueDescriptor>
+using Uint8ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
+ armnn::DataType::QAsymmU8,
+ armnn::DataType::Float32>;
+
+} //namespace armnn
diff --git a/include/armnn/backends/WorkloadData.hpp b/include/armnn/backends/WorkloadData.hpp
new file mode 100644
index 0000000000..7406547216
--- /dev/null
+++ b/include/armnn/backends/WorkloadData.hpp
@@ -0,0 +1,769 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include "TensorHandle.hpp"
+
+#include <armnn/Deprecated.hpp>
+#include <armnn/Descriptors.hpp>
+#include <armnn/Exceptions.hpp>
+#include <armnn/Types.hpp>
+#include <armnn/Tensor.hpp>
+#include <common/include/ProfilingGuid.hpp>
+
+namespace armnn
+{
+
+//A helper function that returns the bias data type required for given input data type.
+DataType GetBiasDataType(DataType inputDataType);
+
+struct WorkloadInfo;
+
+struct QueueDescriptor
+{
+ std::vector<ITensorHandle*> m_Inputs;
+ std::vector<ITensorHandle*> m_Outputs;
+ void* m_AdditionalInfoObject;
+
+ void ValidateInputsOutputs(const std::string& descName,
+ unsigned int numExpectedIn,
+ unsigned int numExpectedOut) const;
+
+ template<typename T>
+ const T* GetAdditionalInformation() const
+ {
+ return static_cast<T*>(m_AdditionalInfoObject);
+ }
+
+protected:
+ ~QueueDescriptor() = default;
+ QueueDescriptor()
+ : m_AdditionalInfoObject(nullptr)
+ {}
+ QueueDescriptor(QueueDescriptor const&) = default;
+ QueueDescriptor& operator=(QueueDescriptor const&) = default;
+};
+
+// Base class for queue descriptors which contain parameters.
+template <typename LayerDescriptor>
+struct QueueDescriptorWithParameters : public QueueDescriptor
+{
+ LayerDescriptor m_Parameters;
+
+protected:
+ ~QueueDescriptorWithParameters() = default;
+ QueueDescriptorWithParameters() = default;
+ QueueDescriptorWithParameters(QueueDescriptorWithParameters const&) = default;
+ QueueDescriptorWithParameters& operator=(QueueDescriptorWithParameters const&) = default;
+};
+
+struct MapQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct UnmapQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct MemCopyQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+using InputQueueDescriptor = MemCopyQueueDescriptor;
+using OutputQueueDescriptor = MemCopyQueueDescriptor;
+
+struct MemImportQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct MemSyncQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Softmax layer workload data.
+struct SoftmaxQueueDescriptor : QueueDescriptorWithParameters<SoftmaxDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Splitter layer workload data.
+struct SplitterQueueDescriptor : QueueDescriptorWithParameters<ViewsDescriptor>
+{
+ struct ViewOrigin
+ {
+ ViewOrigin() {}
+ ViewOrigin(std::vector<unsigned int> const& origin) : m_Origin(origin) {}
+
+ //View origin (size of the vector is the same as number of dimensions of the view).
+ std::vector<unsigned int> m_Origin;
+ };
+
+ //View defines a tensor that will be carved from the input tensor.
+ //View origins are stored here, the extents are defined by sizes of the output tensors.
+ std::vector<ViewOrigin> m_ViewOrigins;
+
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Concat layer workload data.
+struct ConcatQueueDescriptor : QueueDescriptorWithParameters<OriginsDescriptor>
+{
+ struct ViewOrigin
+ {
+ ViewOrigin() {}
+ ViewOrigin(const std::vector<unsigned int>& origin) : m_Origin(origin) {}
+
+ //View origin (size of the vector is the same as number of dimensions of the view).
+ std::vector<unsigned int> m_Origin;
+ };
+
+ //View defines a sub-area of the output tensor that will be filled with the corresponding input tensor.
+ //View origins are stored here, the extents are defined by sizes of the input tensors.
+ std::vector<ViewOrigin> m_ViewOrigins;
+
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Deprecated. Use ConcatQueueDescriptor instead
+using MergerQueueDescriptor = ConcatQueueDescriptor;
+
+// Stack layer workload data.
+struct StackQueueDescriptor : QueueDescriptorWithParameters<StackDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Activation layer workload data.
+struct ActivationQueueDescriptor : QueueDescriptorWithParameters<ActivationDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct ArgMinMaxQueueDescriptor : QueueDescriptorWithParameters<ArgMinMaxDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct CastQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Fill layer workload data.
+struct FillQueueDescriptor : QueueDescriptorWithParameters<FillDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Fully connected layer workload data.
+struct FullyConnectedQueueDescriptor : QueueDescriptorWithParameters<FullyConnectedDescriptor>
+{
+ FullyConnectedQueueDescriptor()
+ : m_Weight(nullptr)
+ , m_Bias(nullptr)
+ {
+ }
+
+ const ConstTensorHandle* m_Weight;
+ const ConstTensorHandle* m_Bias;
+
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Permute layer workload data.
+struct PermuteQueueDescriptor : QueueDescriptorWithParameters<PermuteDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Pooling 2D layer workload data.
+struct Pooling2dQueueDescriptor : QueueDescriptorWithParameters<Pooling2dDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Pooling 3D layer workload data.
+struct Pooling3dQueueDescriptor : QueueDescriptorWithParameters<Pooling3dDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+
+// Convolution 2D layer workload data.
+struct Convolution2dQueueDescriptor : QueueDescriptorWithParameters<Convolution2dDescriptor>
+{
+ Convolution2dQueueDescriptor()
+ : m_Weight(nullptr)
+ , m_Bias(nullptr)
+ {
+ }
+
+ const ConstTensorHandle* m_Weight;
+ const ConstTensorHandle* m_Bias;
+
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Convolution 3D layer workload data.
+struct Convolution3dQueueDescriptor : QueueDescriptorWithParameters<Convolution3dDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+/// Depthwise Convolution 2D layer workload data.
+///
+/// @note
+/// The weights are in the format [1, H, W, I*M]. Where I is the input channel size, M the depthwise mutliplier and
+/// H, W is the height and width of the filter kernel. If per channel quantization is applied
+/// the weights will be quantized along the last dimension/axis (I*M) which corresponds to the output channel size.
+/// If per channel quantization is applied the weights tensor will have I*M scales, one for each dimension
+/// of the quantization axis. You have to be aware of this when reshaping the weights tensor.
+/// Splitting the I*M axis, e.g. [1, H, W, I*M] --> [H, W, I, M], won't work without taking care of the
+/// corresponding quantization scales.
+/// If there is no per channel quantization applied reshaping the weights tensor won't cause any issues. There are
+/// preconfigured permutation functions available @link WorkloadUtils.hpp here.
+///
+struct DepthwiseConvolution2dQueueDescriptor : QueueDescriptorWithParameters<DepthwiseConvolution2dDescriptor>
+{
+ DepthwiseConvolution2dQueueDescriptor()
+ : m_Weight(nullptr)
+ , m_Bias(nullptr)
+ {
+ }
+
+ const ConstTensorHandle* m_Weight;
+ const ConstTensorHandle* m_Bias;
+
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct DetectionPostProcessQueueDescriptor : QueueDescriptorWithParameters<DetectionPostProcessDescriptor>
+{
+ DetectionPostProcessQueueDescriptor()
+ : m_Anchors(nullptr)
+ {
+ }
+
+ const ConstTensorHandle* m_Anchors;
+
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Normalization layer workload data.
+struct NormalizationQueueDescriptor : QueueDescriptorWithParameters<NormalizationDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Add layer workload data.
+struct AdditionQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Multiplication layer workload data.
+struct MultiplicationQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Division layer workload data.
+struct DivisionQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Subtraction layer workload data.
+struct SubtractionQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Maximum layer workload data.
+struct MaximumQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Mean layer workload data.
+struct MeanQueueDescriptor : QueueDescriptorWithParameters<MeanDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Pad layer workload data
+struct PadQueueDescriptor : QueueDescriptorWithParameters<PadDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct QuantizeQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Deprecated use ComparisonQueueDescriptor instead
+struct EqualQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Batch norm layer workload data.
+struct BatchNormalizationQueueDescriptor : QueueDescriptorWithParameters<BatchNormalizationDescriptor>
+{
+ BatchNormalizationQueueDescriptor()
+ : m_Mean(nullptr)
+ , m_Variance(nullptr)
+ , m_Beta(nullptr)
+ , m_Gamma(nullptr)
+ {
+ }
+
+ const ConstTensorHandle* m_Mean;
+ const ConstTensorHandle* m_Variance;
+ const ConstTensorHandle* m_Beta;
+ const ConstTensorHandle* m_Gamma;
+
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct RankQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+ARMNN_NO_DEPRECATE_WARN_BEGIN
+struct
+ARMNN_DEPRECATED_MSG_REMOVAL_DATE("ResizeBilinearQueueDescriptor is deprecated use ResizeQueueDescriptor instead",
+ "22.08")
+ResizeBilinearQueueDescriptor : QueueDescriptorWithParameters<ResizeBilinearDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+ARMNN_NO_DEPRECATE_WARN_END
+
+struct ResizeQueueDescriptor : QueueDescriptorWithParameters<ResizeDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct FakeQuantizationQueueDescriptor : QueueDescriptorWithParameters<FakeQuantizationDescriptor>
+{
+ FakeQuantizationQueueDescriptor()
+ : m_Min(nullptr)
+ , m_Max(nullptr)
+ {
+ }
+
+ const ConstTensorHandle* m_Min;
+ const ConstTensorHandle* m_Max;
+
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct InstanceNormalizationQueueDescriptor : QueueDescriptorWithParameters<InstanceNormalizationDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct L2NormalizationQueueDescriptor : QueueDescriptorWithParameters<L2NormalizationDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct LogSoftmaxQueueDescriptor : QueueDescriptorWithParameters<LogSoftmaxDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct ConstantQueueDescriptor : QueueDescriptor
+{
+ ConstantQueueDescriptor()
+ : m_LayerOutput(nullptr)
+ {
+ }
+
+ const ConstTensorHandle* m_LayerOutput;
+
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct ReshapeQueueDescriptor : QueueDescriptorWithParameters<ReshapeDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct SpaceToBatchNdQueueDescriptor : QueueDescriptorWithParameters<SpaceToBatchNdDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct SpaceToDepthQueueDescriptor : QueueDescriptorWithParameters<SpaceToDepthDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct FloorQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct LstmQueueDescriptor : QueueDescriptorWithParameters<LstmDescriptor>
+{
+ LstmQueueDescriptor()
+ : m_InputToInputWeights(nullptr)
+ , m_InputToForgetWeights(nullptr)
+ , m_InputToCellWeights(nullptr)
+ , m_InputToOutputWeights(nullptr)
+ , m_RecurrentToInputWeights(nullptr)
+ , m_RecurrentToForgetWeights(nullptr)
+ , m_RecurrentToCellWeights(nullptr)
+ , m_RecurrentToOutputWeights(nullptr)
+ , m_CellToInputWeights(nullptr)
+ , m_CellToForgetWeights(nullptr)
+ , m_CellToOutputWeights(nullptr)
+ , m_InputGateBias(nullptr)
+ , m_ForgetGateBias(nullptr)
+ , m_CellBias(nullptr)
+ , m_OutputGateBias(nullptr)
+ , m_ProjectionWeights(nullptr)
+ , m_ProjectionBias(nullptr)
+ , m_InputLayerNormWeights(nullptr)
+ , m_ForgetLayerNormWeights(nullptr)
+ , m_CellLayerNormWeights(nullptr)
+ , m_OutputLayerNormWeights(nullptr)
+ {
+ }
+
+ const ConstTensorHandle* m_InputToInputWeights;
+ const ConstTensorHandle* m_InputToForgetWeights;
+ const ConstTensorHandle* m_InputToCellWeights;
+ const ConstTensorHandle* m_InputToOutputWeights;
+ const ConstTensorHandle* m_RecurrentToInputWeights;
+ const ConstTensorHandle* m_RecurrentToForgetWeights;
+ const ConstTensorHandle* m_RecurrentToCellWeights;
+ const ConstTensorHandle* m_RecurrentToOutputWeights;
+ const ConstTensorHandle* m_CellToInputWeights;
+ const ConstTensorHandle* m_CellToForgetWeights;
+ const ConstTensorHandle* m_CellToOutputWeights;
+ const ConstTensorHandle* m_InputGateBias;
+ const ConstTensorHandle* m_ForgetGateBias;
+ const ConstTensorHandle* m_CellBias;
+ const ConstTensorHandle* m_OutputGateBias;
+ const ConstTensorHandle* m_ProjectionWeights;
+ const ConstTensorHandle* m_ProjectionBias;
+ const ConstTensorHandle* m_InputLayerNormWeights;
+ const ConstTensorHandle* m_ForgetLayerNormWeights;
+ const ConstTensorHandle* m_CellLayerNormWeights;
+ const ConstTensorHandle* m_OutputLayerNormWeights;
+
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct ConvertBf16ToFp32QueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct ConvertFp32ToBf16QueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct ConvertFp16ToFp32QueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct ConvertFp32ToFp16QueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct BatchToSpaceNdQueueDescriptor : QueueDescriptorWithParameters<BatchToSpaceNdDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct StridedSliceQueueDescriptor : QueueDescriptorWithParameters<StridedSliceDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Minimum layer workload data.
+struct MinimumQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+// Deprecated use ComparisonQueueDescriptor instead
+struct GreaterQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct DebugQueueDescriptor : QueueDescriptor
+{
+ DebugQueueDescriptor() : m_Guid(0) {}
+
+ void Validate(const WorkloadInfo& workloadInfo) const;
+
+ LayerGuid m_Guid;
+ std::string m_LayerName;
+ unsigned int m_SlotIndex;
+};
+
+struct RsqrtQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct GatherQueueDescriptor : QueueDescriptorWithParameters<GatherDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct PreCompiledQueueDescriptor : QueueDescriptorWithParameters<PreCompiledDescriptor>
+{
+ PreCompiledQueueDescriptor()
+ : m_PreCompiledObject(nullptr)
+ {
+ }
+
+ void* m_PreCompiledObject;
+
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct DequantizeQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct MergeQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct SwitchQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct PreluQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct TransposeConvolution2dQueueDescriptor : QueueDescriptorWithParameters<TransposeConvolution2dDescriptor>
+{
+ TransposeConvolution2dQueueDescriptor() :
+ m_Weight(nullptr),
+ m_Bias(nullptr)
+ {}
+
+ const ConstTensorHandle* m_Weight;
+ const ConstTensorHandle* m_Bias;
+
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct TransposeQueueDescriptor : QueueDescriptorWithParameters<TransposeDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct QLstmQueueDescriptor : QueueDescriptorWithParameters<QLstmDescriptor>
+{
+ QLstmQueueDescriptor()
+ : m_InputToInputWeights(nullptr)
+ , m_InputToForgetWeights(nullptr)
+ , m_InputToCellWeights(nullptr)
+ , m_InputToOutputWeights(nullptr)
+ , m_RecurrentToInputWeights(nullptr)
+ , m_RecurrentToForgetWeights(nullptr)
+ , m_RecurrentToCellWeights(nullptr)
+ , m_RecurrentToOutputWeights(nullptr)
+ , m_CellToInputWeights(nullptr)
+ , m_CellToForgetWeights(nullptr)
+ , m_CellToOutputWeights(nullptr)
+ , m_InputGateBias(nullptr)
+ , m_ForgetGateBias(nullptr)
+ , m_CellBias(nullptr)
+ , m_OutputGateBias(nullptr)
+ , m_ProjectionWeights(nullptr)
+ , m_ProjectionBias(nullptr)
+ , m_InputLayerNormWeights(nullptr)
+ , m_ForgetLayerNormWeights(nullptr)
+ , m_CellLayerNormWeights(nullptr)
+ , m_OutputLayerNormWeights(nullptr)
+ {
+ }
+
+ const ConstTensorHandle* m_InputToInputWeights;
+ const ConstTensorHandle* m_InputToForgetWeights;
+ const ConstTensorHandle* m_InputToCellWeights;
+ const ConstTensorHandle* m_InputToOutputWeights;
+ const ConstTensorHandle* m_RecurrentToInputWeights;
+ const ConstTensorHandle* m_RecurrentToForgetWeights;
+ const ConstTensorHandle* m_RecurrentToCellWeights;
+ const ConstTensorHandle* m_RecurrentToOutputWeights;
+ const ConstTensorHandle* m_CellToInputWeights;
+ const ConstTensorHandle* m_CellToForgetWeights;
+ const ConstTensorHandle* m_CellToOutputWeights;
+ const ConstTensorHandle* m_InputGateBias;
+ const ConstTensorHandle* m_ForgetGateBias;
+ const ConstTensorHandle* m_CellBias;
+ const ConstTensorHandle* m_OutputGateBias;
+ const ConstTensorHandle* m_ProjectionWeights;
+ const ConstTensorHandle* m_ProjectionBias;
+ const ConstTensorHandle* m_InputLayerNormWeights;
+ const ConstTensorHandle* m_ForgetLayerNormWeights;
+ const ConstTensorHandle* m_CellLayerNormWeights;
+ const ConstTensorHandle* m_OutputLayerNormWeights;
+
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct QuantizedLstmQueueDescriptor : QueueDescriptor
+{
+ QuantizedLstmQueueDescriptor()
+ : m_InputToInputWeights(nullptr)
+ , m_InputToForgetWeights(nullptr)
+ , m_InputToCellWeights(nullptr)
+ , m_InputToOutputWeights(nullptr)
+
+ , m_RecurrentToInputWeights(nullptr)
+ , m_RecurrentToForgetWeights(nullptr)
+ , m_RecurrentToCellWeights(nullptr)
+ , m_RecurrentToOutputWeights(nullptr)
+
+ , m_InputGateBias(nullptr)
+ , m_ForgetGateBias(nullptr)
+ , m_CellBias(nullptr)
+ , m_OutputGateBias(nullptr)
+ {}
+
+ const ConstTensorHandle* m_InputToInputWeights;
+ const ConstTensorHandle* m_InputToForgetWeights;
+ const ConstTensorHandle* m_InputToCellWeights;
+ const ConstTensorHandle* m_InputToOutputWeights;
+
+ const ConstTensorHandle* m_RecurrentToInputWeights;
+ const ConstTensorHandle* m_RecurrentToForgetWeights;
+ const ConstTensorHandle* m_RecurrentToCellWeights;
+ const ConstTensorHandle* m_RecurrentToOutputWeights;
+
+ const ConstTensorHandle* m_InputGateBias;
+ const ConstTensorHandle* m_ForgetGateBias;
+ const ConstTensorHandle* m_CellBias;
+ const ConstTensorHandle* m_OutputGateBias;
+
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct AbsQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct SliceQueueDescriptor : QueueDescriptorWithParameters<SliceDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct DepthToSpaceQueueDescriptor : QueueDescriptorWithParameters<DepthToSpaceDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct ComparisonQueueDescriptor : QueueDescriptorWithParameters<ComparisonDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct ElementwiseUnaryQueueDescriptor : QueueDescriptorWithParameters<ElementwiseUnaryDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct LogicalBinaryQueueDescriptor : QueueDescriptorWithParameters<LogicalBinaryDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct ReduceQueueDescriptor : QueueDescriptorWithParameters<ReduceDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct ShapeQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct UnidirectionalSequenceLstmQueueDescriptor : QueueDescriptorWithParameters<LstmDescriptor>
+{
+ UnidirectionalSequenceLstmQueueDescriptor()
+ : m_InputToInputWeights(nullptr)
+ , m_InputToForgetWeights(nullptr)
+ , m_InputToCellWeights(nullptr)
+ , m_InputToOutputWeights(nullptr)
+ , m_RecurrentToInputWeights(nullptr)
+ , m_RecurrentToForgetWeights(nullptr)
+ , m_RecurrentToCellWeights(nullptr)
+ , m_RecurrentToOutputWeights(nullptr)
+ , m_CellToInputWeights(nullptr)
+ , m_CellToForgetWeights(nullptr)
+ , m_CellToOutputWeights(nullptr)
+ , m_InputGateBias(nullptr)
+ , m_ForgetGateBias(nullptr)
+ , m_CellBias(nullptr)
+ , m_OutputGateBias(nullptr)
+ , m_ProjectionWeights(nullptr)
+ , m_ProjectionBias(nullptr)
+ , m_InputLayerNormWeights(nullptr)
+ , m_ForgetLayerNormWeights(nullptr)
+ , m_CellLayerNormWeights(nullptr)
+ , m_OutputLayerNormWeights(nullptr)
+ {
+ }
+
+ const ConstTensorHandle* m_InputToInputWeights;
+ const ConstTensorHandle* m_InputToForgetWeights;
+ const ConstTensorHandle* m_InputToCellWeights;
+ const ConstTensorHandle* m_InputToOutputWeights;
+ const ConstTensorHandle* m_RecurrentToInputWeights;
+ const ConstTensorHandle* m_RecurrentToForgetWeights;
+ const ConstTensorHandle* m_RecurrentToCellWeights;
+ const ConstTensorHandle* m_RecurrentToOutputWeights;
+ const ConstTensorHandle* m_CellToInputWeights;
+ const ConstTensorHandle* m_CellToForgetWeights;
+ const ConstTensorHandle* m_CellToOutputWeights;
+ const ConstTensorHandle* m_InputGateBias;
+ const ConstTensorHandle* m_ForgetGateBias;
+ const ConstTensorHandle* m_CellBias;
+ const ConstTensorHandle* m_OutputGateBias;
+ const ConstTensorHandle* m_ProjectionWeights;
+ const ConstTensorHandle* m_ProjectionBias;
+ const ConstTensorHandle* m_InputLayerNormWeights;
+ const ConstTensorHandle* m_ForgetLayerNormWeights;
+ const ConstTensorHandle* m_CellLayerNormWeights;
+ const ConstTensorHandle* m_OutputLayerNormWeights;
+
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct ChannelShuffleQueueDescriptor : QueueDescriptorWithParameters<ChannelShuffleDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+} // namespace armnn
diff --git a/include/armnn/backends/WorkloadFactory.hpp b/include/armnn/backends/WorkloadFactory.hpp
new file mode 100644
index 0000000000..68ad2e3741
--- /dev/null
+++ b/include/armnn/backends/WorkloadFactory.hpp
@@ -0,0 +1,289 @@
+//
+// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include "ITensorHandle.hpp"
+#include "Workload.hpp"
+
+#include <armnn/Optional.hpp>
+#include <armnn/INetwork.hpp>
+#include <armnn/TensorFwd.hpp>
+
+#include <memory>
+
+namespace armnn
+{
+
+class Layer;
+
+// Workload factory interface for compute backends.
+class IWorkloadFactory
+{
+public:
+ virtual ~IWorkloadFactory() { }
+
+ virtual void AfterWorkloadsCreated() {};
+
+ virtual const BackendId& GetBackendId() const = 0;
+
+ static bool IsLayerSupported(const BackendId& backendId,
+ const IConnectableLayer& layer,
+ Optional<DataType> dataType,
+ std::string& outReasonIfUnsupported);
+
+ static bool IsLayerSupported(const IConnectableLayer& layer,
+ Optional<DataType> dataType,
+ std::string& outReasonIfUnsupported);
+
+ static bool IsLayerSupported(const IConnectableLayer& layer,
+ Optional<DataType> dataType,
+ std::string& outReasonIfUnsupported,
+ const ModelOptions& modelOptions);
+
+ static bool IsLayerSupported(const BackendId& backendId,
+ const IConnectableLayer& layer,
+ Optional<DataType> dataType,
+ std::string& outReasonIfUnsupported,
+ const ModelOptions& modelOptions);
+
+ virtual bool SupportsSubTensors() const = 0;
+
+ ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateSubTensorHandle instead")
+ virtual std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
+ TensorShape const& subTensorShape,
+ unsigned int const* subTensorOrigin
+ ) const = 0;
+
+ virtual std::unique_ptr<IWorkload> CreateInput(const InputQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const = 0;
+
+ ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
+ virtual std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
+ const bool IsMemoryManaged = true) const = 0;
+
+ ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
+ virtual std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
+ DataLayout dataLayout,
+ const bool IsMemoryManaged = true) const = 0;
+
+ virtual std::unique_ptr<IWorkload> CreateActivation(const ActivationQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateAddition(const AdditionQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateArgMinMax(const ArgMinMaxQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateBatchNormalization(const BatchNormalizationQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
+ const WorkloadInfo& Info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateCast(const CastQueueDescriptor& descriptor,
+ const WorkloadInfo& Info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateChannelShuffle(const ChannelShuffleQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateComparison(const ComparisonQueueDescriptor& descriptor,
+ const WorkloadInfo& Info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateConcat(const ConcatQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateConstant(const ConstantQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateConvolution3d(const Convolution3dQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateDebug(const DebugQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateDepthToSpace(const DepthToSpaceQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateDepthwiseConvolution2d(
+ const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateDequantize(const DequantizeQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateDetectionPostProcess(
+ const DetectionPostProcessQueueDescriptor& descriptor, const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateDivision(const DivisionQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& descriptor,
+ const WorkloadInfo& Info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateFakeQuantization(const FakeQuantizationQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateFill(const FillQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateFloor(const FloorQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateGather(const GatherQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateInstanceNormalization(
+ const InstanceNormalizationQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateLogicalBinary(const LogicalBinaryQueueDescriptor& descriptor,
+ const WorkloadInfo& Info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateLogicalUnary(const ElementwiseUnaryQueueDescriptor& descriptor,
+ const WorkloadInfo& Info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateLogSoftmax(const LogSoftmaxQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateLstm(const LstmQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateMaximum(const MaximumQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateMean(const MeanQueueDescriptor& descriptor,
+ const WorkloadInfo& Info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateMemImport(const MemImportQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateMerge(const MergeQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateMinimum(const MinimumQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateNormalization(const NormalizationQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateOutput(const OutputQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreatePad(const PadQueueDescriptor& descriptor,
+ const WorkloadInfo& Info) const;
+
+ virtual std::unique_ptr<IWorkload> CreatePermute(const PermuteQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreatePooling3d(const Pooling3dQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreatePrelu(const PreluQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateQuantize(const QuantizeQueueDescriptor& descriptor,
+ const WorkloadInfo& Info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateQLstm(const QLstmQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateRank(const RankQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateReduce(const ReduceQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateReshape(const ReshapeQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateResize(const ResizeQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateShape(const ShapeQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateSlice(const SliceQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateSplitter(const SplitterQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateStack(const StackQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
+ const WorkloadInfo& Info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateSwitch(const SwitchQueueDescriptor& descriptor,
+ const WorkloadInfo& Info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateTranspose(const TransposeQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateTransposeConvolution2d(
+ const TransposeConvolution2dQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+ virtual std::unique_ptr<IWorkload> CreateUnidirectionalSequenceLstm(
+ const UnidirectionalSequenceLstmQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
+private:
+ static bool IsLayerConfigurationSupported(const BackendId& backendId,
+ const IConnectableLayer& connectableLayer,
+ Optional<DataType> dataType,
+ std::string& outReasonIfUnsupported,
+ const ModelOptions& modelOptions = {});
+};
+
+} // namespace armnn
diff --git a/include/armnnTestUtils/WorkloadTestUtils.hpp b/include/armnnTestUtils/WorkloadTestUtils.hpp
new file mode 100644
index 0000000000..156258a549
--- /dev/null
+++ b/include/armnnTestUtils/WorkloadTestUtils.hpp
@@ -0,0 +1,113 @@
+//
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include <armnn/Tensor.hpp>
+
+#include <armnn/backends/IBackendInternal.hpp>
+#include <armnn/backends/IMemoryManager.hpp>
+#include <armnn/backends/Workload.hpp>
+#include <armnn/backends/WorkloadInfo.hpp>
+
+namespace armnn
+{
+class ITensorHandle;
+} // namespace armnn
+
+namespace
+{
+
+template <typename QueueDescriptor>
+void AddInputToWorkload(QueueDescriptor& descriptor,
+ armnn::WorkloadInfo& info,
+ const armnn::TensorInfo& tensorInfo,
+ armnn::ITensorHandle* tensorHandle)
+{
+ descriptor.m_Inputs.push_back(tensorHandle);
+ info.m_InputTensorInfos.push_back(tensorInfo);
+}
+
+template <typename QueueDescriptor>
+void AddOutputToWorkload(QueueDescriptor& descriptor,
+ armnn::WorkloadInfo& info,
+ const armnn::TensorInfo& tensorInfo,
+ armnn::ITensorHandle* tensorHandle)
+{
+ descriptor.m_Outputs.push_back(tensorHandle);
+ info.m_OutputTensorInfos.push_back(tensorInfo);
+}
+
+template <typename QueueDescriptor>
+void SetWorkloadInput(QueueDescriptor& descriptor,
+ armnn::WorkloadInfo& info,
+ unsigned int index,
+ const armnn::TensorInfo& tensorInfo,
+ armnn::ITensorHandle* tensorHandle)
+{
+ descriptor.m_Inputs[index] = tensorHandle;
+ info.m_InputTensorInfos[index] = tensorInfo;
+}
+
+template <typename QueueDescriptor>
+void SetWorkloadOutput(QueueDescriptor& descriptor,
+ armnn::WorkloadInfo& info,
+ unsigned int index,
+ const armnn::TensorInfo& tensorInfo,
+ armnn::ITensorHandle* tensorHandle)
+{
+ descriptor.m_Outputs[index] = tensorHandle;
+ info.m_OutputTensorInfos[index] = tensorInfo;
+}
+
+inline void ExecuteWorkload(armnn::IWorkload& workload,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ bool memoryManagementRequested = true)
+{
+ const bool manageMemory = memoryManager && memoryManagementRequested;
+
+ // Acquire working memory (if needed)
+ if (manageMemory)
+ {
+ memoryManager->Acquire();
+ }
+
+ // Perform PostAllocationConfiguration
+ workload.PostAllocationConfigure();
+
+ // Execute the workload
+ workload.Execute();
+
+ // Release working memory (if needed)
+ if (manageMemory)
+ {
+ memoryManager->Release();
+ }
+}
+
+inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
+{
+ if (!weightsType)
+ {
+ return weightsType;
+ }
+
+ switch(weightsType.value())
+ {
+ case armnn::DataType::BFloat16:
+ case armnn::DataType::Float16:
+ case armnn::DataType::Float32:
+ return weightsType;
+ case armnn::DataType::QAsymmS8:
+ case armnn::DataType::QAsymmU8:
+ case armnn::DataType::QSymmS8:
+ case armnn::DataType::QSymmS16:
+ return armnn::DataType::Signed32;
+ default:
+ ARMNN_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
+ }
+ return armnn::EmptyOptional();
+}
+
+} // anonymous namespace