From cac13b1cfd593889271f8e2191be2039b8d88f36 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 27 Apr 2018 19:07:19 +0100 Subject: COMPMID-1097: Port mobilenet to NHWC Change-Id: I789065bfa0d4ef133388e1904c5caf31e450f80f Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129495 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- src/graph/nodes/FullyConnectedLayer.cpp | 38 +++++++++++++++++---------------- 1 file changed, 20 insertions(+), 18 deletions(-) (limited to 'src/graph/nodes/FullyConnectedLayer.cpp') diff --git a/src/graph/nodes/FullyConnectedLayer.cpp b/src/graph/nodes/FullyConnectedLayer.cpp index cbf2b35ddd..d94a7851ff 100644 --- a/src/graph/nodes/FullyConnectedLayer.cpp +++ b/src/graph/nodes/FullyConnectedLayer.cpp @@ -38,10 +38,11 @@ FullyConnectedLayerNode::FullyConnectedLayerNode(unsigned int num_outputs) _outputs.resize(1, NullTensorID); } -TensorShape FullyConnectedLayerNode::compute_weights_shape(TensorShape input_shape, unsigned int num_outputs) +TensorDescriptor FullyConnectedLayerNode::compute_weights_descriptor(const TensorDescriptor &input_descriptor, + unsigned int num_outputs) { unsigned int num_weights = 1; - unsigned int num_dimensions = input_shape.num_dimensions(); + unsigned int num_dimensions = input_descriptor.shape.num_dimensions(); // Ignore the batch dimension if there is one: if(num_dimensions == 2 || num_dimensions == 4) { @@ -49,20 +50,29 @@ TensorShape FullyConnectedLayerNode::compute_weights_shape(TensorShape input_sha } for(unsigned int i = 0; i < num_dimensions; i++) { - num_weights *= input_shape[i]; + num_weights *= input_descriptor.shape[i]; } - return TensorShape(num_weights, num_outputs); + + TensorDescriptor weights_descriptor = input_descriptor; + weights_descriptor.shape = TensorShape(num_weights, num_outputs); + + return weights_descriptor; } -TensorShape FullyConnectedLayerNode::compute_output_shape(TensorShape input_shape, unsigned int num_outputs) +TensorDescriptor FullyConnectedLayerNode::compute_output_descriptor(const TensorDescriptor &input_descriptor, + unsigned int num_outputs) { // Note: Only 1D batch space is supported at the moment - unsigned int batches = input_shape[1]; - if(input_shape.num_dimensions() > 2) + unsigned int batches = input_descriptor.shape[1]; + if(input_descriptor.shape.num_dimensions() > 2) { - batches = input_shape[3]; + batches = input_descriptor.shape[3]; } - return TensorShape(num_outputs, batches); + + TensorDescriptor output_descriptor = input_descriptor; + output_descriptor.shape = TensorShape(num_outputs, batches); + + return output_descriptor; } bool FullyConnectedLayerNode::forward_descriptors() @@ -83,15 +93,7 @@ TensorDescriptor FullyConnectedLayerNode::configure_output(size_t idx) const const Tensor *src = input(0); ARM_COMPUTE_ERROR_ON(src == nullptr); - TensorDescriptor output_info = src->desc(); - TensorShape output_shape = compute_output_shape(src->desc().shape, _num_outputs); - output_info.shape = output_shape; - return output_info; -} - -Status FullyConnectedLayerNode::validate() -{ - return Status{}; + return compute_output_descriptor(src->desc(), _num_outputs); } NodeType FullyConnectedLayerNode::type() const -- cgit v1.2.1