From 195b0ba457d0020e1f54fb0c0378040e1c75d510 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Thu, 2 Aug 2018 17:18:51 +0100 Subject: MLCE-36: FC tranpose weights Change-Id: I3b8a6c00e61ba6da459ca5fc7275393f9d073aed Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/142533 Reviewed-by: Anthony Barbier Tested-by: Jenkins --- arm_compute/core/Types.h | 11 +++++++++++ arm_compute/graph/frontend/Types.h | 1 + arm_compute/graph/nodes/FullyConnectedLayerNode.h | 2 ++ 3 files changed, 14 insertions(+) (limited to 'arm_compute') diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 81d652dd7d..d9109e4565 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -701,6 +701,17 @@ struct FullyConnectedLayerInfo weights_trained_layout = layout; return *this; } + /** Sets the transpose weights flag + * + * @param[in] should_transpose_weights Boolean flag indicating if weights should be transposed + * + * @return Updated object + */ + FullyConnectedLayerInfo &set_transpose_weights(bool should_transpose_weights) + { + transpose_weights = should_transpose_weights; + return *this; + } }; /** Pooling Layer Information class */ diff --git a/arm_compute/graph/frontend/Types.h b/arm_compute/graph/frontend/Types.h index 8f6312f318..ebbf7101ac 100644 --- a/arm_compute/graph/frontend/Types.h +++ b/arm_compute/graph/frontend/Types.h @@ -39,6 +39,7 @@ using graph::TensorShape; using graph::PermutationVector; using graph::ActivationLayerInfo; +using graph::FullyConnectedLayerInfo; using graph::NormalizationLayerInfo; using graph::NormType; using graph::PadStrideInfo; diff --git a/arm_compute/graph/nodes/FullyConnectedLayerNode.h b/arm_compute/graph/nodes/FullyConnectedLayerNode.h index 1bff6006c8..33f9b1eefe 100644 --- a/arm_compute/graph/nodes/FullyConnectedLayerNode.h +++ b/arm_compute/graph/nodes/FullyConnectedLayerNode.h @@ -49,12 +49,14 @@ public: * * @param[in] input_descriptor Input descriptor * @param[in] num_outputs Number of output neurons + * @param[in] fc_info (Optional) Additional information about the fully connected layer * @param[in] weights_quant_info (Optional) Weights quantization info * * @return Weights descriptor */ static TensorDescriptor compute_weights_descriptor(const TensorDescriptor &input_descriptor, unsigned int num_outputs, + FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), QuantizationInfo weights_quant_info = QuantizationInfo()); /** Computes fully connected layer output descriptor * -- cgit v1.2.1