diff options
Diffstat (limited to 'src/armnn/QuantizerVisitor.cpp')
-rw-r--r-- | src/armnn/QuantizerVisitor.cpp | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp index ef6b068145..292924c8e4 100644 --- a/src/armnn/QuantizerVisitor.cpp +++ b/src/armnn/QuantizerVisitor.cpp @@ -446,4 +446,32 @@ void QuantizerVisitor::VisitSubtractionLayer(const IConnectableLayer* layer, SetQuantizedInputConnections(layer, newLayer); } +void QuantizerVisitor::VisitTransposeConvolution2dLayer(const IConnectableLayer* layer, + const TransposeConvolution2dDescriptor& descriptor, + const ConstTensor& weights, + const Optional<ConstTensor>& biases, + const char* name) +{ + // quantize weights + std::vector<uint8_t> weightsBacking; + ConstTensor qWeights = CreateQuantizedConst(weights, weightsBacking); + + // quantize biases + std::vector<int32_t> biasesBacking; + Optional<ConstTensor> optionalQBiases; + if (biases.has_value()) + { + ConstTensor qBiases = CreateQuantizedBias(layer, qWeights, biases, biasesBacking); + optionalQBiases = Optional<ConstTensor>(qBiases); + } + + IConnectableLayer* newLayer = m_QuantizedNetwork->AddTransposeConvolution2dLayer(descriptor, + qWeights, + optionalQBiases, + name); + + RecordLayer(layer, newLayer); + SetQuantizedInputConnections(layer, newLayer); +} + } //namespace armnn |