aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/QuantizerVisitor.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/QuantizerVisitor.cpp')
-rw-r--r--src/armnn/QuantizerVisitor.cpp28
1 files changed, 28 insertions, 0 deletions
diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp
index ef6b068145..292924c8e4 100644
--- a/src/armnn/QuantizerVisitor.cpp
+++ b/src/armnn/QuantizerVisitor.cpp
@@ -446,4 +446,32 @@ void QuantizerVisitor::VisitSubtractionLayer(const IConnectableLayer* layer,
SetQuantizedInputConnections(layer, newLayer);
}
+void QuantizerVisitor::VisitTransposeConvolution2dLayer(const IConnectableLayer* layer,
+ const TransposeConvolution2dDescriptor& descriptor,
+ const ConstTensor& weights,
+ const Optional<ConstTensor>& biases,
+ const char* name)
+{
+ // quantize weights
+ std::vector<uint8_t> weightsBacking;
+ ConstTensor qWeights = CreateQuantizedConst(weights, weightsBacking);
+
+ // quantize biases
+ std::vector<int32_t> biasesBacking;
+ Optional<ConstTensor> optionalQBiases;
+ if (biases.has_value())
+ {
+ ConstTensor qBiases = CreateQuantizedBias(layer, qWeights, biases, biasesBacking);
+ optionalQBiases = Optional<ConstTensor>(qBiases);
+ }
+
+ IConnectableLayer* newLayer = m_QuantizedNetwork->AddTransposeConvolution2dLayer(descriptor,
+ qWeights,
+ optionalQBiases,
+ name);
+
+ RecordLayer(layer, newLayer);
+ SetQuantizedInputConnections(layer, newLayer);
+}
+
} //namespace armnn