diff options
Diffstat (limited to 'arm_compute')
-rw-r--r-- | arm_compute/core/Types.h | 1 | ||||
-rw-r--r-- | arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h | 6 |
2 files changed, 4 insertions, 3 deletions
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 0a25277b57..f4955ed457 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -805,6 +805,7 @@ struct FullyConnectedLayerInfo bool transpose_weights{ true }; /**< Transpose weights if true. */ bool are_weights_reshaped{ false }; /**< Reshape the weights tensor if false. */ bool retain_internal_weights{ false }; /**< Retain internal reshaped weights. */ + bool fp_mixed_precision{ false }; /**< Use wider accumulators (32 bit instead of 16 for FP16) to improve accuracy. */ /** Sets the weights trained data layout * diff --git a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h index 7f872532e4..f284359663 100644 --- a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h +++ b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h @@ -174,9 +174,9 @@ public: void prepare() override; private: - void configure_fc_fc(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, bool retain_internal_weights); - void configure_conv_fc(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, bool retain_internal_weights); - void configure_mm(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, bool retain_internal_weights); + void configure_fc_fc(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, const FullyConnectedLayerInfo &fc_info); + void configure_conv_fc(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, const FullyConnectedLayerInfo &fc_info); + void configure_mm(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, const FullyConnectedLayerInfo &fc_info); MemoryGroup _memory_group; IWeightsManager *_weights_manager; |