diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2018-04-27 19:07:19 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:51:17 +0000 |
commit | cac13b1cfd593889271f8e2191be2039b8d88f36 (patch) | |
tree | d1c5196877d7fbd5dcfbb9f9003faf6035f82a33 /src/graph/nodes/FullyConnectedLayer.cpp | |
parent | ad0c7388f6261989a268ffb2d042f2bd80736e3f (diff) | |
download | ComputeLibrary-cac13b1cfd593889271f8e2191be2039b8d88f36.tar.gz |
COMPMID-1097: Port mobilenet to NHWC
Change-Id: I789065bfa0d4ef133388e1904c5caf31e450f80f
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129495
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/graph/nodes/FullyConnectedLayer.cpp')
-rw-r--r-- | src/graph/nodes/FullyConnectedLayer.cpp | 38 |
1 files changed, 20 insertions, 18 deletions
diff --git a/src/graph/nodes/FullyConnectedLayer.cpp b/src/graph/nodes/FullyConnectedLayer.cpp index cbf2b35ddd..d94a7851ff 100644 --- a/src/graph/nodes/FullyConnectedLayer.cpp +++ b/src/graph/nodes/FullyConnectedLayer.cpp @@ -38,10 +38,11 @@ FullyConnectedLayerNode::FullyConnectedLayerNode(unsigned int num_outputs) _outputs.resize(1, NullTensorID); } -TensorShape FullyConnectedLayerNode::compute_weights_shape(TensorShape input_shape, unsigned int num_outputs) +TensorDescriptor FullyConnectedLayerNode::compute_weights_descriptor(const TensorDescriptor &input_descriptor, + unsigned int num_outputs) { unsigned int num_weights = 1; - unsigned int num_dimensions = input_shape.num_dimensions(); + unsigned int num_dimensions = input_descriptor.shape.num_dimensions(); // Ignore the batch dimension if there is one: if(num_dimensions == 2 || num_dimensions == 4) { @@ -49,20 +50,29 @@ TensorShape FullyConnectedLayerNode::compute_weights_shape(TensorShape input_sha } for(unsigned int i = 0; i < num_dimensions; i++) { - num_weights *= input_shape[i]; + num_weights *= input_descriptor.shape[i]; } - return TensorShape(num_weights, num_outputs); + + TensorDescriptor weights_descriptor = input_descriptor; + weights_descriptor.shape = TensorShape(num_weights, num_outputs); + + return weights_descriptor; } -TensorShape FullyConnectedLayerNode::compute_output_shape(TensorShape input_shape, unsigned int num_outputs) +TensorDescriptor FullyConnectedLayerNode::compute_output_descriptor(const TensorDescriptor &input_descriptor, + unsigned int num_outputs) { // Note: Only 1D batch space is supported at the moment - unsigned int batches = input_shape[1]; - if(input_shape.num_dimensions() > 2) + unsigned int batches = input_descriptor.shape[1]; + if(input_descriptor.shape.num_dimensions() > 2) { - batches = input_shape[3]; + batches = input_descriptor.shape[3]; } - return TensorShape(num_outputs, batches); + + TensorDescriptor output_descriptor = input_descriptor; + output_descriptor.shape = TensorShape(num_outputs, batches); + + return output_descriptor; } bool FullyConnectedLayerNode::forward_descriptors() @@ -83,15 +93,7 @@ TensorDescriptor FullyConnectedLayerNode::configure_output(size_t idx) const const Tensor *src = input(0); ARM_COMPUTE_ERROR_ON(src == nullptr); - TensorDescriptor output_info = src->desc(); - TensorShape output_shape = compute_output_shape(src->desc().shape, _num_outputs); - output_info.shape = output_shape; - return output_info; -} - -Status FullyConnectedLayerNode::validate() -{ - return Status{}; + return compute_output_descriptor(src->desc(), _num_outputs); } NodeType FullyConnectedLayerNode::type() const |