diff options
Diffstat (limited to 'arm_compute/graph/nodes/FullyConnectedLayerNode.h')
-rw-r--r-- | arm_compute/graph/nodes/FullyConnectedLayerNode.h | 21 |
1 files changed, 17 insertions, 4 deletions
diff --git a/arm_compute/graph/nodes/FullyConnectedLayerNode.h b/arm_compute/graph/nodes/FullyConnectedLayerNode.h index a7712f46b9..3bcf386d64 100644 --- a/arm_compute/graph/nodes/FullyConnectedLayerNode.h +++ b/arm_compute/graph/nodes/FullyConnectedLayerNode.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -39,10 +39,22 @@ public: * @param[in] num_outputs Number of neurons in the layer * @param[in] out_quant_info (Optional) Output quantization info * @param[in] fc_info (Optional) Additional information about the fully connected layer + * @param[in] fast_math_hint (Optional) Fast math hint */ FullyConnectedLayerNode(unsigned int num_outputs, QuantizationInfo out_quant_info = QuantizationInfo(), - FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo()); + FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), + FastMathHint fast_math_hint = FastMathHint::Disabled); + /** Sets the fast math fast hint + * + * @param[in] hint Hint to use for fullyconnected layer + */ + void set_fast_math_hint(FastMathHint hint); + /** Fast math hint accessor + * + * @return Fast math hint to be used by the node + */ + FastMathHint fast_math_hint() const; /** Sets fused activation * * @param[in] fused_activation Fused activation to set @@ -61,7 +73,7 @@ public: */ static TensorDescriptor compute_weights_descriptor(const TensorDescriptor &input_descriptor, unsigned int num_outputs, - FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), + FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), const QuantizationInfo &weights_quant_info = QuantizationInfo()); /** Computes fully connected layer output descriptor * @@ -86,7 +98,7 @@ public: NodeType type() const override; bool forward_descriptors() override; TensorDescriptor configure_output(size_t idx) const override; - void accept(INodeVisitor &v) override; + void accept(INodeVisitor &v) override; static constexpr NodeType node_type = NodeType::FullyConnectedLayer; @@ -94,6 +106,7 @@ private: unsigned int _num_outputs; QuantizationInfo _out_quant_info; FullyConnectedLayerInfo _info; + FastMathHint _fast_math_hint; }; } // namespace graph } // namespace arm_compute |