From e8f05da5fb919aa209e1bf0e5c70dd15fff84b7f Mon Sep 17 00:00:00 2001 From: thecha01 Date: Mon, 24 Aug 2020 17:21:41 +0100 Subject: Add ArgMinMax layer node to Graph API Change-Id: I2ccb2c65edd2932b76e905af3d747324b65c2f7f Signed-off-by: thecha01 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3910 Comments-Addressed: Arm Jenkins Reviewed-by: Michele Di Giorgio Tested-by: Arm Jenkins --- arm_compute/graph/frontend/Layers.h | 42 +++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) (limited to 'arm_compute/graph/frontend/Layers.h') diff --git a/arm_compute/graph/frontend/Layers.h b/arm_compute/graph/frontend/Layers.h index 2dd8b31ccf..da664e0448 100644 --- a/arm_compute/graph/frontend/Layers.h +++ b/arm_compute/graph/frontend/Layers.h @@ -145,6 +145,48 @@ private: const QuantizationInfo _out_quant_info; }; +/** ArgMinMax Layer */ +class ArgMinMaxLayer final : public ILayer +{ +public: + /** Construct an activation layer. + * + * @param[in] op Reduction Operation: min or max + * @param[in] axis Axis to perform reduction along + * @param[in] out_data_type (Optional) Output tensor data type + * @param[in] out_quant_info (Optional) Output quantization info + */ + ArgMinMaxLayer(ReductionOperation op, + unsigned int axis, + DataType out_data_type = DataType::UNKNOWN, + const QuantizationInfo out_quant_info = QuantizationInfo()) + : _op(op), + _axis(axis), + _out_data_type(out_data_type), + _out_quant_info(std::move(out_quant_info)) + { + } + + /** Create layer and add to the given stream. + * + * @param[in] s Stream to add layer to. + * + * @return ID of the created node. + */ + NodeID create_layer(IStream &s) override + { + NodeParams common_params = { name(), s.hints().target_hint }; + NodeIdxPair input = { s.tail_node(), 0 }; + return GraphBuilder::add_arg_min_max_node(s.graph(), common_params, input, _op, _axis, _out_data_type, std::move(_out_quant_info)); + } + +private: + ReductionOperation _op; + unsigned int _axis; + DataType _out_data_type; + QuantizationInfo _out_quant_info; +}; + /** Batchnormalization Layer */ class BatchNormalizationLayer final : public ILayer { -- cgit v1.2.1