aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrancis Murtagh <francis.murtagh@arm.com>2018-09-24 15:01:18 +0100
committerMatthew Bentham <matthew.bentham@arm.com>2018-10-10 16:16:57 +0100
commit351d13d0b5fa698b72130012b2f069d30b911cb3 (patch)
treeb9417a78336e3e1b4c7d8775a2b3fd5bce0d624a
parent2ca4696639c9d2361b24adbd9a33225d18527fde (diff)
downloadarmnn-351d13d0b5fa698b72130012b2f069d30b911cb3.tar.gz
IVGCVSW-1888 Plumb data layout parameter for Convolution2D
* Added the DataLayout parameter to the Convolution2dDescriptor * Added the DataLayout parameter the Convolution2dQueueDescriptor * Set the DataLayout on the Descriptor in CreateWorkload() * Added overloaded factory methods for CreateTensorHandle() * Updated BuildArmComputeTensorInfo() to take DataLayout parameter. * Updated handles to take DataLayout parameter * Updated (Cl/Neon)Convolution2dWorkloadValidate * Updated (Cl/Neon)Convolution2dFloatWorkload * Updated (Cl/Neon)Convolution2dUint8Workload Change-Id: I8410668b3d727ca587bee66755cc4c4c78422f1f
-rw-r--r--include/armnn/Descriptors.hpp2
-rw-r--r--src/armnn/layers/Convolution2dLayer.cpp3
-rw-r--r--src/backends/ArmComputeTensorUtils.cpp28
-rw-r--r--src/backends/ArmComputeTensorUtils.hpp17
-rw-r--r--src/backends/ClTensorHandle.hpp5
-rw-r--r--src/backends/ClWorkloadFactory.cpp15
-rw-r--r--src/backends/ClWorkloadFactory.hpp3
-rw-r--r--src/backends/ClWorkloads/ClConvolution2dBaseWorkload.cpp8
-rw-r--r--src/backends/ClWorkloads/ClConvolution2dFloatWorkload.cpp4
-rw-r--r--src/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp4
-rw-r--r--src/backends/NeonTensorHandle.hpp5
-rw-r--r--src/backends/NeonWorkloadFactory.cpp15
-rw-r--r--src/backends/NeonWorkloadFactory.hpp3
-rw-r--r--src/backends/NeonWorkloads/NeonConvolution2dBaseWorkload.cpp12
-rw-r--r--src/backends/OutputHandler.cpp5
-rw-r--r--src/backends/OutputHandler.hpp5
-rw-r--r--src/backends/RefWorkloadFactory.cpp6
-rw-r--r--src/backends/RefWorkloadFactory.hpp3
-rw-r--r--src/backends/WorkloadData.hpp2
-rw-r--r--src/backends/WorkloadFactory.hpp3
20 files changed, 134 insertions, 14 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 8940e0b003..dfd532f633 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -216,6 +216,7 @@ struct Convolution2dDescriptor
, m_StrideX(0)
, m_StrideY(0)
, m_BiasEnabled(false)
+ , m_DataLayout(DataLayout::NCHW)
{};
uint32_t m_PadLeft;
@@ -225,6 +226,7 @@ struct Convolution2dDescriptor
uint32_t m_StrideX;
uint32_t m_StrideY;
bool m_BiasEnabled;
+ DataLayout m_DataLayout;
};
struct DepthwiseConvolution2dDescriptor
diff --git a/src/armnn/layers/Convolution2dLayer.cpp b/src/armnn/layers/Convolution2dLayer.cpp
index 71f54b88f8..07d6d7eee4 100644
--- a/src/armnn/layers/Convolution2dLayer.cpp
+++ b/src/armnn/layers/Convolution2dLayer.cpp
@@ -26,6 +26,9 @@ std::unique_ptr<IWorkload> Convolution2dLayer::CreateWorkload(const Graph& graph
Convolution2dQueueDescriptor descriptor;
descriptor.m_Weight = m_Weight.get();
+
+ descriptor.m_DataLayout = GetParameters().m_DataLayout;
+
if (m_Param.m_BiasEnabled)
{
BOOST_ASSERT_MSG(m_Bias != nullptr, "Convolution2dLayer: Bias data should not be null.");
diff --git a/src/backends/ArmComputeTensorUtils.cpp b/src/backends/ArmComputeTensorUtils.cpp
index ba9fb40cfc..e65c4ad35f 100644
--- a/src/backends/ArmComputeTensorUtils.cpp
+++ b/src/backends/ArmComputeTensorUtils.cpp
@@ -5,6 +5,7 @@
#include "ArmComputeTensorUtils.hpp"
#include "ArmComputeUtils.hpp"
+#include "armnn/Exceptions.hpp"
#include <armnn/Descriptors.hpp>
namespace armnn
@@ -66,6 +67,33 @@ arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tenso
return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
}
+arm_compute::DataLayout ConvertDataLayout(armnn::DataLayout dataLayout)
+{
+ switch(dataLayout)
+ {
+ case armnn::DataLayout::NHWC : return arm_compute::DataLayout::NHWC;
+
+ case armnn::DataLayout::NCHW : return arm_compute::DataLayout::NCHW;
+
+ default: throw InvalidArgumentException("Unknown armnn::DataLayout: [" +
+ std::to_string(static_cast<int>(dataLayout)) + "]");
+ }
+}
+
+arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
+ armnn::DataLayout dataLayout)
+{
+ const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape());
+ const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType());
+ const arm_compute::QuantizationInfo aclQuantizationInfo(tensorInfo.GetQuantizationScale(),
+ tensorInfo.GetQuantizationOffset());
+
+ arm_compute::TensorInfo clTensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo);
+ clTensorInfo.set_data_layout(ConvertDataLayout(dataLayout));
+
+ return clTensorInfo;
+}
+
arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDescriptor& descriptor)
{
using arm_compute::PoolingType;
diff --git a/src/backends/ArmComputeTensorUtils.hpp b/src/backends/ArmComputeTensorUtils.hpp
index 572e310ecf..18f41ee173 100644
--- a/src/backends/ArmComputeTensorUtils.hpp
+++ b/src/backends/ArmComputeTensorUtils.hpp
@@ -30,6 +30,16 @@ arm_compute::TensorShape BuildArmComputeTensorShape(const armnn::TensorShape& te
/// armnn::ITensorInfo.
arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo);
+/// Utility function used to convert armnn::DataLayout to arm_compute::DataLayout
+/// armnn::DataLayout.
+arm_compute::DataLayout ConvertDataLayout(armnn::DataLayout dataLayout);
+
+/// Utility function used to setup an arm_compute::ITensorInfo object whose dimensions are based on the given
+/// armnn::ITensorInfo.
+/// armnn::DataLayout.
+arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo,
+ armnn::DataLayout dataLayout);
+
/// Utility function used to setup an arm_compute::PoolingLayerInfo object from an armnn::Pooling2dDescriptor.
arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDescriptor& descriptor);
@@ -59,6 +69,13 @@ void BuildArmComputeTensor(Tensor& tensor, const armnn::TensorInfo& tensorInfo)
tensor.allocator()->init(BuildArmComputeTensorInfo(tensorInfo));
}
+/// Sets up the given ArmCompute tensor's dimensions based on the given ArmNN tensor.
+template <typename Tensor>
+void BuildArmComputeTensor(Tensor& tensor, const armnn::TensorInfo& tensorInfo, DataLayout dataLayout)
+{
+ tensor.allocator()->init(BuildArmComputeTensorInfo(tensorInfo, dataLayout));
+}
+
template <typename Tensor>
void InitialiseArmComputeTensorEmpty(Tensor& tensor)
{
diff --git a/src/backends/ClTensorHandle.hpp b/src/backends/ClTensorHandle.hpp
index 9c78192284..e3d7b5b491 100644
--- a/src/backends/ClTensorHandle.hpp
+++ b/src/backends/ClTensorHandle.hpp
@@ -37,6 +37,11 @@ public:
armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
}
+ ClTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout)
+ {
+ armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
+ }
+
arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
virtual void Allocate() override {armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);}
diff --git a/src/backends/ClWorkloadFactory.cpp b/src/backends/ClWorkloadFactory.cpp
index 591fb85dbb..5f395a2f6f 100644
--- a/src/backends/ClWorkloadFactory.cpp
+++ b/src/backends/ClWorkloadFactory.cpp
@@ -55,6 +55,15 @@ std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const Tenso
return tensorHandle;
}
+std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
+ DataLayout dataLayout) const
+{
+ std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo, dataLayout);
+ tensorHandle->SetMemoryGroup(m_MemoryManager.GetInterLayerMemoryGroup());
+
+ return tensorHandle;
+}
+
std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateSubTensorHandle(ITensorHandle& parent,
TensorShape const& subTensorShape,
unsigned int const* subTensorOrigin) const
@@ -290,6 +299,12 @@ std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const Tenso
return nullptr;
}
+std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
+ DataLayout dataLayout) const
+{
+ return nullptr;
+}
+
std::unique_ptr<ITensorHandle> ClWorkloadFactory::CreateSubTensorHandle(ITensorHandle& parent,
TensorShape const& subTensorShape,
unsigned int const* subTensorOrigin) const
diff --git a/src/backends/ClWorkloadFactory.hpp b/src/backends/ClWorkloadFactory.hpp
index 892d564fbb..d0bf4160f6 100644
--- a/src/backends/ClWorkloadFactory.hpp
+++ b/src/backends/ClWorkloadFactory.hpp
@@ -33,6 +33,9 @@ public:
virtual std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
+ virtual std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
+ DataLayout dataLayout) const override;
+
virtual std::unique_ptr<IWorkload> CreateInput(const InputQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
diff --git a/src/backends/ClWorkloads/ClConvolution2dBaseWorkload.cpp b/src/backends/ClWorkloads/ClConvolution2dBaseWorkload.cpp
index 228f17d54e..110a2dab3a 100644
--- a/src/backends/ClWorkloads/ClConvolution2dBaseWorkload.cpp
+++ b/src/backends/ClWorkloads/ClConvolution2dBaseWorkload.cpp
@@ -21,9 +21,9 @@ arm_compute::Status ClConvolution2dWorkloadValidate(const TensorInfo& input,
const TensorInfo& weights,
const boost::optional<TensorInfo>& biases)
{
- const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input);
- const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
- const arm_compute::TensorInfo aclWeightsInfo = BuildArmComputeTensorInfo(weights);
+ const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
+ const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
+ const arm_compute::TensorInfo aclWeightsInfo = BuildArmComputeTensorInfo(weights, descriptor.m_DataLayout);
arm_compute::TensorInfo aclBiasesInfo;
arm_compute::TensorInfo *optionalAclBiasesInfo = nullptr;
@@ -32,7 +32,7 @@ arm_compute::Status ClConvolution2dWorkloadValidate(const TensorInfo& input,
{
BOOST_ASSERT(biases.is_initialized());
- aclBiasesInfo = BuildArmComputeTensorInfo(biases.get());
+ aclBiasesInfo = BuildArmComputeTensorInfo(biases.get(), descriptor.m_DataLayout);
optionalAclBiasesInfo = &aclBiasesInfo;
}
diff --git a/src/backends/ClWorkloads/ClConvolution2dFloatWorkload.cpp b/src/backends/ClWorkloads/ClConvolution2dFloatWorkload.cpp
index f0b9a46d60..3da6fa7d8f 100644
--- a/src/backends/ClWorkloads/ClConvolution2dFloatWorkload.cpp
+++ b/src/backends/ClWorkloads/ClConvolution2dFloatWorkload.cpp
@@ -25,7 +25,7 @@ ClConvolution2dFloatWorkload::ClConvolution2dFloatWorkload(const Convolution2dQu
const TensorInfo& weightInfo = m_Data.m_Weight->GetTensorInfo();
m_KernelTensor = std::make_unique<arm_compute::CLTensor>();
- BuildArmComputeTensor(*m_KernelTensor, weightInfo);
+ BuildArmComputeTensor(*m_KernelTensor, weightInfo, descriptor.m_DataLayout);
arm_compute::PadStrideInfo padStrideInfo(m_Data.m_Parameters.m_StrideX,
m_Data.m_Parameters.m_StrideY,
@@ -38,7 +38,7 @@ ClConvolution2dFloatWorkload::ClConvolution2dFloatWorkload(const Convolution2dQu
if (m_Data.m_Parameters.m_BiasEnabled)
{
m_BiasTensor = std::make_unique<arm_compute::CLTensor>();
- BuildArmComputeTensor(*m_BiasTensor, m_Data.m_Bias->GetTensorInfo());
+ BuildArmComputeTensor(*m_BiasTensor, m_Data.m_Bias->GetTensorInfo(), descriptor.m_DataLayout);
}
m_Data.ValidateInputsOutputs("ClConvolution2dFloat32Workload", 1, 1);
diff --git a/src/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp b/src/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp
index c9f5eaa31d..3949a74c96 100644
--- a/src/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp
+++ b/src/backends/ClWorkloads/ClConvolution2dUint8Workload.cpp
@@ -24,7 +24,7 @@ ClConvolution2dUint8Workload::ClConvolution2dUint8Workload(const Convolution2dQu
const TensorInfo& weightInfo = m_Data.m_Weight->GetTensorInfo();
m_KernelTensor = std::make_unique<arm_compute::CLTensor>();
- BuildArmComputeTensor(*m_KernelTensor, weightInfo);
+ BuildArmComputeTensor(*m_KernelTensor, weightInfo, descriptor.m_DataLayout);
arm_compute::PadStrideInfo padStrideInfo(m_Data.m_Parameters.m_StrideX,
m_Data.m_Parameters.m_StrideY,
@@ -37,7 +37,7 @@ ClConvolution2dUint8Workload::ClConvolution2dUint8Workload(const Convolution2dQu
if (m_Data.m_Parameters.m_BiasEnabled)
{
m_BiasTensor = std::make_unique<arm_compute::CLTensor>();
- BuildArmComputeTensor(*m_BiasTensor, m_Data.m_Bias->GetTensorInfo());
+ BuildArmComputeTensor(*m_BiasTensor, m_Data.m_Bias->GetTensorInfo(), descriptor.m_DataLayout);
}
m_Data.ValidateInputsOutputs("ClConvolution2dUint8Workload", 1, 1);
diff --git a/src/backends/NeonTensorHandle.hpp b/src/backends/NeonTensorHandle.hpp
index e385c83967..77f3cc1184 100644
--- a/src/backends/NeonTensorHandle.hpp
+++ b/src/backends/NeonTensorHandle.hpp
@@ -36,6 +36,11 @@ public:
armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
}
+ NeonTensorHandle(const TensorInfo& tensorInfo, DataLayout dataLayout)
+ {
+ armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
+ }
+
arm_compute::ITensor& GetTensor() override { return m_Tensor; }
arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
diff --git a/src/backends/NeonWorkloadFactory.cpp b/src/backends/NeonWorkloadFactory.cpp
index 80ce0b918e..c989121eac 100644
--- a/src/backends/NeonWorkloadFactory.cpp
+++ b/src/backends/NeonWorkloadFactory.cpp
@@ -67,6 +67,15 @@ std::unique_ptr<ITensorHandle> NeonWorkloadFactory::CreateTensorHandle(const Ten
return tensorHandle;
}
+std::unique_ptr<ITensorHandle> NeonWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
+ DataLayout dataLayout) const
+{
+ auto tensorHandle = std::make_unique<NeonTensorHandle>(tensorInfo, dataLayout);
+ tensorHandle->SetMemoryGroup(m_MemoryManager.GetInterLayerMemoryGroup());
+
+ return tensorHandle;
+}
+
std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
@@ -289,6 +298,12 @@ std::unique_ptr<ITensorHandle> NeonWorkloadFactory::CreateTensorHandle(const Ten
return nullptr;
}
+std::unique_ptr<ITensorHandle> NeonWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
+ DataLayout dataLayout) const
+{
+ return nullptr;
+}
+
std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
diff --git a/src/backends/NeonWorkloadFactory.hpp b/src/backends/NeonWorkloadFactory.hpp
index a981855314..45d1c2c8c0 100644
--- a/src/backends/NeonWorkloadFactory.hpp
+++ b/src/backends/NeonWorkloadFactory.hpp
@@ -33,6 +33,9 @@ public:
virtual std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
+ virtual std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
+ DataLayout dataLayout) const override;
+
virtual std::unique_ptr<IWorkload> CreateInput(const InputQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
diff --git a/src/backends/NeonWorkloads/NeonConvolution2dBaseWorkload.cpp b/src/backends/NeonWorkloads/NeonConvolution2dBaseWorkload.cpp
index 0e9894ce78..912e2d5b69 100644
--- a/src/backends/NeonWorkloads/NeonConvolution2dBaseWorkload.cpp
+++ b/src/backends/NeonWorkloads/NeonConvolution2dBaseWorkload.cpp
@@ -23,9 +23,9 @@ arm_compute::Status NeonConvolution2dWorkloadValidate(const TensorInfo& input,
const TensorInfo& weights,
const boost::optional<TensorInfo>& biases)
{
- const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input);
- const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output);
- const arm_compute::TensorInfo aclWeightsInfo = BuildArmComputeTensorInfo(weights);
+ const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input, descriptor.m_DataLayout);
+ const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output, descriptor.m_DataLayout);
+ const arm_compute::TensorInfo aclWeightsInfo = BuildArmComputeTensorInfo(weights, descriptor.m_DataLayout);
arm_compute::TensorInfo aclBiasesInfo;
arm_compute::TensorInfo *optionalAclBiasesInfo = nullptr;
@@ -34,7 +34,7 @@ arm_compute::Status NeonConvolution2dWorkloadValidate(const TensorInfo& input,
{
BOOST_ASSERT(biases.is_initialized());
- aclBiasesInfo = BuildArmComputeTensorInfo(biases.get());
+ aclBiasesInfo = BuildArmComputeTensorInfo(biases.get(), descriptor.m_DataLayout);
optionalAclBiasesInfo = &aclBiasesInfo;
}
@@ -63,12 +63,12 @@ NeonConvolution2dBaseWorkload<dataTypes...>::NeonConvolution2dBaseWorkload(
arm_compute::ITensor& output = boost::polymorphic_downcast<INeonTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
m_KernelTensor = std::make_unique<arm_compute::Tensor>();
- BuildArmComputeTensor(*m_KernelTensor, m_Data.m_Weight->GetTensorInfo());
+ BuildArmComputeTensor(*m_KernelTensor, m_Data.m_Weight->GetTensorInfo(), descriptor.m_DataLayout);
if (m_Data.m_Parameters.m_BiasEnabled)
{
m_BiasTensor = std::make_unique<arm_compute::Tensor>();
- BuildArmComputeTensor(*m_BiasTensor, m_Data.m_Bias->GetTensorInfo());
+ BuildArmComputeTensor(*m_BiasTensor, m_Data.m_Bias->GetTensorInfo(), descriptor.m_DataLayout);
}
arm_compute::PadStrideInfo padStrideInfo(m_Data.m_Parameters.m_StrideX,
diff --git a/src/backends/OutputHandler.cpp b/src/backends/OutputHandler.cpp
index c1be5b7dc4..4dfa1a621e 100644
--- a/src/backends/OutputHandler.cpp
+++ b/src/backends/OutputHandler.cpp
@@ -25,6 +25,11 @@ void OutputHandler::CreateTensorHandles(const IWorkloadFactory& factory)
m_TensorHandle = factory.CreateTensorHandle(m_TensorInfo);
}
+void OutputHandler::CreateTensorHandles(const IWorkloadFactory& factory, DataLayout dataLayout)
+{
+ m_TensorHandle = factory.CreateTensorHandle(m_TensorInfo, dataLayout);
+}
+
void OutputHandler::CollectWorkloadOutputs(WorkloadDataCollector& dataCollector) const
{
dataCollector.Push(m_TensorHandle.get(), m_TensorInfo);
diff --git a/src/backends/OutputHandler.hpp b/src/backends/OutputHandler.hpp
index dfc01844c9..97da87d8cc 100644
--- a/src/backends/OutputHandler.hpp
+++ b/src/backends/OutputHandler.hpp
@@ -39,6 +39,11 @@ public:
/// @param factory - Factory to be used for handler creation.
void CreateTensorHandles(const IWorkloadFactory& factory);
+ /// @brief - Creates tensor handlers used by the intermediate tensors. Does not allocate memory.
+ /// @param factory - Factory to be used for handler creation.
+ /// @param dataLayout - Data Layout to be used for handler creation.
+ void CreateTensorHandles(const IWorkloadFactory& factory, DataLayout dataLayout);
+
/// @brief - Gets the matching TensorInfo for the output.
/// @return - References to the output TensorInfo.
const TensorInfo& GetTensorInfo() const { return m_TensorInfo; }
diff --git a/src/backends/RefWorkloadFactory.cpp b/src/backends/RefWorkloadFactory.cpp
index 93b4d946c4..689adb628a 100644
--- a/src/backends/RefWorkloadFactory.cpp
+++ b/src/backends/RefWorkloadFactory.cpp
@@ -36,6 +36,12 @@ std::unique_ptr<ITensorHandle> RefWorkloadFactory::CreateTensorHandle(const Tens
return std::make_unique<ScopedCpuTensorHandle>(tensorInfo);
}
+std::unique_ptr<ITensorHandle> RefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
+ DataLayout dataLayout) const
+{
+ return std::make_unique<ScopedCpuTensorHandle>(tensorInfo);
+}
+
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
diff --git a/src/backends/RefWorkloadFactory.hpp b/src/backends/RefWorkloadFactory.hpp
index 6b13377167..da0ca9b066 100644
--- a/src/backends/RefWorkloadFactory.hpp
+++ b/src/backends/RefWorkloadFactory.hpp
@@ -49,6 +49,9 @@ public:
virtual std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
+ virtual std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
+ DataLayout dataLayout) const override;
+
virtual std::unique_ptr<IWorkload> CreateInput(const InputQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
diff --git a/src/backends/WorkloadData.hpp b/src/backends/WorkloadData.hpp
index b5b0402237..5da9e8b1fd 100644
--- a/src/backends/WorkloadData.hpp
+++ b/src/backends/WorkloadData.hpp
@@ -142,11 +142,13 @@ struct Convolution2dQueueDescriptor : QueueDescriptorWithParameters<Convolution2
Convolution2dQueueDescriptor()
: m_Weight(nullptr)
, m_Bias(nullptr)
+ , m_DataLayout(DataLayout::NCHW)
{
}
const ConstCpuTensorHandle* m_Weight;
const ConstCpuTensorHandle* m_Bias;
+ DataLayout m_DataLayout;
void Validate(const WorkloadInfo& workloadInfo) const;
};
diff --git a/src/backends/WorkloadFactory.hpp b/src/backends/WorkloadFactory.hpp
index fbc6134574..77e810c9ad 100644
--- a/src/backends/WorkloadFactory.hpp
+++ b/src/backends/WorkloadFactory.hpp
@@ -49,6 +49,9 @@ public:
virtual std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const = 0;
+ virtual std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
+ DataLayout dataLayout) const = 0;
+
virtual std::unique_ptr<IWorkload> CreateOutput(const OutputQueueDescriptor& descriptor,
const WorkloadInfo& info) const = 0;