From 60a346b348a71b62ca26dcbb33eb881203ee0a68 Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Thu, 7 Nov 2019 14:49:26 +0000 Subject: IVGCVSW-4104 Support per-axis quantization for GROUPED_CONV2D Signed-off-by: Aron Virginas-Tar Change-Id: Ice7c4d3273db31130ec64edc1b76d1c9d5197961 --- 1.2/HalPolicy.cpp | 56 ++++++++++++++++++++++++++++++++----------------------- 1 file changed, 33 insertions(+), 23 deletions(-) diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp index 12c08047..f901a31b 100644 --- a/1.2/HalPolicy.cpp +++ b/1.2/HalPolicy.cpp @@ -702,12 +702,7 @@ bool HalPolicy::ConvertGroupedConv2d(const Operation& operation, const Model& mo } ConstTensor weights = weightsPin.GetConstTensor(); - if (weights.GetInfo().HasPerAxisQuantization()) - { - return Fail("%s: Per-axis quantization is not supported", __func__); - } - - ConstTensor biases = biasesPin.GetConstTensor(); + ConstTensor biases = biasesPin.GetConstTensor(); SanitizeBiasQuantizationScale(biases.GetInfo(), weights.GetInfo(), inputInfo); const TensorShape& inputShape = inputInfo.GetShape(); @@ -838,6 +833,8 @@ bool HalPolicy::ConvertGroupedConv2d(const Operation& operation, const Model& mo // // Set up Convolution2d layers for each group // + + // Set up group tensor shapes TensorShape groupInputShape(inputShape); groupInputShape[channelsIndex] = channelsPerGroup; @@ -849,27 +846,25 @@ bool HalPolicy::ConvertGroupedConv2d(const Operation& operation, const Model& mo TensorShape groupBiasesShape({ 1 }); - const TensorInfo groupInputInfo (groupInputShape, - inputInfo.GetDataType(), - inputInfo.GetQuantizationScale(), - inputInfo.GetQuantizationOffset()); - const TensorInfo groupWeightsInfo(groupWeightsShape, - weights.GetInfo().GetDataType(), - weights.GetInfo().GetQuantizationScale(), - weights.GetInfo().GetQuantizationOffset()); - const TensorInfo groupBiasesInfo (groupBiasesShape, - biases.GetInfo().GetDataType(), - biases.GetInfo().GetQuantizationScale(), - biases.GetInfo().GetQuantizationOffset()); - const TensorInfo groupOutputInfo (groupOutputShape, - outputInfo.GetDataType(), - outputInfo.GetQuantizationScale(), - outputInfo.GetQuantizationOffset()); + // Set up group tensor infos + TensorInfo groupInputInfo(inputInfo); + groupInputInfo.SetShape(groupInputShape); + + const TensorInfo& weightsInfo = weights.GetInfo(); + TensorInfo groupWeightsInfo(weightsInfo); + groupWeightsInfo.SetShape(groupWeightsShape); + + const TensorInfo& biasesInfo = biases.GetInfo(); + TensorInfo groupBiasesInfo(biasesInfo); + groupBiasesInfo.SetShape(groupBiasesShape); + + TensorInfo groupOutputInfo(outputInfo); + groupOutputInfo.SetShape(groupOutputShape); const unsigned int weightsDataTypeSize = GetDataTypeSize(groupWeightsInfo.GetDataType()); const unsigned int biasesDataTypeSize = GetDataTypeSize(groupBiasesInfo.GetDataType()); - std::vector convLayers(numGroups*channelMultiplier, nullptr); + std::vector convLayers(numGroups * channelMultiplier, nullptr); for (unsigned int group = 0u; group < numGroups; ++group) { for (unsigned int m = 0u; m < channelMultiplier; ++m) @@ -879,6 +874,21 @@ bool HalPolicy::ConvertGroupedConv2d(const Operation& operation, const Model& mo const unsigned int weightsDataOffset = groupWeightsShape.GetNumElements() * index * weightsDataTypeSize; const unsigned int biasesDataOffset = groupBiasesShape.GetNumElements() * index * biasesDataTypeSize; + if (weightsInfo.HasPerAxisQuantization()) + { + // Extract per-axis quantization scales for group weights + const std::vector& weightsQuantScales = weightsInfo.GetQuantizationScales(); + groupWeightsInfo.SetQuantizationScales( + std::vector(weightsQuantScales.begin() + index, + weightsQuantScales.begin() + index + groupWeightsShape[0])); + + // Extract per-axis quantization scales for group biases + const std::vector& biasesQuantScales = biasesInfo.GetQuantizationScales(); + groupBiasesInfo.SetQuantizationScales( + std::vector(biasesQuantScales.begin() + index, + biasesQuantScales.begin() + index + groupWeightsShape[0])); + } + // Extract weights and biases data for current group convolution ConstTensor groupWeights(groupWeightsInfo, static_cast(reinterpret_cast(weights.GetMemoryArea()) + -- cgit v1.2.1