diff options
Diffstat (limited to 'src/armnnDeserializer/Deserializer.cpp')
-rw-r--r-- | src/armnnDeserializer/Deserializer.cpp | 37 |
1 files changed, 26 insertions, 11 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index c34797725f..633c272f00 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -1841,7 +1841,6 @@ void IDeserializer::DeserializerImpl::ParseFullyConnected(GraphPtr graph, unsign CHECK_LAYERS(graph, 0, layerIndex); auto inputs = GetInputs(graph, layerIndex); CHECK_LOCATION(); - CHECK_VALID_SIZE(inputs.size(), 1); auto outputs = GetOutputs(graph, layerIndex); CHECK_VALID_SIZE(outputs.size(), 1); @@ -1853,20 +1852,36 @@ void IDeserializer::DeserializerImpl::ParseFullyConnected(GraphPtr graph, unsign armnn::FullyConnectedDescriptor fullyConnectedDescriptor; fullyConnectedDescriptor.m_BiasEnabled = flatBufferDescriptor->biasEnabled(); fullyConnectedDescriptor.m_TransposeWeightMatrix = flatBufferDescriptor->transposeWeightsMatrix(); + fullyConnectedDescriptor.m_ConstantWeights = flatBufferDescriptor->constantWeights(); + uint32_t numInputs = 1; + if (!fullyConnectedDescriptor.m_ConstantWeights) + { + numInputs = 2; + if (fullyConnectedDescriptor.m_BiasEnabled) + { + numInputs = 3; + } + } + CHECK_VALID_SIZE(inputs.size(), numInputs); - armnn::ConstTensor weightsTensor = ToConstTensor(flatBufferLayer->weights()); - - armnn::IConnectableLayer* layer; + armnn::Optional <armnn::ConstTensor> optionalWeights = armnn::EmptyOptional(); armnn::Optional<armnn::ConstTensor> optionalBiases = armnn::EmptyOptional(); - if (flatBufferDescriptor->biasEnabled()) + if (fullyConnectedDescriptor.m_ConstantWeights) { - armnn::ConstTensor biasTensorData = ToConstTensor(flatBufferLayer->biases()); - optionalBiases = armnn::Optional<armnn::ConstTensor>(biasTensorData); + armnn::ConstTensor weightsTensorData = ToConstTensor(flatBufferLayer->weights()); + optionalWeights = armnn::Optional<armnn::ConstTensor>(weightsTensorData); + + if (flatBufferDescriptor->biasEnabled()) + { + armnn::ConstTensor biasTensorData = ToConstTensor(flatBufferLayer->biases()); + optionalBiases = armnn::Optional<armnn::ConstTensor>(biasTensorData); + } } - layer = m_Network->AddFullyConnectedLayer(fullyConnectedDescriptor, - weightsTensor, - optionalBiases, - layerName.c_str()); + + armnn::IConnectableLayer* layer = m_Network->AddFullyConnectedLayer(fullyConnectedDescriptor, + optionalWeights, + optionalBiases, + layerName.c_str()); armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); |