diff options
Diffstat (limited to 'src/armnn/layers/FullyConnectedLayer.cpp')
-rw-r--r-- | src/armnn/layers/FullyConnectedLayer.cpp | 40 |
1 files changed, 30 insertions, 10 deletions
diff --git a/src/armnn/layers/FullyConnectedLayer.cpp b/src/armnn/layers/FullyConnectedLayer.cpp index 1597e8c2c3..8b8f010bdb 100644 --- a/src/armnn/layers/FullyConnectedLayer.cpp +++ b/src/armnn/layers/FullyConnectedLayer.cpp @@ -22,11 +22,15 @@ FullyConnectedLayer::FullyConnectedLayer(const FullyConnectedDescriptor& param, std::unique_ptr<IWorkload> FullyConnectedLayer::CreateWorkload(const Graph& graph, const IWorkloadFactory& factory) const { + // on this level constant data should not be released.. + BOOST_ASSERT_MSG(m_Weight != nullptr, "FullyConnectedLayer: Weights data should not be null."); + FullyConnectedQueueDescriptor descriptor; descriptor.m_Weight = m_Weight.get(); if (m_Param.m_BiasEnabled) { + BOOST_ASSERT_MSG(m_Bias != nullptr, "FullyConnectedLayer: Bias data should not be null."); descriptor.m_Bias = m_Bias.get(); } return factory.CreateFullyConnected(descriptor, PrepInfoAndDesc(descriptor, graph)); @@ -45,25 +49,41 @@ FullyConnectedLayer* FullyConnectedLayer::Clone(Graph& graph) const return std::move(layer); } +std::vector<TensorShape> FullyConnectedLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const +{ + BOOST_ASSERT(inputShapes.size() == 2); + const TensorShape& inputShape = inputShapes[0]; + const TensorShape weightShape = inputShapes[1]; + + // Output for FC is [1, w[1]]. + unsigned int batches = inputShape[0]; + unsigned int dimIdx = m_Param.m_TransposeWeightMatrix ? 0 : 1; + + return std::vector<TensorShape>({ TensorShape({batches, weightShape[dimIdx]})}); +} + void FullyConnectedLayer::ValidateTensorShapesFromInputs() { - ConditionalThrow<LayerValidationException>(GetInputSlot(0).GetConnection() != nullptr, - "FullyConnectedLayer: InputSlot must be connected to an OutputSlot"); - ConditionalThrow<LayerValidationException>(GetInputSlot(0).GetConnection()->IsTensorInfoSet(), - "FullyConnectedLayer: TensorInfo must be set on connected OutputSlot."); + VerifyLayerConnections(1, CHECK_LOCATION()); + // check if we m_Weight data is not nullptr + BOOST_ASSERT_MSG(m_Weight != nullptr, "FullyConnectedLayer: Weights data should not be null."); - TensorShape const& weightShape = m_Weight->GetTensorInfo().GetShape(); + auto inferredShapes = InferOutputShapes({ + GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), + m_Weight->GetTensorInfo().GetShape() }); - // output for FC is [1, w[1]] - unsigned int batches = GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape()[0]; - unsigned int dimIdx = m_Param.m_TransposeWeightMatrix ? 0 : 1; - TensorShape outShape({batches, weightShape[dimIdx]}); + BOOST_ASSERT(inferredShapes.size() == 1); ConditionalThrowIfNotEqual<LayerValidationException>( "FullyConnectedLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.", GetOutputSlot(0).GetTensorInfo().GetShape(), - outShape); + inferredShapes[0]); +} + +Layer::ConstantTensors FullyConnectedLayer::GetConstantTensorsByRef() +{ + return {m_Weight, m_Bias}; } } // namespace armnn |