aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-11-07 14:49:26 +0000
committerAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-11-07 14:53:00 +0000
commit60a346b348a71b62ca26dcbb33eb881203ee0a68 (patch)
treebc7ab674c26febeab456911a9bffabf416df3693
parent4bc42cb7b4e35652ec2a64c054eec7d2aca997c3 (diff)
downloadandroid-nn-driver-60a346b348a71b62ca26dcbb33eb881203ee0a68.tar.gz
IVGCVSW-4104 Support per-axis quantization for GROUPED_CONV2D
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> Change-Id: Ice7c4d3273db31130ec64edc1b76d1c9d5197961
-rw-r--r--1.2/HalPolicy.cpp56
1 files changed, 33 insertions, 23 deletions
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp
index 12c08047..f901a31b 100644
--- a/1.2/HalPolicy.cpp
+++ b/1.2/HalPolicy.cpp
@@ -702,12 +702,7 @@ bool HalPolicy::ConvertGroupedConv2d(const Operation& operation, const Model& mo
}
ConstTensor weights = weightsPin.GetConstTensor();
- if (weights.GetInfo().HasPerAxisQuantization())
- {
- return Fail("%s: Per-axis quantization is not supported", __func__);
- }
-
- ConstTensor biases = biasesPin.GetConstTensor();
+ ConstTensor biases = biasesPin.GetConstTensor();
SanitizeBiasQuantizationScale(biases.GetInfo(), weights.GetInfo(), inputInfo);
const TensorShape& inputShape = inputInfo.GetShape();
@@ -838,6 +833,8 @@ bool HalPolicy::ConvertGroupedConv2d(const Operation& operation, const Model& mo
//
// Set up Convolution2d layers for each group
//
+
+ // Set up group tensor shapes
TensorShape groupInputShape(inputShape);
groupInputShape[channelsIndex] = channelsPerGroup;
@@ -849,27 +846,25 @@ bool HalPolicy::ConvertGroupedConv2d(const Operation& operation, const Model& mo
TensorShape groupBiasesShape({ 1 });
- const TensorInfo groupInputInfo (groupInputShape,
- inputInfo.GetDataType(),
- inputInfo.GetQuantizationScale(),
- inputInfo.GetQuantizationOffset());
- const TensorInfo groupWeightsInfo(groupWeightsShape,
- weights.GetInfo().GetDataType(),
- weights.GetInfo().GetQuantizationScale(),
- weights.GetInfo().GetQuantizationOffset());
- const TensorInfo groupBiasesInfo (groupBiasesShape,
- biases.GetInfo().GetDataType(),
- biases.GetInfo().GetQuantizationScale(),
- biases.GetInfo().GetQuantizationOffset());
- const TensorInfo groupOutputInfo (groupOutputShape,
- outputInfo.GetDataType(),
- outputInfo.GetQuantizationScale(),
- outputInfo.GetQuantizationOffset());
+ // Set up group tensor infos
+ TensorInfo groupInputInfo(inputInfo);
+ groupInputInfo.SetShape(groupInputShape);
+
+ const TensorInfo& weightsInfo = weights.GetInfo();
+ TensorInfo groupWeightsInfo(weightsInfo);
+ groupWeightsInfo.SetShape(groupWeightsShape);
+
+ const TensorInfo& biasesInfo = biases.GetInfo();
+ TensorInfo groupBiasesInfo(biasesInfo);
+ groupBiasesInfo.SetShape(groupBiasesShape);
+
+ TensorInfo groupOutputInfo(outputInfo);
+ groupOutputInfo.SetShape(groupOutputShape);
const unsigned int weightsDataTypeSize = GetDataTypeSize(groupWeightsInfo.GetDataType());
const unsigned int biasesDataTypeSize = GetDataTypeSize(groupBiasesInfo.GetDataType());
- std::vector<IConnectableLayer*> convLayers(numGroups*channelMultiplier, nullptr);
+ std::vector<IConnectableLayer*> convLayers(numGroups * channelMultiplier, nullptr);
for (unsigned int group = 0u; group < numGroups; ++group)
{
for (unsigned int m = 0u; m < channelMultiplier; ++m)
@@ -879,6 +874,21 @@ bool HalPolicy::ConvertGroupedConv2d(const Operation& operation, const Model& mo
const unsigned int weightsDataOffset = groupWeightsShape.GetNumElements() * index * weightsDataTypeSize;
const unsigned int biasesDataOffset = groupBiasesShape.GetNumElements() * index * biasesDataTypeSize;
+ if (weightsInfo.HasPerAxisQuantization())
+ {
+ // Extract per-axis quantization scales for group weights
+ const std::vector<float>& weightsQuantScales = weightsInfo.GetQuantizationScales();
+ groupWeightsInfo.SetQuantizationScales(
+ std::vector<float>(weightsQuantScales.begin() + index,
+ weightsQuantScales.begin() + index + groupWeightsShape[0]));
+
+ // Extract per-axis quantization scales for group biases
+ const std::vector<float>& biasesQuantScales = biasesInfo.GetQuantizationScales();
+ groupBiasesInfo.SetQuantizationScales(
+ std::vector<float>(biasesQuantScales.begin() + index,
+ biasesQuantScales.begin() + index + groupWeightsShape[0]));
+ }
+
// Extract weights and biases data for current group convolution
ConstTensor groupWeights(groupWeightsInfo,
static_cast<const void *>(reinterpret_cast<const char *>(weights.GetMemoryArea()) +