diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadUtils.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadUtils.cpp | 88 |
1 files changed, 82 insertions, 6 deletions
diff --git a/src/backends/backendsCommon/WorkloadUtils.cpp b/src/backends/backendsCommon/WorkloadUtils.cpp index fa387a7a0b..3185ba00d3 100644 --- a/src/backends/backendsCommon/WorkloadUtils.cpp +++ b/src/backends/backendsCommon/WorkloadUtils.cpp @@ -9,8 +9,7 @@ namespace armnn { armnn::ConstTensor PermuteTensor(const ConstCpuTensorHandle* tensor, - const PermutationVector& permutationVector, - void* permuteBuffer) + const PermutationVector& permutationVector, void* permuteBuffer) { BOOST_ASSERT_MSG(tensor, "Invalid input tensor"); BOOST_ASSERT_MSG(permuteBuffer, "Invalid permute buffer"); @@ -44,16 +43,70 @@ void ReshapeWeightsForAcl(TensorInfo& weightInfo, DataLayout dataLayout) weightShape[0], weightShape[1], weightShape[2] * weightShape[3] }); - break; - case DataLayout::NCHW: - default: - // The data layout is NCHW, reshape from [ M, I, H, W ] to [ 1, I * M, H, W, ] weightInfo.SetShape({ 1, weightShape[0] * weightShape[1], weightShape[2], weightShape[3] }); break; + case DataLayout::NCHW: + default: + // The data layout is NCHW, reshape from [ M, I, H, W ] to [ 1, I * M, H, W, ] + weightInfo.SetShape({ 1, weightShape[0] * weightShape[1], weightShape[2], weightShape[3] }); + break; + } +} + +template <typename DataType> +ConstTensor ReorderWeightChannelsForAcl(const ConstTensor& weightHandle, DataLayout dataLayout, void* permuteBuffer) +{ + DataType* weight = static_cast<DataType*>(permuteBuffer); + const TensorShape& weightShape = weightHandle.GetShape(); + unsigned int multiplier; + unsigned int height; + unsigned int width; + unsigned int inputChannels; + switch (dataLayout) + { + case DataLayout::NHWC: //It actually is [ H, W, I, M ] + height = weightShape[0]; + width = weightShape[1]; + inputChannels = weightShape[2]; + multiplier = weightShape[3]; + break; + case DataLayout::NCHW: //It actually is [ M, I, H, W ] + default: + height = weightShape[2]; + width = weightShape[3]; + inputChannels = weightShape[1]; + multiplier = weightShape[0]; + break; } + + DataType weightAclOrder[height*width*inputChannels*multiplier]; + unsigned int destinationWeightsChannel; + unsigned int totalChannels = inputChannels * multiplier; + unsigned int channelSize = height * width; + + for (unsigned int originWeightsChannel = 0; originWeightsChannel < totalChannels; originWeightsChannel++) + { + if (originWeightsChannel % inputChannels == 0) + { + destinationWeightsChannel = originWeightsChannel / inputChannels; + } + else + { + destinationWeightsChannel = (originWeightsChannel - 1) / inputChannels + multiplier; + } + + for (unsigned int i = 0; i < channelSize; i++) + { + weightAclOrder[i + destinationWeightsChannel * channelSize] = + weight[i + originWeightsChannel * channelSize]; + } + } + + ::memcpy(permuteBuffer, weightAclOrder, weightHandle.GetInfo().GetNumBytes()); + return ConstTensor(weightHandle.GetInfo(), permuteBuffer); } TensorInfo ConvertWeightTensorInfoFromArmnnToAcl(const TensorInfo& weightInfo, DataLayout dataLayout) @@ -86,6 +139,9 @@ armnn::ConstTensor ConvertWeightTensorFromArmnnToAcl(const ConstCpuTensorHandle* BOOST_ASSERT_MSG(weightTensor, "Invalid input tensor"); BOOST_ASSERT_MSG(permuteBuffer, "Invalid permute buffer"); + auto multiplier = weightTensor->GetTensorInfo().GetShape()[0]; + auto inputChannels = weightTensor->GetTensorInfo().GetShape()[1]; + // Convert the weight format from ArmNN's [ M, I, H, W ] (does NOT depend on the data layout) to either // [ 1, H, W, I * M ] (if NHWC) or [ 1, I * M, H, W ] (if NCHW), as required by the compute library @@ -101,6 +157,26 @@ armnn::ConstTensor ConvertWeightTensorFromArmnnToAcl(const ConstCpuTensorHandle* } ConstTensor weightPermuted = PermuteTensor(weightTensor, permutationVector, permuteBuffer); + // Shuffle the weights data to obtain the channel order needed used by Acl + if (multiplier > 1 and inputChannels > 1 and dataLayout == DataLayout::NCHW) + { + switch (weightPermuted.GetDataType()) + { + case DataType::Float32: + weightPermuted = ReorderWeightChannelsForAcl<float>(weightPermuted, dataLayout, permuteBuffer); + break; + case DataType::Float16: + weightPermuted = + ReorderWeightChannelsForAcl<half_float::half>(weightPermuted, dataLayout, permuteBuffer); + break; + case DataType::QuantisedAsymm8: + weightPermuted = ReorderWeightChannelsForAcl<uint8_t>(weightPermuted, dataLayout, permuteBuffer); + break; + default: + break; + } + } + // 2. Reshape the weights ReshapeWeightsForAcl(weightPermuted.GetInfo(), dataLayout); |