diff options
Diffstat (limited to 'src/armnn/QuantizerVisitor.cpp')
-rw-r--r-- | src/armnn/QuantizerVisitor.cpp | 32 |
1 files changed, 31 insertions, 1 deletions
diff --git a/src/armnn/QuantizerVisitor.cpp b/src/armnn/QuantizerVisitor.cpp index b5085be0a2..97a8bc1ad2 100644 --- a/src/armnn/QuantizerVisitor.cpp +++ b/src/armnn/QuantizerVisitor.cpp @@ -79,7 +79,37 @@ void QuantizerVisitor::VisitActivationLayer(const IConnectableLayer* layer, SetQuantizedInputConnections(layer, newLayer); } -void QuantizerVisitor::VisitInputLayer(const IConnectableLayer* layer, LayerBindingId id, const char* name) +void QuantizerVisitor::VisitFullyConnectedLayer(const IConnectableLayer *layer, + const FullyConnectedDescriptor& desc, + const ConstTensor& weights, + const char *name) +{ + std::vector<uint8_t> weightsBacking; + ConstTensor qWeights = CreateQuantizedConst(weights, weightsBacking); + + IConnectableLayer* newLayer = m_QuantizedNetwork->AddFullyConnectedLayer(desc, qWeights, name); + RecordLayer(layer, newLayer); + SetQuantizedInputConnections(layer, newLayer); +} + +void QuantizerVisitor::VisitFullyConnectedLayer(const IConnectableLayer *layer, + const FullyConnectedDescriptor& desc, + const ConstTensor& weights, + const ConstTensor& bias, + const char *name) +{ + std::vector<uint8_t> weightsBacking; + ConstTensor qWeights = CreateQuantizedConst(weights, weightsBacking); + + std::vector<uint8_t> biasBacking; + ConstTensor qBias = CreateQuantizedConst(bias, biasBacking); + + IConnectableLayer* newLayer = m_QuantizedNetwork->AddFullyConnectedLayer(desc, qWeights, qBias, name); + RecordLayer(layer, newLayer); + SetQuantizedInputConnections(layer, newLayer); +} + +void QuantizerVisitor::VisitInputLayer(const IConnectableLayer *layer, LayerBindingId id, const char *name) { IConnectableLayer* newLayer = m_QuantizedNetwork->AddInputLayer(id, name); RecordLayer(layer, newLayer); |