aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadUtils.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadUtils.cpp94
1 files changed, 94 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadUtils.cpp b/src/backends/backendsCommon/WorkloadUtils.cpp
index c8105aea04..bd7f09b28a 100644
--- a/src/backends/backendsCommon/WorkloadUtils.cpp
+++ b/src/backends/backendsCommon/WorkloadUtils.cpp
@@ -7,6 +7,9 @@
#include <armnn/Utils.hpp>
#include <armnn/utility/NumericCast.hpp>
+#include <armnnUtils/DataLayoutIndexed.hpp>
+
+#include <fmt/format.h>
namespace armnn
{
@@ -107,6 +110,7 @@ ConstTensor ReorderWeightChannelsForAcl(const ConstTensor& weightHandle, DataLay
return ConstTensor(weightHandle.GetInfo(), permuteBuffer);
}
+
TensorInfo ConvertWeightTensorInfoFromArmnnToAcl(const TensorInfo& weightInfo, DataLayout dataLayout)
{
// Convert the weight format from ArmNN's [ M, I, H, W ] (does NOT depend on the data layout) to either
@@ -130,6 +134,96 @@ TensorInfo ConvertWeightTensorInfoFromArmnnToAcl(const TensorInfo& weightInfo, D
return weightPermutedInfo;
}
+
+std::tuple<ConstTensor, unsigned int> Convert1HWOTensorToAcl(const ConstTensorHandle* weightTensor,
+ const TensorInfo& inputInfo,
+ const DataLayout dataLayout,
+ void* permuteBuffer)
+{
+ TensorInfo weightsInfo = weightTensor->GetTensorInfo();
+ unsigned int depthMultiplier = 1;
+ PermutationVector permutationVector{};
+ if (dataLayout == armnn::DataLayout::NHWC)
+ {
+ // No permutation required. Data layouts are the same.
+
+ depthMultiplier = weightsInfo.GetShape()[3] / inputInfo.GetShape()[3];
+ }
+ else if (dataLayout == armnn::DataLayout::NCHW)
+ {
+ // [ 1, H, W, I*M] --> [ 1, I * M, H, W ]
+ depthMultiplier = weightsInfo.GetShape()[3] / inputInfo.GetShape()[1];
+ permutationVector = { 0, 2, 3, 1 };
+ }
+ else
+ {
+ throw InvalidArgumentException(fmt::format("Unknown data layout for tensor conversion: {}",
+ GetDataLayoutName(dataLayout)));
+ }
+
+ ConstTensor weightsPermuted = PermuteTensor(weightTensor, permutationVector, permuteBuffer);
+
+ return std::make_tuple(weightsPermuted, depthMultiplier);
+}
+
+std::tuple<TensorInfo, unsigned int> Convert1HWOTensorInfoToAcl(const TensorInfo& weightInfo,
+ const TensorInfo& inputInfo,
+ const DataLayout dataLayout)
+{
+ unsigned int aclDepthMultiplier = 1;
+ TensorInfo weightsPermuted;
+ if (dataLayout == armnn::DataLayout::NHWC)
+ {
+ // No permutation required. Data layouts are the same.
+ aclDepthMultiplier = weightInfo.GetShape()[3] / inputInfo.GetShape()[3];
+ weightsPermuted = weightInfo;
+ }
+ else if (dataLayout == armnn::DataLayout::NCHW)
+ {
+ // [ 1, H, W, I*M] --> [ 1, I * M, H, W ]
+ aclDepthMultiplier = weightInfo.GetShape()[3] / inputInfo.GetShape()[1];
+ PermutationVector permutationVector{ 0, 2, 3, 1 };
+ weightsPermuted = armnnUtils::Permuted(weightInfo, permutationVector);
+ }
+ else
+ {
+ throw InvalidArgumentException(fmt::format("Unknown data layout for tensor info conversion: {}",
+ GetDataLayoutName(dataLayout)));
+ }
+
+ return std::make_tuple(weightsPermuted, aclDepthMultiplier);
+}
+
+
+std::tuple<ConstTensor, unsigned int> Convert1HWOtoMIHW(const ConstTensorHandle* weightTensor,
+ const TensorInfo& inputInfo,
+ const DataLayout& dataLayout,
+ void* permuteBuffer)
+{
+ TensorInfo weightsInfo = weightTensor->GetTensorInfo();
+
+ if (weightsInfo.HasPerAxisQuantization())
+ {
+ throw InvalidArgumentException("Can't convert tensor from [1,H,W,Cout] to [M,Cin,H,W] when per channel "
+ "quantization is applied.");
+ }
+
+ // Reshape weights [ 1, H, W, I*M ] --> [ H, W, I, M ]
+ auto weightsShape = weightsInfo.GetShape();
+ auto channelIndex = armnnUtils::DataLayoutIndexed(dataLayout).GetChannelsIndex();
+ unsigned int depthMultiplier = weightsShape[3] / inputInfo.GetShape()[channelIndex];
+ weightsInfo.SetShape({ weightsShape[1],
+ weightsShape[2],
+ inputInfo.GetShape()[channelIndex],
+ depthMultiplier});
+
+ // Permute [ H, W, I, M ] --> [ M, I, H, W ]
+ PermutationVector permutationVector = { 2, 3, 1, 0 };
+ ConstTensor weightsPermuted = PermuteTensor(weightTensor, permutationVector, permuteBuffer);
+
+ return std::make_tuple(weightsPermuted, depthMultiplier);
+}
+
armnn::ConstTensor ConvertWeightTensorFromArmnnToAcl(const ConstTensorHandle* weightTensor,
DataLayout dataLayout,
void* permuteBuffer)