// // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #include "WorkloadUtils.hpp" namespace armnn { armnn::ConstTensor PermuteTensor(const ConstCpuTensorHandle* tensor, const PermutationVector& permutationVector, void* permuteBuffer) { BOOST_ASSERT_MSG(tensor, "Invalid input tensor"); BOOST_ASSERT_MSG(permuteBuffer, "Invalid permute buffer"); TensorInfo tensorInfo = tensor->GetTensorInfo(); if (permutationVector.GetSize() > 0) { tensorInfo = armnnUtils::Permuted(tensorInfo, permutationVector); armnnUtils::Permute(tensorInfo.GetShape(), permutationVector, tensor->GetConstTensor(), permuteBuffer, GetDataTypeSize(tensorInfo.GetDataType())); } else { ::memcpy(permuteBuffer, tensor->GetConstTensor(), tensorInfo.GetNumBytes()); } return ConstTensor(tensorInfo, permuteBuffer); } void ReshapeWeightsForAcl(TensorInfo& weightInfo, DataLayout dataLayout) { // Reshape the weights in-place const TensorShape& weightShape = weightInfo.GetShape(); switch (dataLayout) { case DataLayout::NHWC: // The data layout is NHWC, reshape from [ H, W, I, M ] to [ 1, H, W, I * M ] weightInfo.SetShape({ 1, weightShape[0], weightShape[1], weightShape[2] * weightShape[3] }); break; case DataLayout::NCHW: default: // The data layout is NCHW, reshape from [ M, I, H, W ] to [ 1, I * M, H, W, ] weightInfo.SetShape({ 1, weightShape[0] * weightShape[1], weightShape[2], weightShape[3] }); break; } } TensorInfo ConvertWeightTensorInfoFromArmnnToAcl(const TensorInfo& weightInfo, DataLayout dataLayout) { // Convert the weight format from ArmNN's [ M, I, H, W ] (does NOT depend on the data layout) to either // [ 1, H, W, I * M ] (if NHWC) or [ 1, I * M, H, W ] (if NCHW), as required by the compute library // 1. Permute the weights if necessary // If the data layout is NCHW no permutation is necessary, as a reshape to [ 1, I * M, H, W ] can be better done // starting from the current shape of [ M, I, H, W ] TensorInfo weightPermutedInfo(weightInfo); if (dataLayout == DataLayout::NHWC) { // The data layout is NHWC, then permute the weights from [ M, I, H, W ] to [ H, W, I, M ] PermutationVector permutationVector{ 3, 2, 0, 1 }; weightPermutedInfo = armnnUtils::Permuted(weightInfo, permutationVector); } // 2. Reshape the weights ReshapeWeightsForAcl(weightPermutedInfo, dataLayout); // 3. Return the permuted weight info return weightPermutedInfo; } armnn::ConstTensor ConvertWeightTensorFromArmnnToAcl(const ConstCpuTensorHandle* weightTensor, DataLayout dataLayout, void* permuteBuffer) { BOOST_ASSERT_MSG(weightTensor, "Invalid input tensor"); BOOST_ASSERT_MSG(permuteBuffer, "Invalid permute buffer"); // Convert the weight format from ArmNN's [ M, I, H, W ] (does NOT depend on the data layout) to either // [ 1, H, W, I * M ] (if NHWC) or [ 1, I * M, H, W ] (if NCHW), as required by the compute library // 1. Permute the weights if necessary // If the data layout is NCHW no permutation is necessary, as a reshape to [ 1, I * M, H, W ] can be better done // starting from the current shape of [ M, I, H, W ] // If no permutation is necessary, leave the permutation vector empty PermutationVector permutationVector{}; if (dataLayout == DataLayout::NHWC) { // The data layout is NHWC, then permute the weights from [ M, I, H, W ] to [ H, W, I, M ] permutationVector = { 3, 2, 0, 1 }; } ConstTensor weightPermuted = PermuteTensor(weightTensor, permutationVector, permuteBuffer); // 2. Reshape the weights ReshapeWeightsForAcl(weightPermuted.GetInfo(), dataLayout); // 3. Return both the tensor and the allocated storage to ensure that the data stays alive return weightPermuted; } } // namespace armnn