aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph/frontend/Layers.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/graph/frontend/Layers.h')
-rw-r--r--arm_compute/graph/frontend/Layers.h48
1 files changed, 45 insertions, 3 deletions
diff --git a/arm_compute/graph/frontend/Layers.h b/arm_compute/graph/frontend/Layers.h
index 1a71c89e54..d062d5a53e 100644
--- a/arm_compute/graph/frontend/Layers.h
+++ b/arm_compute/graph/frontend/Layers.h
@@ -578,6 +578,34 @@ public:
: _num_outputs(num_outputs),
_weights(std::move(weights)),
_bias(std::move(bias)),
+ _weights_ss(nullptr),
+ _bias_ss(nullptr),
+ _fc_info(fc_info),
+ _weights_quant_info(std::move(weights_quant_info)),
+ _out_quant_info(std::move(out_quant_info))
+ {
+ }
+
+ /** Construct a fully connected layer.
+ *
+ * @param[in] num_outputs Number of outputs.
+ * @param[in] sub_stream_weights Graph sub-stream for the weights.
+ * @param[in] sub_stream_bias Graph sub-stream for the bias.
+ * @param[in] fc_info (Optional) Fully connected layer metadata
+ * @param[in] weights_quant_info (Optional) Weights quantization information
+ * @param[in] out_quant_info (Optional) Output quantization info
+ */
+ FullyConnectedLayer(unsigned int num_outputs,
+ SubStream &&sub_stream_weights,
+ SubStream &&sub_stream_bias,
+ const FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(),
+ const QuantizationInfo weights_quant_info = QuantizationInfo(),
+ const QuantizationInfo out_quant_info = QuantizationInfo())
+ : _num_outputs(num_outputs),
+ _weights(nullptr),
+ _bias(nullptr),
+ _weights_ss(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream_weights))),
+ _bias_ss(arm_compute::support::cpp14::make_unique<SubStream>(std::move(sub_stream_bias))),
_fc_info(fc_info),
_weights_quant_info(std::move(weights_quant_info)),
_out_quant_info(std::move(out_quant_info))
@@ -594,15 +622,29 @@ public:
{
NodeParams common_params = { name(), s.hints().target_hint };
NodeIdxPair input = { s.tail_node(), 0 };
- return GraphBuilder::add_fully_connected_layer(s.graph(), common_params, input, _num_outputs,
- std::move(_weights), std::move(_bias), _fc_info,
- std::move(_weights_quant_info), std::move(_out_quant_info));
+ if(_weights != nullptr)
+ {
+ return GraphBuilder::add_fully_connected_layer(s.graph(), common_params, input, _num_outputs,
+ std::move(_weights), std::move(_bias), _fc_info,
+ std::move(_weights_quant_info), std::move(_out_quant_info));
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR_ON(_weights_ss == nullptr);
+
+ NodeID bias_nid = (_bias_ss == nullptr) ? EmptyNodeID : _bias_ss->tail_node();
+ return GraphBuilder::add_fully_connected_layer(s.graph(), common_params, input, _num_outputs,
+ _weights_ss->tail_node(), bias_nid, _fc_info,
+ std::move(_out_quant_info));
+ }
}
private:
unsigned int _num_outputs;
ITensorAccessorUPtr _weights;
ITensorAccessorUPtr _bias;
+ std::unique_ptr<SubStream> _weights_ss;
+ std::unique_ptr<SubStream> _bias_ss;
const FullyConnectedLayerInfo _fc_info;
const QuantizationInfo _weights_quant_info;
const QuantizationInfo _out_quant_info;