diff options
author | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-11-07 14:49:26 +0000 |
---|---|---|
committer | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-11-07 14:53:00 +0000 |
commit | 60a346b348a71b62ca26dcbb33eb881203ee0a68 (patch) | |
tree | bc7ab674c26febeab456911a9bffabf416df3693 | |
parent | 4bc42cb7b4e35652ec2a64c054eec7d2aca997c3 (diff) | |
download | android-nn-driver-60a346b348a71b62ca26dcbb33eb881203ee0a68.tar.gz |
IVGCVSW-4104 Support per-axis quantization for GROUPED_CONV2D
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: Ice7c4d3273db31130ec64edc1b76d1c9d5197961
-rw-r--r-- | 1.2/HalPolicy.cpp | 56 |
1 files changed, 33 insertions, 23 deletions
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp index 12c08047..f901a31b 100644 --- a/1.2/HalPolicy.cpp +++ b/1.2/HalPolicy.cpp @@ -702,12 +702,7 @@ bool HalPolicy::ConvertGroupedConv2d(const Operation& operation, const Model& mo } ConstTensor weights = weightsPin.GetConstTensor(); - if (weights.GetInfo().HasPerAxisQuantization()) - { - return Fail("%s: Per-axis quantization is not supported", __func__); - } - - ConstTensor biases = biasesPin.GetConstTensor(); + ConstTensor biases = biasesPin.GetConstTensor(); SanitizeBiasQuantizationScale(biases.GetInfo(), weights.GetInfo(), inputInfo); const TensorShape& inputShape = inputInfo.GetShape(); @@ -838,6 +833,8 @@ bool HalPolicy::ConvertGroupedConv2d(const Operation& operation, const Model& mo // // Set up Convolution2d layers for each group // + + // Set up group tensor shapes TensorShape groupInputShape(inputShape); groupInputShape[channelsIndex] = channelsPerGroup; @@ -849,27 +846,25 @@ bool HalPolicy::ConvertGroupedConv2d(const Operation& operation, const Model& mo TensorShape groupBiasesShape({ 1 }); - const TensorInfo groupInputInfo (groupInputShape, - inputInfo.GetDataType(), - inputInfo.GetQuantizationScale(), - inputInfo.GetQuantizationOffset()); - const TensorInfo groupWeightsInfo(groupWeightsShape, - weights.GetInfo().GetDataType(), - weights.GetInfo().GetQuantizationScale(), - weights.GetInfo().GetQuantizationOffset()); - const TensorInfo groupBiasesInfo (groupBiasesShape, - biases.GetInfo().GetDataType(), - biases.GetInfo().GetQuantizationScale(), - biases.GetInfo().GetQuantizationOffset()); - const TensorInfo groupOutputInfo (groupOutputShape, - outputInfo.GetDataType(), - outputInfo.GetQuantizationScale(), - outputInfo.GetQuantizationOffset()); + // Set up group tensor infos + TensorInfo groupInputInfo(inputInfo); + groupInputInfo.SetShape(groupInputShape); + + const TensorInfo& weightsInfo = weights.GetInfo(); + TensorInfo groupWeightsInfo(weightsInfo); + groupWeightsInfo.SetShape(groupWeightsShape); + + const TensorInfo& biasesInfo = biases.GetInfo(); + TensorInfo groupBiasesInfo(biasesInfo); + groupBiasesInfo.SetShape(groupBiasesShape); + + TensorInfo groupOutputInfo(outputInfo); + groupOutputInfo.SetShape(groupOutputShape); const unsigned int weightsDataTypeSize = GetDataTypeSize(groupWeightsInfo.GetDataType()); const unsigned int biasesDataTypeSize = GetDataTypeSize(groupBiasesInfo.GetDataType()); - std::vector<IConnectableLayer*> convLayers(numGroups*channelMultiplier, nullptr); + std::vector<IConnectableLayer*> convLayers(numGroups * channelMultiplier, nullptr); for (unsigned int group = 0u; group < numGroups; ++group) { for (unsigned int m = 0u; m < channelMultiplier; ++m) @@ -879,6 +874,21 @@ bool HalPolicy::ConvertGroupedConv2d(const Operation& operation, const Model& mo const unsigned int weightsDataOffset = groupWeightsShape.GetNumElements() * index * weightsDataTypeSize; const unsigned int biasesDataOffset = groupBiasesShape.GetNumElements() * index * biasesDataTypeSize; + if (weightsInfo.HasPerAxisQuantization()) + { + // Extract per-axis quantization scales for group weights + const std::vector<float>& weightsQuantScales = weightsInfo.GetQuantizationScales(); + groupWeightsInfo.SetQuantizationScales( + std::vector<float>(weightsQuantScales.begin() + index, + weightsQuantScales.begin() + index + groupWeightsShape[0])); + + // Extract per-axis quantization scales for group biases + const std::vector<float>& biasesQuantScales = biasesInfo.GetQuantizationScales(); + groupBiasesInfo.SetQuantizationScales( + std::vector<float>(biasesQuantScales.begin() + index, + biasesQuantScales.begin() + index + groupWeightsShape[0])); + } + // Extract weights and biases data for current group convolution ConstTensor groupWeights(groupWeightsInfo, static_cast<const void *>(reinterpret_cast<const char *>(weights.GetMemoryArea()) + |