diff options
Diffstat (limited to 'arm_compute/graph/frontend')
-rw-r--r-- | arm_compute/graph/frontend/IStreamOperators.h | 12 | ||||
-rw-r--r-- | arm_compute/graph/frontend/Layers.h | 2 | ||||
-rw-r--r-- | arm_compute/graph/frontend/Types.h | 2 |
3 files changed, 15 insertions, 1 deletions
diff --git a/arm_compute/graph/frontend/IStreamOperators.h b/arm_compute/graph/frontend/IStreamOperators.h index 350d78fd1c..4d680f9a0e 100644 --- a/arm_compute/graph/frontend/IStreamOperators.h +++ b/arm_compute/graph/frontend/IStreamOperators.h @@ -96,6 +96,18 @@ inline IStream &operator<<(IStream &s, DepthwiseConvolutionMethod depthwise_conv s.hints().depthwise_convolution_method_hint = depthwise_convolution_method_hint; return s; } +/** Overloaded stream operator to provide a fast math hint to the graph + * + * @param[in, out] s Stream to provide the hint to + * @param[in] fast_math_hint Convolution method hint to be considered + * + * @return Updated stream + */ +inline IStream &operator<<(IStream &s, FastMathHint fast_math_hint) +{ + s.hints().fast_math_hint = fast_math_hint; + return s; +} } // namespace frontend } // namespace graph } // namespace arm_compute diff --git a/arm_compute/graph/frontend/Layers.h b/arm_compute/graph/frontend/Layers.h index 54cf515aa7..d122a7a967 100644 --- a/arm_compute/graph/frontend/Layers.h +++ b/arm_compute/graph/frontend/Layers.h @@ -197,7 +197,7 @@ public: NodeParams common_params = { name(), s.hints().target_hint }; return GraphBuilder::add_convolution_node(s.graph(), common_params, input, Size2D(_conv_width, _conv_height), _ofm, _conv_info, _num_groups, - s.hints().convolution_method_hint, + s.hints().convolution_method_hint, s.hints().fast_math_hint, std::move(_weights), std::move(_bias), std::move(_weights_quant_info), std::move(_out_quant_info)); } diff --git a/arm_compute/graph/frontend/Types.h b/arm_compute/graph/frontend/Types.h index 6cf7460900..47893613c7 100644 --- a/arm_compute/graph/frontend/Types.h +++ b/arm_compute/graph/frontend/Types.h @@ -45,6 +45,7 @@ using graph::PoolingLayerInfo; using graph::PoolingType; using graph::Target; using graph::ConvolutionMethod; +using graph::FastMathHint; using graph::DepthwiseConvolutionMethod; using graph::TensorDescriptor; using graph::DimensionRoundingType; @@ -63,6 +64,7 @@ struct StreamHints Target target_hint = { Target::UNSPECIFIED }; /**< Target execution hint */ ConvolutionMethod convolution_method_hint = { ConvolutionMethod::DEFAULT }; /**< Convolution method hint */ DepthwiseConvolutionMethod depthwise_convolution_method_hint = { DepthwiseConvolutionMethod::DEFAULT }; /**< Depthwise Convolution method hint */ + FastMathHint fast_math_hint = { FastMathHint::DISABLED }; /**< Fast math hint */ }; } // namespace frontend } // namespace graph |