aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph/frontend/Layers.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/graph/frontend/Layers.h')
-rw-r--r--arm_compute/graph/frontend/Layers.h59
1 files changed, 58 insertions, 1 deletions
diff --git a/arm_compute/graph/frontend/Layers.h b/arm_compute/graph/frontend/Layers.h
index ec69350f86..2b44d0e844 100644
--- a/arm_compute/graph/frontend/Layers.h
+++ b/arm_compute/graph/frontend/Layers.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -902,6 +902,63 @@ private:
PoolingLayerInfo _pool_info;
};
+/** Print Layer */
+class PrintLayer final : public ILayer
+{
+public:
+ /** Construct a print layer.
+ *
+ * Example usage to locally dequantize and print a tensor:
+ *
+ * Tensor *output = new Tensor();
+ * const auto transform = [output](ITensor *input)
+ * {
+ * output->allocator()->init(*input->info());
+ * output->info()->set_data_type(DataType::F32);
+ * output->allocator()->allocate();
+ *
+ * Window win;
+ * win.use_tensor_dimensions(input->info()->tensor_shape());
+ * Iterator in(input, win);
+ * Iterator out(output, win);
+ * execute_window_loop(win, [&](const Coordinates &)
+ * {
+ * *(reinterpret_cast<float *>(out.ptr())) = dequantize_qasymm8(*in.ptr(), input->info()->quantization_info().uniform());
+ * }, in, out);
+ *
+ * return output;
+ * };
+ *
+ * graph << InputLayer(input_descriptor.set_quantization_info(in_quant_info), get_input_accessor(common_params, nullptr, false))
+ * << ...
+ * << \\ CNN Layers
+ * << ...
+ * << PrintLayer(std::cout, IOFormatInfo(), transform)
+ * << ...
+ * << OutputLayer(get_output_accessor(common_params, 5));
+ *
+ * @param[in] stream Output stream.
+ * @param[in] format_info (Optional) Format info.
+ * @param[in] transform (Optional) Input transform function.
+ */
+ PrintLayer(std::ostream &stream, const IOFormatInfo &format_info = IOFormatInfo(), const std::function<ITensor *(ITensor *)> transform = nullptr)
+ : _stream(stream), _format_info(format_info), _transform(transform)
+ {
+ }
+
+ NodeID create_layer(IStream &s) override
+ {
+ NodeParams common_params = { name(), s.hints().target_hint };
+ NodeIdxPair input = { s.tail_node(), 0 };
+ return GraphBuilder::add_print_node(s.graph(), common_params, input, _stream, _format_info, _transform);
+ }
+
+private:
+ std::ostream &_stream;
+ const IOFormatInfo &_format_info;
+ const std::function<ITensor *(ITensor *)> _transform;
+};
+
/** PriorBox Layer */
class PriorBoxLayer final : public ILayer
{