diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadUtils.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadUtils.cpp | 111 |
1 files changed, 111 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadUtils.cpp b/src/backends/backendsCommon/WorkloadUtils.cpp new file mode 100644 index 0000000000..fa387a7a0b --- /dev/null +++ b/src/backends/backendsCommon/WorkloadUtils.cpp @@ -0,0 +1,111 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "WorkloadUtils.hpp" + +namespace armnn +{ + +armnn::ConstTensor PermuteTensor(const ConstCpuTensorHandle* tensor, + const PermutationVector& permutationVector, + void* permuteBuffer) +{ + BOOST_ASSERT_MSG(tensor, "Invalid input tensor"); + BOOST_ASSERT_MSG(permuteBuffer, "Invalid permute buffer"); + + TensorInfo tensorInfo = tensor->GetTensorInfo(); + + if (permutationVector.GetSize() > 0) + { + tensorInfo = armnnUtils::Permuted(tensorInfo, permutationVector); + armnnUtils::Permute(tensorInfo.GetShape(), permutationVector, + tensor->GetConstTensor<void>(), permuteBuffer, + GetDataTypeSize(tensorInfo.GetDataType())); + } + else + { + ::memcpy(permuteBuffer, tensor->GetConstTensor<void>(), tensorInfo.GetNumBytes()); + } + + return ConstTensor(tensorInfo, permuteBuffer); +} + +void ReshapeWeightsForAcl(TensorInfo& weightInfo, DataLayout dataLayout) +{ + // Reshape the weights in-place + const TensorShape& weightShape = weightInfo.GetShape(); + switch (dataLayout) + { + case DataLayout::NHWC: + // The data layout is NHWC, reshape from [ H, W, I, M ] to [ 1, H, W, I * M ] + weightInfo.SetShape({ 1, + weightShape[0], + weightShape[1], + weightShape[2] * weightShape[3] }); + break; + case DataLayout::NCHW: + default: + // The data layout is NCHW, reshape from [ M, I, H, W ] to [ 1, I * M, H, W, ] + weightInfo.SetShape({ 1, + weightShape[0] * weightShape[1], + weightShape[2], + weightShape[3] }); + break; + } +} + +TensorInfo ConvertWeightTensorInfoFromArmnnToAcl(const TensorInfo& weightInfo, DataLayout dataLayout) +{ + // Convert the weight format from ArmNN's [ M, I, H, W ] (does NOT depend on the data layout) to either + // [ 1, H, W, I * M ] (if NHWC) or [ 1, I * M, H, W ] (if NCHW), as required by the compute library + + // 1. Permute the weights if necessary + // If the data layout is NCHW no permutation is necessary, as a reshape to [ 1, I * M, H, W ] can be better done + // starting from the current shape of [ M, I, H, W ] + TensorInfo weightPermutedInfo(weightInfo); + if (dataLayout == DataLayout::NHWC) + { + // The data layout is NHWC, then permute the weights from [ M, I, H, W ] to [ H, W, I, M ] + PermutationVector permutationVector{ 3, 2, 0, 1 }; + weightPermutedInfo = armnnUtils::Permuted(weightInfo, permutationVector); + } + + // 2. Reshape the weights + ReshapeWeightsForAcl(weightPermutedInfo, dataLayout); + + // 3. Return the permuted weight info + return weightPermutedInfo; +} + +armnn::ConstTensor ConvertWeightTensorFromArmnnToAcl(const ConstCpuTensorHandle* weightTensor, + DataLayout dataLayout, + void* permuteBuffer) +{ + BOOST_ASSERT_MSG(weightTensor, "Invalid input tensor"); + BOOST_ASSERT_MSG(permuteBuffer, "Invalid permute buffer"); + + // Convert the weight format from ArmNN's [ M, I, H, W ] (does NOT depend on the data layout) to either + // [ 1, H, W, I * M ] (if NHWC) or [ 1, I * M, H, W ] (if NCHW), as required by the compute library + + // 1. Permute the weights if necessary + // If the data layout is NCHW no permutation is necessary, as a reshape to [ 1, I * M, H, W ] can be better done + // starting from the current shape of [ M, I, H, W ] + // If no permutation is necessary, leave the permutation vector empty + PermutationVector permutationVector{}; + if (dataLayout == DataLayout::NHWC) + { + // The data layout is NHWC, then permute the weights from [ M, I, H, W ] to [ H, W, I, M ] + permutationVector = { 3, 2, 0, 1 }; + } + ConstTensor weightPermuted = PermuteTensor(weightTensor, permutationVector, permuteBuffer); + + // 2. Reshape the weights + ReshapeWeightsForAcl(weightPermuted.GetInfo(), dataLayout); + + // 3. Return both the tensor and the allocated storage to ensure that the data stays alive + return weightPermuted; +} + +} // namespace armnn |