diff options
-rw-r--r-- | arm_compute/core/Types.h | 12 | ||||
-rw-r--r-- | arm_compute/graph/GraphBuilder.h | 6 | ||||
-rw-r--r-- | arm_compute/graph/Types.h | 1 | ||||
-rw-r--r-- | arm_compute/graph/frontend/Layers.h | 26 | ||||
-rw-r--r-- | src/graph/GraphBuilder.cpp | 6 |
5 files changed, 33 insertions, 18 deletions
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 343952f0b2..00370918bd 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -689,6 +689,18 @@ struct FullyConnectedLayerInfo bool transpose_weights{ true }; /**< Transpose weights if true. */ bool are_weights_reshaped{ false }; /**< Reshape the weights tensor if false. */ bool retain_internal_weights{ false }; /**< Retain internal reshaped weights. */ + + /** Sets the weights trained data layout + * + * @param[in] layout Data layout that the weights were trained with + * + * @return Updated object + */ + FullyConnectedLayerInfo &set_weights_trained_layout(DataLayout layout) + { + weights_trained_layout = layout; + return *this; + } }; /** Pooling Layer Information class */ diff --git a/arm_compute/graph/GraphBuilder.h b/arm_compute/graph/GraphBuilder.h index 5bb1df4a11..a2f7618876 100644 --- a/arm_compute/graph/GraphBuilder.h +++ b/arm_compute/graph/GraphBuilder.h @@ -218,6 +218,7 @@ public: * @param[in] num_outputs Number of output neurons * @param[in] weights_accessor (Optional) Accessor of the weights node data * @param[in] bias_accessor (Optional) Accessor of the bias node data + * @param[in] fc_info (Optional) Fully connected layer metadata * @param[in] weights_quant_info (Optional) Weights quantization info * @param[in] out_quant_info (Optional) Output quantization info * @@ -225,8 +226,9 @@ public: */ static NodeID add_fully_connected_layer(Graph &g, NodeParams params, NodeIdxPair input, unsigned int num_outputs, ITensorAccessorUPtr weights_accessor = nullptr, ITensorAccessorUPtr bias_accessor = nullptr, - const QuantizationInfo weights_quant_info = QuantizationInfo(), - const QuantizationInfo out_quant_info = QuantizationInfo()); + const FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), + const QuantizationInfo weights_quant_info = QuantizationInfo(), + const QuantizationInfo out_quant_info = QuantizationInfo()); /** Adds a normalization layer node to the graph * * @param[in] g Graph to add the node to diff --git a/arm_compute/graph/Types.h b/arm_compute/graph/Types.h index f22f50ac82..ee0bf429ad 100644 --- a/arm_compute/graph/Types.h +++ b/arm_compute/graph/Types.h @@ -48,6 +48,7 @@ using arm_compute::Size2D; using arm_compute::ActivationLayerInfo; using arm_compute::NormType; using arm_compute::NormalizationLayerInfo; +using arm_compute::FullyConnectedLayerInfo; using arm_compute::PadStrideInfo; using arm_compute::PoolingLayerInfo; using arm_compute::PoolingType; diff --git a/arm_compute/graph/frontend/Layers.h b/arm_compute/graph/frontend/Layers.h index a222c8546e..0a1a0cf1e4 100644 --- a/arm_compute/graph/frontend/Layers.h +++ b/arm_compute/graph/frontend/Layers.h @@ -384,17 +384,20 @@ public: * @param[in] num_outputs Number of outputs. * @param[in] weights Accessor to get weights from. * @param[in] bias Accessor to get bias from. + * @param[in] fc_info (Optional) Fully connected layer metadata * @param[in] weights_quant_info (Optional) Weights quantization information * @param[in] out_quant_info (Optional) Output quantization info */ - FullyConnectedLayer(unsigned int num_outputs, - ITensorAccessorUPtr weights, - ITensorAccessorUPtr bias, - const QuantizationInfo weights_quant_info = QuantizationInfo(), - const QuantizationInfo out_quant_info = QuantizationInfo()) + FullyConnectedLayer(unsigned int num_outputs, + ITensorAccessorUPtr weights, + ITensorAccessorUPtr bias, + const FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), + const QuantizationInfo weights_quant_info = QuantizationInfo(), + const QuantizationInfo out_quant_info = QuantizationInfo()) : _num_outputs(num_outputs), _weights(std::move(weights)), _bias(std::move(bias)), + _fc_info(fc_info), _weights_quant_info(std::move(weights_quant_info)), _out_quant_info(std::move(out_quant_info)) { @@ -405,16 +408,17 @@ public: NodeParams common_params = { name(), s.hints().target_hint }; NodeIdxPair input = { s.tail_node(), 0 }; return GraphBuilder::add_fully_connected_layer(s.graph(), common_params, input, _num_outputs, - std::move(_weights), std::move(_bias), + std::move(_weights), std::move(_bias), _fc_info, std::move(_weights_quant_info), std::move(_out_quant_info)); } private: - unsigned int _num_outputs; - ITensorAccessorUPtr _weights; - ITensorAccessorUPtr _bias; - const QuantizationInfo _weights_quant_info; - const QuantizationInfo _out_quant_info; + unsigned int _num_outputs; + ITensorAccessorUPtr _weights; + ITensorAccessorUPtr _bias; + const FullyConnectedLayerInfo _fc_info; + const QuantizationInfo _weights_quant_info; + const QuantizationInfo _out_quant_info; }; /** Normalization Layer */ diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp index 7f567fd559..007f69177a 100644 --- a/src/graph/GraphBuilder.cpp +++ b/src/graph/GraphBuilder.cpp @@ -424,6 +424,7 @@ NodeID GraphBuilder::add_flatten_node(Graph &g, NodeParams params, NodeIdxPair i NodeID GraphBuilder::add_fully_connected_layer(Graph &g, NodeParams params, NodeIdxPair input, unsigned int num_outputs, ITensorAccessorUPtr weights_accessor, ITensorAccessorUPtr bias_accessor, + const FullyConnectedLayerInfo fc_info, const QuantizationInfo weights_quant_info, const QuantizationInfo out_quant_info) { CHECK_NODEIDX_PAIR(input, g); @@ -451,11 +452,6 @@ NodeID GraphBuilder::add_fully_connected_layer(Graph &g, NodeParams params, Node b_nid = add_const_node_with_name(g, params, "Bias", b_desc, std::move(bias_accessor)); } - // Add fully connected info - // FIXME (COMPMID-1367) : Expose weights layout - FullyConnectedLayerInfo fc_info; - fc_info.weights_trained_layout = DataLayout::NCHW; - // Create fully connected node and connect NodeID fc_nid = g.add_node<FullyConnectedLayerNode>(num_outputs, out_quant_info, fc_info); g.add_connection(input.node_id, input.index, fc_nid, 0); |