aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2018-05-02 13:59:04 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:51:50 +0000
commit59631a174e1b5ef23bd3a0102f60b57c99502766 (patch)
tree5d8e15d7a3b65e5071db82e2937ee1808953823f /arm_compute/graph
parentef9e05978ab008e533cc76a8e6f10c9e86a880c1 (diff)
downloadComputeLibrary-59631a174e1b5ef23bd3a0102f60b57c99502766.tar.gz
COMPMID-1104 Add fast math hint in the graph API
Change-Id: I83db135fa94c6884e080f0229a9b6430d908c029 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129823 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'arm_compute/graph')
-rw-r--r--arm_compute/graph/GraphBuilder.h5
-rw-r--r--arm_compute/graph/TypePrinter.h18
-rw-r--r--arm_compute/graph/Types.h7
-rw-r--r--arm_compute/graph/backends/ValidateHelpers.h5
-rw-r--r--arm_compute/graph/frontend/IStreamOperators.h12
-rw-r--r--arm_compute/graph/frontend/Layers.h2
-rw-r--r--arm_compute/graph/frontend/Types.h2
-rw-r--r--arm_compute/graph/nodes/ConvolutionLayerNode.h17
8 files changed, 62 insertions, 6 deletions
diff --git a/arm_compute/graph/GraphBuilder.h b/arm_compute/graph/GraphBuilder.h
index bbbbefbcbb..aea28eb8d6 100644
--- a/arm_compute/graph/GraphBuilder.h
+++ b/arm_compute/graph/GraphBuilder.h
@@ -99,6 +99,8 @@ public:
ITensorAccessorUPtr beta_accessor = nullptr, ITensorAccessorUPtr gamma_accessor = nullptr);
/** Adds a convolution layer node to the graph
*
+ * TODO (COMPMID-1113): Add a graph descriptor for convolution layer node
+ *
* @param[in] g Graph to add the node to
* @param[in] params Common node parameters
* @param[in] input Input to the convolution layer node as a NodeID-Index pair
@@ -107,6 +109,7 @@ public:
* @param[in] conv_info Convolution layer information
* @param[in] num_groups (Optional) Number of groups for a grouped convolution. Defaults to 1
* @param[in] method (Optional) Convolution method to use
+ * @param[in] fast_math_hint (Optional) Fast math hint
* @param[in] weights_accessor (Optional) Accessor of the weights node data
* @param[in] bias_accessor (Optional) Accessor of the bias node data
* @param[in] weights_quant_info (Optional) Weights quantization info
@@ -116,7 +119,7 @@ public:
*/
static NodeID add_convolution_node(Graph &g, NodeParams params, NodeIdxPair input,
Size2D kernel_spatial_extend, unsigned int depth, PadStrideInfo conv_info,
- unsigned int num_groups = 1, ConvolutionMethod method = ConvolutionMethod::DEFAULT,
+ unsigned int num_groups = 1, ConvolutionMethod method = ConvolutionMethod::DEFAULT, FastMathHint fast_math_hint = FastMathHint::DISABLED,
ITensorAccessorUPtr weights_accessor = nullptr, ITensorAccessorUPtr bias_accessor = nullptr,
const QuantizationInfo weights_quant_info = QuantizationInfo(),
const QuantizationInfo out_quant_info = QuantizationInfo());
diff --git a/arm_compute/graph/TypePrinter.h b/arm_compute/graph/TypePrinter.h
index 0ecd57de9d..6babd3961d 100644
--- a/arm_compute/graph/TypePrinter.h
+++ b/arm_compute/graph/TypePrinter.h
@@ -295,6 +295,24 @@ inline ::std::ostream &operator<<(::std::ostream &os, const ConvolutionMethod &m
return os;
}
+/** Formatted output of the FastMathHint type. */
+inline ::std::ostream &operator<<(::std::ostream &os, const FastMathHint &hint)
+{
+ switch(hint)
+ {
+ case FastMathHint::ENABLED:
+ os << "ENABLED";
+ break;
+ case FastMathHint::DISABLED:
+ os << "DISABLED";
+ break;
+ default:
+ ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
+ }
+
+ return os;
+}
+
/** Formatted output of the DepthwiseConvolutionMethod type. */
inline ::std::ostream &operator<<(::std::ostream &os, const DepthwiseConvolutionMethod &method)
{
diff --git a/arm_compute/graph/Types.h b/arm_compute/graph/Types.h
index b195ed7eda..a910610c7a 100644
--- a/arm_compute/graph/Types.h
+++ b/arm_compute/graph/Types.h
@@ -116,6 +116,13 @@ enum class DepthwiseConvolutionMethod
OPTIMIZED_3x3, /**< Optimized 3x3 direct depthwise convolution */
};
+/** Enable or disable fast math for Convolution layer */
+enum class FastMathHint
+{
+ ENABLED, /**< Fast math enabled for Convolution layer */
+ DISABLED, /**< Fast math disabled for Convolution layer */
+};
+
/** Supported nodes */
enum class NodeType
{
diff --git a/arm_compute/graph/backends/ValidateHelpers.h b/arm_compute/graph/backends/ValidateHelpers.h
index c203e8c885..db3f8ba4f9 100644
--- a/arm_compute/graph/backends/ValidateHelpers.h
+++ b/arm_compute/graph/backends/ValidateHelpers.h
@@ -83,6 +83,7 @@ Status validate_convolution_layer(ConvolutionLayerNode &node)
const PadStrideInfo conv_info = node.convolution_info();
const ConvolutionMethod conv_algorithm = node.convolution_method();
+ //const bool fast_math = node.fast_math_hint() == FastMathHint::ENABLED; // FIXME (COMPMID-1138): uncomment once NEON and GLES support fast_math
// Validate function
Status status{};
@@ -95,7 +96,7 @@ Status validate_convolution_layer(ConvolutionLayerNode &node)
status = GEMMConvolutionLayer::validate(input, weights, biases, output, conv_info);
break;
case ConvolutionMethod::WINOGRAD:
- status = WinogradConvolutionLayer::validate(input, weights, biases, output, conv_info);
+ status = WinogradConvolutionLayer::validate(input, weights, biases, output, conv_info /*, fast_math*/);
break;
case ConvolutionMethod::DEFAULT:
status = ConvolutionLayer::validate(input, weights, biases, output, conv_info);
@@ -107,7 +108,7 @@ Status validate_convolution_layer(ConvolutionLayerNode &node)
// If validation fails try the Default approach
if(!bool(status))
{
- status = ConvolutionLayer::validate(input, weights, biases, output, conv_info);
+ status = ConvolutionLayer::validate(input, weights, biases, output, conv_info /*, fast_math*/);
if(bool(status))
{
ARM_COMPUTE_LOG_GRAPH_INFO("Switched ConvolutionLayer method of node with ID : "
diff --git a/arm_compute/graph/frontend/IStreamOperators.h b/arm_compute/graph/frontend/IStreamOperators.h
index 350d78fd1c..4d680f9a0e 100644
--- a/arm_compute/graph/frontend/IStreamOperators.h
+++ b/arm_compute/graph/frontend/IStreamOperators.h
@@ -96,6 +96,18 @@ inline IStream &operator<<(IStream &s, DepthwiseConvolutionMethod depthwise_conv
s.hints().depthwise_convolution_method_hint = depthwise_convolution_method_hint;
return s;
}
+/** Overloaded stream operator to provide a fast math hint to the graph
+ *
+ * @param[in, out] s Stream to provide the hint to
+ * @param[in] fast_math_hint Convolution method hint to be considered
+ *
+ * @return Updated stream
+ */
+inline IStream &operator<<(IStream &s, FastMathHint fast_math_hint)
+{
+ s.hints().fast_math_hint = fast_math_hint;
+ return s;
+}
} // namespace frontend
} // namespace graph
} // namespace arm_compute
diff --git a/arm_compute/graph/frontend/Layers.h b/arm_compute/graph/frontend/Layers.h
index 54cf515aa7..d122a7a967 100644
--- a/arm_compute/graph/frontend/Layers.h
+++ b/arm_compute/graph/frontend/Layers.h
@@ -197,7 +197,7 @@ public:
NodeParams common_params = { name(), s.hints().target_hint };
return GraphBuilder::add_convolution_node(s.graph(), common_params, input,
Size2D(_conv_width, _conv_height), _ofm, _conv_info, _num_groups,
- s.hints().convolution_method_hint,
+ s.hints().convolution_method_hint, s.hints().fast_math_hint,
std::move(_weights), std::move(_bias), std::move(_weights_quant_info), std::move(_out_quant_info));
}
diff --git a/arm_compute/graph/frontend/Types.h b/arm_compute/graph/frontend/Types.h
index 6cf7460900..47893613c7 100644
--- a/arm_compute/graph/frontend/Types.h
+++ b/arm_compute/graph/frontend/Types.h
@@ -45,6 +45,7 @@ using graph::PoolingLayerInfo;
using graph::PoolingType;
using graph::Target;
using graph::ConvolutionMethod;
+using graph::FastMathHint;
using graph::DepthwiseConvolutionMethod;
using graph::TensorDescriptor;
using graph::DimensionRoundingType;
@@ -63,6 +64,7 @@ struct StreamHints
Target target_hint = { Target::UNSPECIFIED }; /**< Target execution hint */
ConvolutionMethod convolution_method_hint = { ConvolutionMethod::DEFAULT }; /**< Convolution method hint */
DepthwiseConvolutionMethod depthwise_convolution_method_hint = { DepthwiseConvolutionMethod::DEFAULT }; /**< Depthwise Convolution method hint */
+ FastMathHint fast_math_hint = { FastMathHint::DISABLED }; /**< Fast math hint */
};
} // namespace frontend
} // namespace graph
diff --git a/arm_compute/graph/nodes/ConvolutionLayerNode.h b/arm_compute/graph/nodes/ConvolutionLayerNode.h
index d1186a8eae..aca60283d7 100644
--- a/arm_compute/graph/nodes/ConvolutionLayerNode.h
+++ b/arm_compute/graph/nodes/ConvolutionLayerNode.h
@@ -38,9 +38,11 @@ public:
*
* @param[in] info Convolution layer attributes
* @param[in] method (Optional) Convolution method to use
+ * @param[in] fast_math_hint (Optional) Fast math hint
* @param[in] out_quant_info (Optional) Output quantization info
*/
- ConvolutionLayerNode(PadStrideInfo info, ConvolutionMethod method = ConvolutionMethod::DEFAULT, QuantizationInfo out_quant_info = QuantizationInfo());
+ ConvolutionLayerNode(PadStrideInfo info, ConvolutionMethod method = ConvolutionMethod::DEFAULT, FastMathHint fast_math_hint = FastMathHint::DISABLED,
+ QuantizationInfo out_quant_info = QuantizationInfo());
/** Sets the convolution layer method to use
*
* @param[in] method Method to use for convolution
@@ -51,9 +53,19 @@ public:
* @note This is an indication on which convolution layer implementation to use,
* if it fails to be created the library's heuristic approach will be used
*
- * @return Convolution layer method do be used by the node
+ * @return Convolution layer method to be used by the node
*/
ConvolutionMethod convolution_method() const;
+ /** Sets the fast math fast hint
+ *
+ * @param[in] hint Hint to use for convolution
+ */
+ void set_fast_math_hint(FastMathHint hint);
+ /** Fast math hint accessor
+ *
+ * @return Fast math hint to be used by the node
+ */
+ FastMathHint fast_math_hint() const;
/** Convolution metadata accessor
*
* @return Convolution information
@@ -80,6 +92,7 @@ public:
private:
PadStrideInfo _info;
ConvolutionMethod _method;
+ FastMathHint _fast_math_hint;
QuantizationInfo _out_quant_info;
};
} // namespace graph