// // Copyright © 2019 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include #include #include template void PermuteTensorNchwToNhwc(armnn::TensorInfo& tensorInfo, std::vector& tensorData) { const armnn::PermutationVector nchwToNhwc = { 0, 3, 1, 2 }; tensorInfo = armnnUtils::Permuted(tensorInfo, nchwToNhwc); std::vector tmp(tensorData.size()); armnnUtils::Permute(tensorInfo.GetShape(), nchwToNhwc, tensorData.data(), tmp.data(), sizeof(T)); tensorData = tmp; } template void PermuteTensorNhwcToNchw(armnn::TensorInfo& tensorInfo, std::vector& tensorData) { const armnn::PermutationVector nhwcToNchw = { 0, 2, 3, 1 }; tensorInfo = armnnUtils::Permuted(tensorInfo, nhwcToNchw); std::vector tmp(tensorData.size()); armnnUtils::Permute(tensorInfo.GetShape(), nhwcToNchw, tensorData.data(), tmp.data(), sizeof(T)); tensorData = tmp; }