aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/QuantizerVisitor.cpp
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-06-21 15:25:19 +0100
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>2019-06-21 16:08:15 +0000
commit389aa70c8a24fa2faf33df5f8cd9a99b0fabe971 (patch)
tree23794e376c15e5a7cd79bec95a23028944c8f9dc /src/armnn/QuantizerVisitor.cpp
parent5e1b0cf8a7519afb49874a83429ef9939a249f0d (diff)
downloadarmnn-experimental/transpose_conv2d.tar.gz
IVGCVSW-3322 Add Quantizer support for TransposeConvolution2DLayerexperimental/transpose_conv2d
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> Change-Id: I26997d7770585055b2b3256baad2800a4c5ed7e8
Diffstat (limited to 'src/armnn/QuantizerVisitor.cpp')
-rw-r--r--src/armnn/QuantizerVisitor.cpp28
1 files changed, 28 insertions, 0 deletions
diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp
index ef6b068145..292924c8e4 100644
--- a/src/armnn/QuantizerVisitor.cpp
+++ b/src/armnn/QuantizerVisitor.cpp
@@ -446,4 +446,32 @@ void QuantizerVisitor::VisitSubtractionLayer(const IConnectableLayer* layer,
SetQuantizedInputConnections(layer, newLayer);
}
+void QuantizerVisitor::VisitTransposeConvolution2dLayer(const IConnectableLayer* layer,
+ const TransposeConvolution2dDescriptor& descriptor,
+ const ConstTensor& weights,
+ const Optional<ConstTensor>& biases,
+ const char* name)
+{
+ // quantize weights
+ std::vector<uint8_t> weightsBacking;
+ ConstTensor qWeights = CreateQuantizedConst(weights, weightsBacking);
+
+ // quantize biases
+ std::vector<int32_t> biasesBacking;
+ Optional<ConstTensor> optionalQBiases;
+ if (biases.has_value())
+ {
+ ConstTensor qBiases = CreateQuantizedBias(layer, qWeights, biases, biasesBacking);
+ optionalQBiases = Optional<ConstTensor>(qBiases);
+ }
+
+ IConnectableLayer* newLayer = m_QuantizedNetwork->AddTransposeConvolution2dLayer(descriptor,
+ qWeights,
+ optionalQBiases,
+ name);
+
+ RecordLayer(layer, newLayer);
+ SetQuantizedInputConnections(layer, newLayer);
+}
+
} //namespace armnn