aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon
diff options
context:
space:
mode:
authorDerek Lamberti <derek.lamberti@arm.com>2019-08-01 15:56:25 +0100
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>2019-08-05 13:51:42 +0000
commitf674aa0fd2809126debdaaeb8067067790d86907 (patch)
treed86d0261c7a25149217918986043c76d0823ee44 /src/backends/backendsCommon
parent737d9ff58b348b11234b6c2363390607d576177d (diff)
downloadarmnn-f674aa0fd2809126debdaaeb8067067790d86907.tar.gz
IVGCVSW-3277 Mem export/import suppor for Tensors
* Rename MemoryStrategy to EdgeStrategy * Add MemImportLayer * Import memory rather than copy when possible Change-Id: I1d3a9414f2cbe517dc2aae9bbd4fdd92712b38ef Signed-off-by: Derek Lamberti <derek.lamberti@arm.com>
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r--src/backends/backendsCommon/CMakeLists.txt5
-rw-r--r--src/backends/backendsCommon/ITensorHandle.hpp11
-rw-r--r--src/backends/backendsCommon/ITensorHandleFactory.hpp19
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.cpp15
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.hpp4
-rw-r--r--src/backends/backendsCommon/LayerSupportRules.hpp185
-rw-r--r--src/backends/backendsCommon/MemImportWorkload.cpp34
-rw-r--r--src/backends/backendsCommon/MemImportWorkload.hpp27
-rw-r--r--src/backends/backendsCommon/MemSyncWorkload.cpp33
-rw-r--r--src/backends/backendsCommon/MemSyncWorkload.hpp26
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp103
-rw-r--r--src/backends/backendsCommon/WorkloadData.hpp10
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp16
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.hpp3
-rw-r--r--src/backends/backendsCommon/common.mk2
-rw-r--r--src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp2
16 files changed, 484 insertions, 11 deletions
diff --git a/src/backends/backendsCommon/CMakeLists.txt b/src/backends/backendsCommon/CMakeLists.txt
index 653f3727ee..44131ea1b5 100644
--- a/src/backends/backendsCommon/CMakeLists.txt
+++ b/src/backends/backendsCommon/CMakeLists.txt
@@ -20,11 +20,16 @@ list(APPEND armnnBackendsCommon_sources
ITensorHandleFactory.hpp
LayerSupportBase.cpp
LayerSupportBase.hpp
+ LayerSupportRules.hpp
IMemoryManager.hpp
ITensorHandle.hpp
MakeWorkloadHelper.hpp
MemCopyWorkload.cpp
MemCopyWorkload.hpp
+ MemImportWorkload.cpp
+ MemImportWorkload.hpp
+ MemSyncWorkload.cpp
+ MemSyncWorkload.hpp
OptimizationViews.cpp
OptimizationViews.hpp
OutputHandler.cpp
diff --git a/src/backends/backendsCommon/ITensorHandle.hpp b/src/backends/backendsCommon/ITensorHandle.hpp
index 176b021d76..e1b80b874a 100644
--- a/src/backends/backendsCommon/ITensorHandle.hpp
+++ b/src/backends/backendsCommon/ITensorHandle.hpp
@@ -4,6 +4,8 @@
//
#pragma once
+#include <armnn/MemorySources.hpp>
+
namespace armnn
{
@@ -61,6 +63,15 @@ public:
// Testing support to be able to verify and set tensor data content
virtual void CopyOutTo(void* memory) const = 0;
virtual void CopyInFrom(const void* memory) = 0;
+
+ /// Get flags describing supported import sources.
+ virtual unsigned int GetImportFlags() const { return 0; }
+
+ /// Import externally allocated memory
+ /// \param memory base address of the memory being imported.
+ /// \param source source of the allocation for the memory being imported.
+ /// \return true on success or false on failure
+ virtual bool Import(void* memory, MemorySource source) { return false; };
};
}
diff --git a/src/backends/backendsCommon/ITensorHandleFactory.hpp b/src/backends/backendsCommon/ITensorHandleFactory.hpp
index 7685061eb3..89a2a7fa3b 100644
--- a/src/backends/backendsCommon/ITensorHandleFactory.hpp
+++ b/src/backends/backendsCommon/ITensorHandleFactory.hpp
@@ -5,8 +5,9 @@
#pragma once
-#include <armnn/Types.hpp>
#include <armnn/IRuntime.hpp>
+#include <armnn/MemorySources.hpp>
+#include <armnn/Types.hpp>
namespace armnn
{
@@ -20,7 +21,6 @@ public:
virtual ~ITensorHandleFactory() {}
-
virtual std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
TensorShape const& subTensorShape,
unsigned int const* subTensorOrigin) const = 0;
@@ -33,17 +33,16 @@ public:
virtual bool SupportsMapUnmap() const final { return true; }
- virtual bool SupportsExport() const final { return false; }
-
- virtual bool SupportsImport() const final { return false; }
+ virtual MemorySourceFlags GetExportFlags() const { return 0; }
+ virtual MemorySourceFlags GetImportFlags() const { return 0; }
};
-enum class MemoryStrategy
+enum class EdgeStrategy
{
- Undefined,
- DirectCompatibility, // Only allocate the tensorhandle using the assigned factory
- CopyToTarget, // Default + Insert MemCopy node before target
- ExportToTarget, // Default + Insert Import node
+ Undefined, /// No strategy has been defined. Used internally to verify integrity of optimizations.
+ DirectCompatibility, /// Destination backend can work directly with tensors on source backend.
+ ExportToTarget, /// Source backends tensor data can be exported to destination backend tensor without copy.
+ CopyToTarget /// Copy contents from source backend tensor to destination backend tensor.
};
} //namespace armnn
diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp
index f202fedb4f..ee8dc5f7e9 100644
--- a/src/backends/backendsCommon/LayerSupportBase.cpp
+++ b/src/backends/backendsCommon/LayerSupportBase.cpp
@@ -7,6 +7,8 @@
#include <armnn/Exceptions.hpp>
+#include <boost/core/ignore_unused.hpp>
+
namespace
{
@@ -252,7 +254,18 @@ bool LayerSupportBase::IsMemCopySupported(const armnn::TensorInfo& input,
const armnn::TensorInfo& output,
armnn::Optional<std::string &> reasonIfUnsupported) const
{
- return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
+ boost::ignore_unused(input);
+ boost::ignore_unused(output);
+ return true;
+}
+
+bool LayerSupportBase::IsMemImportSupported(const armnn::TensorInfo& input,
+ const armnn::TensorInfo& output,
+ armnn::Optional<std::string &> reasonIfUnsupported) const
+{
+ boost::ignore_unused(input);
+ boost::ignore_unused(output);
+ return true;
}
bool LayerSupportBase::IsMergeSupported(const TensorInfo& input0,
diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp
index c860e34874..0d5a2af16e 100644
--- a/src/backends/backendsCommon/LayerSupportBase.hpp
+++ b/src/backends/backendsCommon/LayerSupportBase.hpp
@@ -157,6 +157,10 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsMemImportSupported(const TensorInfo& input,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsMergeSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
diff --git a/src/backends/backendsCommon/LayerSupportRules.hpp b/src/backends/backendsCommon/LayerSupportRules.hpp
new file mode 100644
index 0000000000..db3f38ccbb
--- /dev/null
+++ b/src/backends/backendsCommon/LayerSupportRules.hpp
@@ -0,0 +1,185 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <boost/assert.hpp>
+#include <algorithm>
+
+namespace armnn
+{
+
+namespace
+{
+
+inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
+{
+ if (!weightsType)
+ {
+ return weightsType;
+ }
+
+ switch(weightsType.value())
+ {
+ case armnn::DataType::Float16:
+ case armnn::DataType::Float32:
+ return weightsType;
+ case armnn::DataType::QuantisedAsymm8:
+ return armnn::DataType::Signed32;
+ case armnn::DataType::QuantisedSymm16:
+ return armnn::DataType::Signed32;
+ default:
+ BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
+ }
+ return armnn::EmptyOptional();
+}
+
+} //namespace
+
+template<typename F>
+bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
+{
+ bool supported = rule();
+ if (!supported && reason)
+ {
+ reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
+ }
+ return supported;
+}
+
+struct Rule
+{
+ bool operator()() const
+ {
+ return m_Res;
+ }
+
+ bool m_Res = true;
+};
+
+template<typename T>
+bool AllTypesAreEqualImpl(T t)
+{
+ return true;
+}
+
+template<typename T, typename... Rest>
+bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
+{
+ static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
+
+ return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
+}
+
+struct TypesAreEqual : public Rule
+{
+ template<typename ... Ts>
+ TypesAreEqual(const Ts&... ts)
+ {
+ m_Res = AllTypesAreEqualImpl(ts...);
+ }
+};
+
+struct QuantizationParametersAreEqual : public Rule
+{
+ QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
+ {
+ m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
+ info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
+ }
+};
+
+struct TypeAnyOf : public Rule
+{
+ template<typename Container>
+ TypeAnyOf(const TensorInfo& info, const Container& c)
+ {
+ m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
+ {
+ return dt == info.GetDataType();
+ });
+ }
+};
+
+struct TypeIs : public Rule
+{
+ TypeIs(const TensorInfo& info, DataType dt)
+ {
+ m_Res = dt == info.GetDataType();
+ }
+};
+
+struct BiasAndWeightsTypesMatch : public Rule
+{
+ BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
+ {
+ m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
+ }
+};
+
+struct BiasAndWeightsTypesCompatible : public Rule
+{
+ template<typename Container>
+ BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
+ {
+ m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
+ {
+ return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
+ });
+ }
+};
+
+struct ShapesAreSameRank : public Rule
+{
+ ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
+ {
+ m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
+ }
+};
+
+struct ShapesAreSameTotalSize : public Rule
+{
+ ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
+ {
+ m_Res = info0.GetNumElements() == info1.GetNumElements();
+ }
+};
+
+struct ShapesAreBroadcastCompatible : public Rule
+{
+ unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
+ {
+ unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
+ unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
+ return sizeIn;
+ }
+
+ ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
+ {
+ const TensorShape& shape0 = in0.GetShape();
+ const TensorShape& shape1 = in1.GetShape();
+ const TensorShape& outShape = out.GetShape();
+
+ for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
+ {
+ unsigned int sizeOut = outShape[i];
+ unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
+ unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
+
+ m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
+ ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
+ }
+ }
+};
+
+struct TensorNumDimensionsAreCorrect : public Rule
+{
+ TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
+ {
+ m_Res = info.GetNumDimensions() == expectedNumDimensions;
+ }
+};
+
+} //namespace armnn \ No newline at end of file
diff --git a/src/backends/backendsCommon/MemImportWorkload.cpp b/src/backends/backendsCommon/MemImportWorkload.cpp
new file mode 100644
index 0000000000..ed00241bb6
--- /dev/null
+++ b/src/backends/backendsCommon/MemImportWorkload.cpp
@@ -0,0 +1,34 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "MemImportWorkload.hpp"
+
+#include "CpuTensorHandle.hpp"
+
+#include <ResolveType.hpp>
+
+#include <boost/cast.hpp>
+
+#include <cstring>
+
+namespace armnn
+{
+
+ImportMemGenericWorkload::ImportMemGenericWorkload(const MemImportQueueDescriptor& descriptor,
+ const WorkloadInfo& info)
+ : BaseWorkload<MemImportQueueDescriptor>(descriptor, info)
+{
+ m_TensorHandlePairs = std::make_pair(descriptor.m_Inputs[0], descriptor.m_Outputs[0]);
+}
+
+void ImportMemGenericWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "ImportMemGeneric_Execute");
+
+ m_TensorHandlePairs.second->Import(const_cast<void*>(m_TensorHandlePairs.first->Map(true)), MemorySource::Malloc);
+ m_TensorHandlePairs.first->Unmap();
+}
+
+} //namespace armnn
diff --git a/src/backends/backendsCommon/MemImportWorkload.hpp b/src/backends/backendsCommon/MemImportWorkload.hpp
new file mode 100644
index 0000000000..e16b99e9e0
--- /dev/null
+++ b/src/backends/backendsCommon/MemImportWorkload.hpp
@@ -0,0 +1,27 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include "CpuTensorHandleFwd.hpp"
+#include "Workload.hpp"
+#include "WorkloadUtils.hpp"
+
+#include <utility>
+
+namespace armnn
+{
+
+class ImportMemGenericWorkload : public BaseWorkload<MemImportQueueDescriptor>
+{
+public:
+ ImportMemGenericWorkload(const MemImportQueueDescriptor& descriptor, const WorkloadInfo& info);
+ void Execute() const override;
+
+private:
+ using TensorHandlePair = std::pair<const ITensorHandle*, ITensorHandle*>;
+ TensorHandlePair m_TensorHandlePairs;
+};
+
+} //namespace armnn
diff --git a/src/backends/backendsCommon/MemSyncWorkload.cpp b/src/backends/backendsCommon/MemSyncWorkload.cpp
new file mode 100644
index 0000000000..a1d309cefb
--- /dev/null
+++ b/src/backends/backendsCommon/MemSyncWorkload.cpp
@@ -0,0 +1,33 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "MemSyncWorkload.hpp"
+
+#include "CpuTensorHandle.hpp"
+
+#include <ResolveType.hpp>
+
+#include <boost/cast.hpp>
+
+#include <cstring>
+
+namespace armnn
+{
+
+SyncMemGenericWorkload::SyncMemGenericWorkload(const MemSyncQueueDescriptor& descriptor,
+ const WorkloadInfo& info)
+ : BaseWorkload<MemSyncQueueDescriptor>(descriptor, info)
+{
+ m_TensorHandle = descriptor.m_Inputs[0];
+}
+
+void SyncMemGenericWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "SyncMemGeneric_Execute");
+ m_TensorHandle->Map(true);
+ m_TensorHandle->Unmap();
+}
+
+} //namespace armnn
diff --git a/src/backends/backendsCommon/MemSyncWorkload.hpp b/src/backends/backendsCommon/MemSyncWorkload.hpp
new file mode 100644
index 0000000000..3a167d2a00
--- /dev/null
+++ b/src/backends/backendsCommon/MemSyncWorkload.hpp
@@ -0,0 +1,26 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include "CpuTensorHandleFwd.hpp"
+#include "Workload.hpp"
+#include "WorkloadUtils.hpp"
+
+#include <utility>
+
+namespace armnn
+{
+
+class SyncMemGenericWorkload : public BaseWorkload<MemSyncQueueDescriptor>
+{
+public:
+ SyncMemGenericWorkload(const MemSyncQueueDescriptor& descriptor, const WorkloadInfo& info);
+ void Execute() const override;
+
+private:
+ ITensorHandle* m_TensorHandle;
+};
+
+} //namespace armnn
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index a4d35827fa..1c607da707 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -351,6 +351,109 @@ void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
}
}
+//---------------------------------------------------------------
+void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
+
+ if (workloadInfo.m_InputTensorInfos.size() != 1)
+ {
+ throw InvalidArgumentException(boost::str(
+ boost::format("Number of input infos (%1%) is not 1.")
+ % workloadInfo.m_InputTensorInfos.size()));
+
+ }
+
+ if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
+ {
+ throw InvalidArgumentException(boost::str(
+ boost::format("Number of input infos (%1%) does not match the number of output infos (%2%)")
+ % workloadInfo.m_InputTensorInfos.size() % workloadInfo.m_OutputTensorInfos.size()));
+ }
+
+ for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
+ {
+ if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
+ workloadInfo.m_OutputTensorInfos[i].GetNumElements())
+ {
+ throw InvalidArgumentException(boost::str(
+ boost::format("Number of elements for tensor input and output %1% does not match")
+ % i ));
+ }
+ }
+
+ if (m_Inputs.size() != 1)
+ {
+ throw InvalidArgumentException(boost::str(
+ boost::format("Number of inputs (%1%) is not 1.")
+ % m_Inputs.size()));
+ }
+
+ if (m_Inputs.size() != m_Outputs.size())
+ {
+ throw InvalidArgumentException(boost::str(
+ boost::format("Number of inputs (%1%) does not match the number of outputs (%2%)")
+ % m_Inputs.size() % m_Outputs.size()));
+ }
+
+ for (unsigned int i = 0; i < m_Inputs.size(); ++i)
+ {
+ if (!m_Inputs[i])
+ {
+ throw InvalidArgumentException(boost::str(boost::format("Invalid null input %1%") % i));
+ }
+
+ if (!m_Outputs[i])
+ {
+ throw InvalidArgumentException(boost::str(boost::format("Invalid null output %1%") % i));
+ }
+ }
+}
+
+//---------------------------------------------------------------
+void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
+
+ if (workloadInfo.m_InputTensorInfos.size() != 1)
+ {
+ throw InvalidArgumentException(boost::str(
+ boost::format("Number of input infos (%1%) is not 1.")
+ % workloadInfo.m_InputTensorInfos.size()));
+
+ }
+
+ if (workloadInfo.m_OutputTensorInfos.size() != 0)
+ {
+ throw InvalidArgumentException(boost::str(
+ boost::format("Number of output infos (%1%) is not 0.")
+ % workloadInfo.m_InputTensorInfos.size()));
+
+ }
+
+ if (m_Inputs.size() != 1)
+ {
+ throw InvalidArgumentException(boost::str(
+ boost::format("Number of inputs (%1%) is not 1.")
+ % m_Inputs.size()));
+ }
+
+ if (m_Outputs.size() != 0)
+ {
+ throw InvalidArgumentException(boost::str(
+ boost::format("Number of outputs (%1%) is not 0.")
+ % m_Inputs.size() % m_Outputs.size()));
+ }
+
+ if (!m_Inputs[0])
+ {
+ throw InvalidArgumentException(boost::str(boost::format("Invalid null input 0")));
+ }
+}
+
+//---------------------------------------------------------------
void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
const std::string descriptorName{"ActivationQueueDescriptor"};
diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp
index d790dafd58..c055beb88d 100644
--- a/src/backends/backendsCommon/WorkloadData.hpp
+++ b/src/backends/backendsCommon/WorkloadData.hpp
@@ -63,6 +63,16 @@ struct MemCopyQueueDescriptor : QueueDescriptor
using InputQueueDescriptor = MemCopyQueueDescriptor;
using OutputQueueDescriptor = MemCopyQueueDescriptor;
+struct MemImportQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
+struct MemSyncQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
// Softmax layer workload data.
struct SoftmaxQueueDescriptor : QueueDescriptorWithParameters<SoftmaxDescriptor>
{
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 1f616f0b18..ffef5b4eb7 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -515,6 +515,16 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
reason);
break;
}
+ case LayerType::MemImport:
+ {
+ const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
+ const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+
+ result = layerSupportObject->IsMemImportSupported(OverrideDataType(input, dataType),
+ OverrideDataType(output, dataType),
+ reason);
+ break;
+ }
case LayerType::Merge:
{
const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
@@ -1092,6 +1102,12 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDes
return std::unique_ptr<IWorkload>();
}
+std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return std::unique_ptr<IWorkload>();
+}
+
std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp
index bd7f1c627b..a9c6049c37 100644
--- a/src/backends/backendsCommon/WorkloadFactory.hpp
+++ b/src/backends/backendsCommon/WorkloadFactory.hpp
@@ -124,6 +124,9 @@ public:
virtual std::unique_ptr<IWorkload> CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
const WorkloadInfo& info) const;
+ virtual std::unique_ptr<IWorkload> CreateMemImport(const MemImportQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
virtual std::unique_ptr<IWorkload> CreateMerge(const MergeQueueDescriptor& descriptor,
const WorkloadInfo& info) const;
diff --git a/src/backends/backendsCommon/common.mk b/src/backends/backendsCommon/common.mk
index 69bde81b0a..eee1dae0ff 100644
--- a/src/backends/backendsCommon/common.mk
+++ b/src/backends/backendsCommon/common.mk
@@ -14,6 +14,8 @@ COMMON_SOURCES := \
ITensorHandleFactory.cpp \
LayerSupportBase.cpp \
MemCopyWorkload.cpp \
+ MemImportWorkload.cpp \
+ MemSyncWorkload.cpp \
OptimizationViews.cpp \
OutputHandler.cpp \
TensorHandleFactoryRegistry.cpp \
diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
index 451c585adc..1f43c989d6 100644
--- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
+++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
@@ -405,6 +405,8 @@ DECLARE_LAYER_POLICY_2_PARAM(Convolution2d)
DECLARE_LAYER_POLICY_1_PARAM(MemCopy)
+DECLARE_LAYER_POLICY_1_PARAM(MemImport)
+
DECLARE_LAYER_POLICY_1_PARAM(Debug)
DECLARE_LAYER_POLICY_2_PARAM(DepthwiseConvolution2d)