diff options
Diffstat (limited to 'src/graph/nodes/FullyConnectedLayer.cpp')
-rw-r--r-- | src/graph/nodes/FullyConnectedLayer.cpp | 40 |
1 files changed, 27 insertions, 13 deletions
diff --git a/src/graph/nodes/FullyConnectedLayer.cpp b/src/graph/nodes/FullyConnectedLayer.cpp index 34c432a1ce..1eed69ddaf 100644 --- a/src/graph/nodes/FullyConnectedLayer.cpp +++ b/src/graph/nodes/FullyConnectedLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 ARM Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,22 +21,36 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#include "arm_compute/graph/nodes/FullyConnectedLayerNode.h" - #include "arm_compute/core/Utils.h" #include "arm_compute/graph/Graph.h" #include "arm_compute/graph/INodeVisitor.h" +#include "arm_compute/graph/nodes/FullyConnectedLayerNode.h" namespace arm_compute { namespace graph { -FullyConnectedLayerNode::FullyConnectedLayerNode(unsigned int num_outputs, QuantizationInfo out_quant_info, FullyConnectedLayerInfo fc_info) - : _num_outputs(num_outputs), _out_quant_info(std::move(out_quant_info)), _info(fc_info) +FullyConnectedLayerNode::FullyConnectedLayerNode(unsigned int num_outputs, + QuantizationInfo out_quant_info, + FullyConnectedLayerInfo fc_info, + FastMathHint fast_math_hint) + : _num_outputs(num_outputs), + _out_quant_info(std::move(out_quant_info)), + _info(fc_info), + _fast_math_hint(fast_math_hint) { _input_edges.resize(3, EmptyEdgeID); _outputs.resize(1, NullTensorID); } +void FullyConnectedLayerNode::set_fast_math_hint(FastMathHint hint) +{ + _fast_math_hint = hint; +} + +FastMathHint FullyConnectedLayerNode::fast_math_hint() const +{ + return _fast_math_hint; +} void FullyConnectedLayerNode::set_fused_activation(ActivationLayerInfo fused_activation) { @@ -51,11 +65,11 @@ TensorDescriptor FullyConnectedLayerNode::compute_weights_descriptor(const Tenso unsigned int num_weights = 1; unsigned int num_dimensions = input_descriptor.shape.num_dimensions(); // Ignore the batch dimension if there is one: - if(num_dimensions == 2 || num_dimensions == 4) + if (num_dimensions == 2 || num_dimensions == 4) { num_dimensions--; } - for(unsigned int i = 0; i < num_dimensions; i++) + for (unsigned int i = 0; i < num_dimensions; i++) { num_weights *= input_descriptor.shape[i]; } @@ -64,13 +78,13 @@ TensorDescriptor FullyConnectedLayerNode::compute_weights_descriptor(const Tenso weights_descriptor.shape = TensorShape(num_weights, num_outputs); // If weights are tranposed, use tranposed shape - if(!fc_info.transpose_weights) + if (!fc_info.transpose_weights) { weights_descriptor.shape = TensorShape(num_outputs, num_weights); } // Set quantization info if present - if(!weights_quant_info.empty()) + if (!weights_quant_info.empty()) { weights_descriptor.quant_info = weights_quant_info; } @@ -84,7 +98,7 @@ TensorDescriptor FullyConnectedLayerNode::compute_output_descriptor(const Tensor { // Note: Only 1D batch space is supported at the moment unsigned int batches = input_descriptor.shape[1]; - if(input_descriptor.shape.num_dimensions() > 2) + if (input_descriptor.shape.num_dimensions() > 2) { batches = input_descriptor.shape[3]; } @@ -94,7 +108,7 @@ TensorDescriptor FullyConnectedLayerNode::compute_output_descriptor(const Tensor output_descriptor.shape = TensorShape(num_outputs, batches); // Set quantization info if present - if(!out_quant_info.empty()) + if (!out_quant_info.empty()) { output_descriptor.quant_info = out_quant_info; } @@ -109,7 +123,7 @@ FullyConnectedLayerInfo FullyConnectedLayerNode::info() const bool FullyConnectedLayerNode::forward_descriptors() { - if((input_id(0) != NullTensorID) && (output_id(0) != NullTensorID)) + if ((input_id(0) != NullTensorID) && (output_id(0) != NullTensorID)) { Tensor *dst = output(0); ARM_COMPUTE_ERROR_ON(dst == nullptr); @@ -138,4 +152,4 @@ void FullyConnectedLayerNode::accept(INodeVisitor &v) v.visit(*this); } } // namespace graph -} // namespace arm_compute
\ No newline at end of file +} // namespace arm_compute |