diff options
author | Derek Lamberti <derek.lamberti@arm.com> | 2019-08-01 15:56:25 +0100 |
---|---|---|
committer | Áron Virginás-Tar <aron.virginas-tar@arm.com> | 2019-08-05 13:51:42 +0000 |
commit | f674aa0fd2809126debdaaeb8067067790d86907 (patch) | |
tree | d86d0261c7a25149217918986043c76d0823ee44 /src/backends/backendsCommon | |
parent | 737d9ff58b348b11234b6c2363390607d576177d (diff) | |
download | armnn-f674aa0fd2809126debdaaeb8067067790d86907.tar.gz |
IVGCVSW-3277 Mem export/import suppor for Tensors
* Rename MemoryStrategy to EdgeStrategy
* Add MemImportLayer
* Import memory rather than copy when possible
Change-Id: I1d3a9414f2cbe517dc2aae9bbd4fdd92712b38ef
Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r-- | src/backends/backendsCommon/CMakeLists.txt | 5 | ||||
-rw-r--r-- | src/backends/backendsCommon/ITensorHandle.hpp | 11 | ||||
-rw-r--r-- | src/backends/backendsCommon/ITensorHandleFactory.hpp | 19 | ||||
-rw-r--r-- | src/backends/backendsCommon/LayerSupportBase.cpp | 15 | ||||
-rw-r--r-- | src/backends/backendsCommon/LayerSupportBase.hpp | 4 | ||||
-rw-r--r-- | src/backends/backendsCommon/LayerSupportRules.hpp | 185 | ||||
-rw-r--r-- | src/backends/backendsCommon/MemImportWorkload.cpp | 34 | ||||
-rw-r--r-- | src/backends/backendsCommon/MemImportWorkload.hpp | 27 | ||||
-rw-r--r-- | src/backends/backendsCommon/MemSyncWorkload.cpp | 33 | ||||
-rw-r--r-- | src/backends/backendsCommon/MemSyncWorkload.hpp | 26 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 103 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.hpp | 10 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.cpp | 16 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.hpp | 3 | ||||
-rw-r--r-- | src/backends/backendsCommon/common.mk | 2 | ||||
-rw-r--r-- | src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp | 2 |
16 files changed, 484 insertions, 11 deletions
diff --git a/src/backends/backendsCommon/CMakeLists.txt b/src/backends/backendsCommon/CMakeLists.txt index 653f3727ee..44131ea1b5 100644 --- a/src/backends/backendsCommon/CMakeLists.txt +++ b/src/backends/backendsCommon/CMakeLists.txt @@ -20,11 +20,16 @@ list(APPEND armnnBackendsCommon_sources ITensorHandleFactory.hpp LayerSupportBase.cpp LayerSupportBase.hpp + LayerSupportRules.hpp IMemoryManager.hpp ITensorHandle.hpp MakeWorkloadHelper.hpp MemCopyWorkload.cpp MemCopyWorkload.hpp + MemImportWorkload.cpp + MemImportWorkload.hpp + MemSyncWorkload.cpp + MemSyncWorkload.hpp OptimizationViews.cpp OptimizationViews.hpp OutputHandler.cpp diff --git a/src/backends/backendsCommon/ITensorHandle.hpp b/src/backends/backendsCommon/ITensorHandle.hpp index 176b021d76..e1b80b874a 100644 --- a/src/backends/backendsCommon/ITensorHandle.hpp +++ b/src/backends/backendsCommon/ITensorHandle.hpp @@ -4,6 +4,8 @@ // #pragma once +#include <armnn/MemorySources.hpp> + namespace armnn { @@ -61,6 +63,15 @@ public: // Testing support to be able to verify and set tensor data content virtual void CopyOutTo(void* memory) const = 0; virtual void CopyInFrom(const void* memory) = 0; + + /// Get flags describing supported import sources. + virtual unsigned int GetImportFlags() const { return 0; } + + /// Import externally allocated memory + /// \param memory base address of the memory being imported. + /// \param source source of the allocation for the memory being imported. + /// \return true on success or false on failure + virtual bool Import(void* memory, MemorySource source) { return false; }; }; } diff --git a/src/backends/backendsCommon/ITensorHandleFactory.hpp b/src/backends/backendsCommon/ITensorHandleFactory.hpp index 7685061eb3..89a2a7fa3b 100644 --- a/src/backends/backendsCommon/ITensorHandleFactory.hpp +++ b/src/backends/backendsCommon/ITensorHandleFactory.hpp @@ -5,8 +5,9 @@ #pragma once -#include <armnn/Types.hpp> #include <armnn/IRuntime.hpp> +#include <armnn/MemorySources.hpp> +#include <armnn/Types.hpp> namespace armnn { @@ -20,7 +21,6 @@ public: virtual ~ITensorHandleFactory() {} - virtual std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent, TensorShape const& subTensorShape, unsigned int const* subTensorOrigin) const = 0; @@ -33,17 +33,16 @@ public: virtual bool SupportsMapUnmap() const final { return true; } - virtual bool SupportsExport() const final { return false; } - - virtual bool SupportsImport() const final { return false; } + virtual MemorySourceFlags GetExportFlags() const { return 0; } + virtual MemorySourceFlags GetImportFlags() const { return 0; } }; -enum class MemoryStrategy +enum class EdgeStrategy { - Undefined, - DirectCompatibility, // Only allocate the tensorhandle using the assigned factory - CopyToTarget, // Default + Insert MemCopy node before target - ExportToTarget, // Default + Insert Import node + Undefined, /// No strategy has been defined. Used internally to verify integrity of optimizations. + DirectCompatibility, /// Destination backend can work directly with tensors on source backend. + ExportToTarget, /// Source backends tensor data can be exported to destination backend tensor without copy. + CopyToTarget /// Copy contents from source backend tensor to destination backend tensor. }; } //namespace armnn diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp index f202fedb4f..ee8dc5f7e9 100644 --- a/src/backends/backendsCommon/LayerSupportBase.cpp +++ b/src/backends/backendsCommon/LayerSupportBase.cpp @@ -7,6 +7,8 @@ #include <armnn/Exceptions.hpp> +#include <boost/core/ignore_unused.hpp> + namespace { @@ -252,7 +254,18 @@ bool LayerSupportBase::IsMemCopySupported(const armnn::TensorInfo& input, const armnn::TensorInfo& output, armnn::Optional<std::string &> reasonIfUnsupported) const { - return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); + boost::ignore_unused(input); + boost::ignore_unused(output); + return true; +} + +bool LayerSupportBase::IsMemImportSupported(const armnn::TensorInfo& input, + const armnn::TensorInfo& output, + armnn::Optional<std::string &> reasonIfUnsupported) const +{ + boost::ignore_unused(input); + boost::ignore_unused(output); + return true; } bool LayerSupportBase::IsMergeSupported(const TensorInfo& input0, diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp index c860e34874..0d5a2af16e 100644 --- a/src/backends/backendsCommon/LayerSupportBase.hpp +++ b/src/backends/backendsCommon/LayerSupportBase.hpp @@ -157,6 +157,10 @@ public: const TensorInfo& output, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + bool IsMemImportSupported(const TensorInfo& input, + const TensorInfo& output, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + bool IsMergeSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, diff --git a/src/backends/backendsCommon/LayerSupportRules.hpp b/src/backends/backendsCommon/LayerSupportRules.hpp new file mode 100644 index 0000000000..db3f38ccbb --- /dev/null +++ b/src/backends/backendsCommon/LayerSupportRules.hpp @@ -0,0 +1,185 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <boost/assert.hpp> +#include <algorithm> + +namespace armnn +{ + +namespace +{ + +inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType) +{ + if (!weightsType) + { + return weightsType; + } + + switch(weightsType.value()) + { + case armnn::DataType::Float16: + case armnn::DataType::Float32: + return weightsType; + case armnn::DataType::QuantisedAsymm8: + return armnn::DataType::Signed32; + case armnn::DataType::QuantisedSymm16: + return armnn::DataType::Signed32; + default: + BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type."); + } + return armnn::EmptyOptional(); +} + +} //namespace + +template<typename F> +bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason) +{ + bool supported = rule(); + if (!supported && reason) + { + reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line + } + return supported; +} + +struct Rule +{ + bool operator()() const + { + return m_Res; + } + + bool m_Res = true; +}; + +template<typename T> +bool AllTypesAreEqualImpl(T t) +{ + return true; +} + +template<typename T, typename... Rest> +bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest) +{ + static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo"); + + return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...); +} + +struct TypesAreEqual : public Rule +{ + template<typename ... Ts> + TypesAreEqual(const Ts&... ts) + { + m_Res = AllTypesAreEqualImpl(ts...); + } +}; + +struct QuantizationParametersAreEqual : public Rule +{ + QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1) + { + m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() && + info0.GetQuantizationOffset() == info1.GetQuantizationOffset(); + } +}; + +struct TypeAnyOf : public Rule +{ + template<typename Container> + TypeAnyOf(const TensorInfo& info, const Container& c) + { + m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt) + { + return dt == info.GetDataType(); + }); + } +}; + +struct TypeIs : public Rule +{ + TypeIs(const TensorInfo& info, DataType dt) + { + m_Res = dt == info.GetDataType(); + } +}; + +struct BiasAndWeightsTypesMatch : public Rule +{ + BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights) + { + m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value(); + } +}; + +struct BiasAndWeightsTypesCompatible : public Rule +{ + template<typename Container> + BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c) + { + m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt) + { + return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value(); + }); + } +}; + +struct ShapesAreSameRank : public Rule +{ + ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1) + { + m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions(); + } +}; + +struct ShapesAreSameTotalSize : public Rule +{ + ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1) + { + m_Res = info0.GetNumElements() == info1.GetNumElements(); + } +}; + +struct ShapesAreBroadcastCompatible : public Rule +{ + unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx) + { + unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions(); + unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset]; + return sizeIn; + } + + ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out) + { + const TensorShape& shape0 = in0.GetShape(); + const TensorShape& shape1 = in1.GetShape(); + const TensorShape& outShape = out.GetShape(); + + for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++) + { + unsigned int sizeOut = outShape[i]; + unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i); + unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i); + + m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) && + ((sizeIn1 == sizeOut) || (sizeIn1 == 1)); + } + } +}; + +struct TensorNumDimensionsAreCorrect : public Rule +{ + TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions) + { + m_Res = info.GetNumDimensions() == expectedNumDimensions; + } +}; + +} //namespace armnn
\ No newline at end of file diff --git a/src/backends/backendsCommon/MemImportWorkload.cpp b/src/backends/backendsCommon/MemImportWorkload.cpp new file mode 100644 index 0000000000..ed00241bb6 --- /dev/null +++ b/src/backends/backendsCommon/MemImportWorkload.cpp @@ -0,0 +1,34 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "MemImportWorkload.hpp" + +#include "CpuTensorHandle.hpp" + +#include <ResolveType.hpp> + +#include <boost/cast.hpp> + +#include <cstring> + +namespace armnn +{ + +ImportMemGenericWorkload::ImportMemGenericWorkload(const MemImportQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload<MemImportQueueDescriptor>(descriptor, info) +{ + m_TensorHandlePairs = std::make_pair(descriptor.m_Inputs[0], descriptor.m_Outputs[0]); +} + +void ImportMemGenericWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "ImportMemGeneric_Execute"); + + m_TensorHandlePairs.second->Import(const_cast<void*>(m_TensorHandlePairs.first->Map(true)), MemorySource::Malloc); + m_TensorHandlePairs.first->Unmap(); +} + +} //namespace armnn diff --git a/src/backends/backendsCommon/MemImportWorkload.hpp b/src/backends/backendsCommon/MemImportWorkload.hpp new file mode 100644 index 0000000000..e16b99e9e0 --- /dev/null +++ b/src/backends/backendsCommon/MemImportWorkload.hpp @@ -0,0 +1,27 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "CpuTensorHandleFwd.hpp" +#include "Workload.hpp" +#include "WorkloadUtils.hpp" + +#include <utility> + +namespace armnn +{ + +class ImportMemGenericWorkload : public BaseWorkload<MemImportQueueDescriptor> +{ +public: + ImportMemGenericWorkload(const MemImportQueueDescriptor& descriptor, const WorkloadInfo& info); + void Execute() const override; + +private: + using TensorHandlePair = std::pair<const ITensorHandle*, ITensorHandle*>; + TensorHandlePair m_TensorHandlePairs; +}; + +} //namespace armnn diff --git a/src/backends/backendsCommon/MemSyncWorkload.cpp b/src/backends/backendsCommon/MemSyncWorkload.cpp new file mode 100644 index 0000000000..a1d309cefb --- /dev/null +++ b/src/backends/backendsCommon/MemSyncWorkload.cpp @@ -0,0 +1,33 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "MemSyncWorkload.hpp" + +#include "CpuTensorHandle.hpp" + +#include <ResolveType.hpp> + +#include <boost/cast.hpp> + +#include <cstring> + +namespace armnn +{ + +SyncMemGenericWorkload::SyncMemGenericWorkload(const MemSyncQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload<MemSyncQueueDescriptor>(descriptor, info) +{ + m_TensorHandle = descriptor.m_Inputs[0]; +} + +void SyncMemGenericWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "SyncMemGeneric_Execute"); + m_TensorHandle->Map(true); + m_TensorHandle->Unmap(); +} + +} //namespace armnn diff --git a/src/backends/backendsCommon/MemSyncWorkload.hpp b/src/backends/backendsCommon/MemSyncWorkload.hpp new file mode 100644 index 0000000000..3a167d2a00 --- /dev/null +++ b/src/backends/backendsCommon/MemSyncWorkload.hpp @@ -0,0 +1,26 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "CpuTensorHandleFwd.hpp" +#include "Workload.hpp" +#include "WorkloadUtils.hpp" + +#include <utility> + +namespace armnn +{ + +class SyncMemGenericWorkload : public BaseWorkload<MemSyncQueueDescriptor> +{ +public: + SyncMemGenericWorkload(const MemSyncQueueDescriptor& descriptor, const WorkloadInfo& info); + void Execute() const override; + +private: + ITensorHandle* m_TensorHandle; +}; + +} //namespace armnn diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index a4d35827fa..1c607da707 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -351,6 +351,109 @@ void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const } } +//--------------------------------------------------------------- +void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1); + + if (workloadInfo.m_InputTensorInfos.size() != 1) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of input infos (%1%) is not 1.") + % workloadInfo.m_InputTensorInfos.size())); + + } + + if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size()) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of input infos (%1%) does not match the number of output infos (%2%)") + % workloadInfo.m_InputTensorInfos.size() % workloadInfo.m_OutputTensorInfos.size())); + } + + for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i) + { + if (workloadInfo.m_InputTensorInfos[i].GetNumElements() != + workloadInfo.m_OutputTensorInfos[i].GetNumElements()) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of elements for tensor input and output %1% does not match") + % i )); + } + } + + if (m_Inputs.size() != 1) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of inputs (%1%) is not 1.") + % m_Inputs.size())); + } + + if (m_Inputs.size() != m_Outputs.size()) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of inputs (%1%) does not match the number of outputs (%2%)") + % m_Inputs.size() % m_Outputs.size())); + } + + for (unsigned int i = 0; i < m_Inputs.size(); ++i) + { + if (!m_Inputs[i]) + { + throw InvalidArgumentException(boost::str(boost::format("Invalid null input %1%") % i)); + } + + if (!m_Outputs[i]) + { + throw InvalidArgumentException(boost::str(boost::format("Invalid null output %1%") % i)); + } + } +} + +//--------------------------------------------------------------- +void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1); + ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1); + + if (workloadInfo.m_InputTensorInfos.size() != 1) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of input infos (%1%) is not 1.") + % workloadInfo.m_InputTensorInfos.size())); + + } + + if (workloadInfo.m_OutputTensorInfos.size() != 0) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of output infos (%1%) is not 0.") + % workloadInfo.m_InputTensorInfos.size())); + + } + + if (m_Inputs.size() != 1) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of inputs (%1%) is not 1.") + % m_Inputs.size())); + } + + if (m_Outputs.size() != 0) + { + throw InvalidArgumentException(boost::str( + boost::format("Number of outputs (%1%) is not 0.") + % m_Inputs.size() % m_Outputs.size())); + } + + if (!m_Inputs[0]) + { + throw InvalidArgumentException(boost::str(boost::format("Invalid null input 0"))); + } +} + +//--------------------------------------------------------------- void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { const std::string descriptorName{"ActivationQueueDescriptor"}; diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp index d790dafd58..c055beb88d 100644 --- a/src/backends/backendsCommon/WorkloadData.hpp +++ b/src/backends/backendsCommon/WorkloadData.hpp @@ -63,6 +63,16 @@ struct MemCopyQueueDescriptor : QueueDescriptor using InputQueueDescriptor = MemCopyQueueDescriptor; using OutputQueueDescriptor = MemCopyQueueDescriptor; +struct MemImportQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + +struct MemSyncQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + // Softmax layer workload data. struct SoftmaxQueueDescriptor : QueueDescriptorWithParameters<SoftmaxDescriptor> { diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index 1f616f0b18..ffef5b4eb7 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -515,6 +515,16 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, reason); break; } + case LayerType::MemImport: + { + const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); + const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); + + result = layerSupportObject->IsMemImportSupported(OverrideDataType(input, dataType), + OverrideDataType(output, dataType), + reason); + break; + } case LayerType::Merge: { const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); @@ -1092,6 +1102,12 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDes return std::unique_ptr<IWorkload>(); } +std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::unique_ptr<IWorkload>(); +} + std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& descriptor, const WorkloadInfo& info) const { diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp index bd7f1c627b..a9c6049c37 100644 --- a/src/backends/backendsCommon/WorkloadFactory.hpp +++ b/src/backends/backendsCommon/WorkloadFactory.hpp @@ -124,6 +124,9 @@ public: virtual std::unique_ptr<IWorkload> CreateMemCopy(const MemCopyQueueDescriptor& descriptor, const WorkloadInfo& info) const; + virtual std::unique_ptr<IWorkload> CreateMemImport(const MemImportQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + virtual std::unique_ptr<IWorkload> CreateMerge(const MergeQueueDescriptor& descriptor, const WorkloadInfo& info) const; diff --git a/src/backends/backendsCommon/common.mk b/src/backends/backendsCommon/common.mk index 69bde81b0a..eee1dae0ff 100644 --- a/src/backends/backendsCommon/common.mk +++ b/src/backends/backendsCommon/common.mk @@ -14,6 +14,8 @@ COMMON_SOURCES := \ ITensorHandleFactory.cpp \ LayerSupportBase.cpp \ MemCopyWorkload.cpp \ + MemImportWorkload.cpp \ + MemSyncWorkload.cpp \ OptimizationViews.cpp \ OutputHandler.cpp \ TensorHandleFactoryRegistry.cpp \ diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp index 451c585adc..1f43c989d6 100644 --- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp +++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp @@ -405,6 +405,8 @@ DECLARE_LAYER_POLICY_2_PARAM(Convolution2d) DECLARE_LAYER_POLICY_1_PARAM(MemCopy) +DECLARE_LAYER_POLICY_1_PARAM(MemImport) + DECLARE_LAYER_POLICY_1_PARAM(Debug) DECLARE_LAYER_POLICY_2_PARAM(DepthwiseConvolution2d) |