aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2019-03-08 14:52:17 +0000
committerMichele Di Giorgio <michele.digiorgio@arm.com>2019-03-13 11:58:43 +0000
commita42f55f4184cb63c73b74ed76759bdcbb18656e8 (patch)
tree4b82cfde94994bd5ab350fd2362896dee8391c68 /arm_compute/graph
parentacce504ec4aebe5e5da470c1cfc3cee401ff11f3 (diff)
downloadComputeLibrary-a42f55f4184cb63c73b74ed76759bdcbb18656e8.tar.gz
COMPMID-1995: Allow weights and bias to be passed as SubStream in FullyConnectedLayer
Change-Id: Iae2e7d55fd66d5932c29f78ef3112289d9b69b84 Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/848 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Marquez <pablo.tello@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'arm_compute/graph')
-rw-r--r--arm_compute/graph/GraphBuilder.h17
-rw-r--r--arm_compute/graph/frontend/Layers.h48
2 files changed, 62 insertions, 3 deletions
diff --git a/arm_compute/graph/GraphBuilder.h b/arm_compute/graph/GraphBuilder.h
index 1296f56482..a2a938b1cc 100644
--- a/arm_compute/graph/GraphBuilder.h
+++ b/arm_compute/graph/GraphBuilder.h
@@ -236,6 +236,23 @@ public:
static NodeID add_flatten_node(Graph &g, NodeParams params, NodeIdxPair input);
/** Adds a fully connected layer node to the graph
*
+ * @param[in] g Graph to add the layer to
+ * @param[in] params Common node parameters
+ * @param[in] input Input to the fully connected layer node as a NodeID-Index pair
+ * @param[in] num_outputs Number of output neurons
+ * @param[in] weights_nid Node ID of the weights node data
+ * @param[in] bias_nid (Optional) Node ID of the bias node data. Defaults to EmptyNodeID
+ * @param[in] fc_info (Optional) Fully connected layer metadata
+ * @param[in] out_quant_info (Optional) Output quantization info
+ *
+ * @return Node ID of the created node, EmptyNodeID in case of error
+ */
+ static NodeID add_fully_connected_layer(Graph &g, NodeParams params, NodeIdxPair input, unsigned int num_outputs,
+ NodeID weights_nid, NodeID bias_nid = EmptyNodeID,
+ const FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(),
+ const QuantizationInfo out_quant_info = QuantizationInfo());
+ /** Adds a fully connected layer node to the graph
+ *
* @param[in] g Graph to add the layer to
* @param[in] params Common node parameters
* @param[in] input Input to the fully connected layer node as a NodeID-Index pair
diff --git a/arm_compute/graph/frontend/Layers.h b/arm_compute/graph/frontend/Layers.h
index 1a71c89e54..d062d5a53e 100644
--- a/arm_compute/graph/frontend/Layers.h
+++ b/arm_compute/graph/frontend/Layers.h
@@ -578,6 +578,34 @@ public:
: _num_outputs(num_outputs),
_weights(std::move(weights)),
_bias(std::move(bias)),
+ _weights_ss(nullptr),
+ _bias_ss(nullptr),
+ _fc_info(fc_info),
+ _weights_quant_info(std::move(weights_quant_info)),
+ _out_quant_info(std::move(out_quant_info))
+ {
+ }
+
+ /** Construct a fully connected layer.
+ *
+ * @param[in] num_outputs Number of outputs.
+ * @param[in] sub_stream_weights Graph sub-stream for the weights.
+ * @param[in] sub_stream_bias Graph sub-stream for the bias.
+ * @param[in] fc_info (Optional) Fully connected layer metadata
+ * @param[in] weights_quant_info (Optional) Weights quantization information
+ * @param[in] out_quant_info (Optional) Output quantization info
+ */
+ FullyConnectedLayer(unsigned int num_outputs,
+ SubStream &&sub_stream_weights,
+ SubStream &&sub_stream_bias,
+ const FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(),
+ const QuantizationInfo weights_quant_info = QuantizationInfo(),
+ const QuantizationInfo out_quant_info = QuantizationInfo())
+ : _num_outputs(num_outputs),
+ _weights(nullptr),
+ _bias(nullptr),
+ _weights_ss(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream_weights))),
+ _bias_ss(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream_bias))),
_fc_info(fc_info),
_weights_quant_info(std::move(weights_quant_info)),
_out_quant_info(std::move(out_quant_info))
@@ -594,15 +622,29 @@ public:
{
NodeParams common_params = { name(), s.hints().target_hint };
NodeIdxPair input = { s.tail_node(), 0 };
- return GraphBuilder::add_fully_connected_layer(s.graph(), common_params, input, _num_outputs,
- std::move(_weights), std::move(_bias), _fc_info,
- std::move(_weights_quant_info), std::move(_out_quant_info));
+ if(_weights != nullptr)
+ {
+ return GraphBuilder::add_fully_connected_layer(s.graph(), common_params, input, _num_outputs,
+ std::move(_weights), std::move(_bias), _fc_info,
+ std::move(_weights_quant_info), std::move(_out_quant_info));
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR_ON(_weights_ss == nullptr);
+
+ NodeID bias_nid = (_bias_ss == nullptr) ? EmptyNodeID : _bias_ss->tail_node();
+ return GraphBuilder::add_fully_connected_layer(s.graph(), common_params, input, _num_outputs,
+ _weights_ss->tail_node(), bias_nid, _fc_info,
+ std::move(_out_quant_info));
+ }
}
private:
unsigned int _num_outputs;
ITensorAccessorUPtr _weights;
ITensorAccessorUPtr _bias;
+ std::unique_ptr<SubStream> _weights_ss;
+ std::unique_ptr<SubStream> _bias_ss;
const FullyConnectedLayerInfo _fc_info;
const QuantizationInfo _weights_quant_info;
const QuantizationInfo _out_quant_info;