aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadUtils.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadUtils.cpp88
1 files changed, 82 insertions, 6 deletions
diff --git a/src/backends/backendsCommon/WorkloadUtils.cpp b/src/backends/backendsCommon/WorkloadUtils.cpp
index fa387a7a0b..3185ba00d3 100644
--- a/src/backends/backendsCommon/WorkloadUtils.cpp
+++ b/src/backends/backendsCommon/WorkloadUtils.cpp
@@ -9,8 +9,7 @@ namespace armnn
{
armnn::ConstTensor PermuteTensor(const ConstCpuTensorHandle* tensor,
- const PermutationVector& permutationVector,
- void* permuteBuffer)
+ const PermutationVector& permutationVector, void* permuteBuffer)
{
BOOST_ASSERT_MSG(tensor, "Invalid input tensor");
BOOST_ASSERT_MSG(permuteBuffer, "Invalid permute buffer");
@@ -44,16 +43,70 @@ void ReshapeWeightsForAcl(TensorInfo& weightInfo, DataLayout dataLayout)
weightShape[0],
weightShape[1],
weightShape[2] * weightShape[3] });
- break;
- case DataLayout::NCHW:
- default:
- // The data layout is NCHW, reshape from [ M, I, H, W ] to [ 1, I * M, H, W, ]
weightInfo.SetShape({ 1,
weightShape[0] * weightShape[1],
weightShape[2],
weightShape[3] });
break;
+ case DataLayout::NCHW:
+ default:
+ // The data layout is NCHW, reshape from [ M, I, H, W ] to [ 1, I * M, H, W, ]
+ weightInfo.SetShape({ 1, weightShape[0] * weightShape[1], weightShape[2], weightShape[3] });
+ break;
+ }
+}
+
+template <typename DataType>
+ConstTensor ReorderWeightChannelsForAcl(const ConstTensor& weightHandle, DataLayout dataLayout, void* permuteBuffer)
+{
+ DataType* weight = static_cast<DataType*>(permuteBuffer);
+ const TensorShape& weightShape = weightHandle.GetShape();
+ unsigned int multiplier;
+ unsigned int height;
+ unsigned int width;
+ unsigned int inputChannels;
+ switch (dataLayout)
+ {
+ case DataLayout::NHWC: //It actually is [ H, W, I, M ]
+ height = weightShape[0];
+ width = weightShape[1];
+ inputChannels = weightShape[2];
+ multiplier = weightShape[3];
+ break;
+ case DataLayout::NCHW: //It actually is [ M, I, H, W ]
+ default:
+ height = weightShape[2];
+ width = weightShape[3];
+ inputChannels = weightShape[1];
+ multiplier = weightShape[0];
+ break;
}
+
+ DataType weightAclOrder[height*width*inputChannels*multiplier];
+ unsigned int destinationWeightsChannel;
+ unsigned int totalChannels = inputChannels * multiplier;
+ unsigned int channelSize = height * width;
+
+ for (unsigned int originWeightsChannel = 0; originWeightsChannel < totalChannels; originWeightsChannel++)
+ {
+ if (originWeightsChannel % inputChannels == 0)
+ {
+ destinationWeightsChannel = originWeightsChannel / inputChannels;
+ }
+ else
+ {
+ destinationWeightsChannel = (originWeightsChannel - 1) / inputChannels + multiplier;
+ }
+
+ for (unsigned int i = 0; i < channelSize; i++)
+ {
+ weightAclOrder[i + destinationWeightsChannel * channelSize] =
+ weight[i + originWeightsChannel * channelSize];
+ }
+ }
+
+ ::memcpy(permuteBuffer, weightAclOrder, weightHandle.GetInfo().GetNumBytes());
+ return ConstTensor(weightHandle.GetInfo(), permuteBuffer);
}
TensorInfo ConvertWeightTensorInfoFromArmnnToAcl(const TensorInfo& weightInfo, DataLayout dataLayout)
@@ -86,6 +139,9 @@ armnn::ConstTensor ConvertWeightTensorFromArmnnToAcl(const ConstCpuTensorHandle*
BOOST_ASSERT_MSG(weightTensor, "Invalid input tensor");
BOOST_ASSERT_MSG(permuteBuffer, "Invalid permute buffer");
+ auto multiplier = weightTensor->GetTensorInfo().GetShape()[0];
+ auto inputChannels = weightTensor->GetTensorInfo().GetShape()[1];
+
// Convert the weight format from ArmNN's [ M, I, H, W ] (does NOT depend on the data layout) to either
// [ 1, H, W, I * M ] (if NHWC) or [ 1, I * M, H, W ] (if NCHW), as required by the compute library
@@ -101,6 +157,26 @@ armnn::ConstTensor ConvertWeightTensorFromArmnnToAcl(const ConstCpuTensorHandle*
}
ConstTensor weightPermuted = PermuteTensor(weightTensor, permutationVector, permuteBuffer);
+ // Shuffle the weights data to obtain the channel order needed used by Acl
+ if (multiplier > 1 and inputChannels > 1 and dataLayout == DataLayout::NCHW)
+ {
+ switch (weightPermuted.GetDataType())
+ {
+ case DataType::Float32:
+ weightPermuted = ReorderWeightChannelsForAcl<float>(weightPermuted, dataLayout, permuteBuffer);
+ break;
+ case DataType::Float16:
+ weightPermuted =
+ ReorderWeightChannelsForAcl<half_float::half>(weightPermuted, dataLayout, permuteBuffer);
+ break;
+ case DataType::QuantisedAsymm8:
+ weightPermuted = ReorderWeightChannelsForAcl<uint8_t>(weightPermuted, dataLayout, permuteBuffer);
+ break;
+ default:
+ break;
+ }
+ }
+
// 2. Reshape the weights
ReshapeWeightsForAcl(weightPermuted.GetInfo(), dataLayout);