diff options
Diffstat (limited to 'arm_compute/graph/frontend/Layers.h')
-rw-r--r-- | arm_compute/graph/frontend/Layers.h | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/arm_compute/graph/frontend/Layers.h b/arm_compute/graph/frontend/Layers.h index 2dd8b31ccf..da664e0448 100644 --- a/arm_compute/graph/frontend/Layers.h +++ b/arm_compute/graph/frontend/Layers.h @@ -145,6 +145,48 @@ private: const QuantizationInfo _out_quant_info; }; +/** ArgMinMax Layer */ +class ArgMinMaxLayer final : public ILayer +{ +public: + /** Construct an activation layer. + * + * @param[in] op Reduction Operation: min or max + * @param[in] axis Axis to perform reduction along + * @param[in] out_data_type (Optional) Output tensor data type + * @param[in] out_quant_info (Optional) Output quantization info + */ + ArgMinMaxLayer(ReductionOperation op, + unsigned int axis, + DataType out_data_type = DataType::UNKNOWN, + const QuantizationInfo out_quant_info = QuantizationInfo()) + : _op(op), + _axis(axis), + _out_data_type(out_data_type), + _out_quant_info(std::move(out_quant_info)) + { + } + + /** Create layer and add to the given stream. + * + * @param[in] s Stream to add layer to. + * + * @return ID of the created node. + */ + NodeID create_layer(IStream &s) override + { + NodeParams common_params = { name(), s.hints().target_hint }; + NodeIdxPair input = { s.tail_node(), 0 }; + return GraphBuilder::add_arg_min_max_node(s.graph(), common_params, input, _op, _axis, _out_data_type, std::move(_out_quant_info)); + } + +private: + ReductionOperation _op; + unsigned int _axis; + DataType _out_data_type; + QuantizationInfo _out_quant_info; +}; + /** Batchnormalization Layer */ class BatchNormalizationLayer final : public ILayer { |