diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2017-12-22 15:27:52 +0000 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:49:16 +0000 |
commit | d8734b55d89f05901ba9a75349761a9c955d9243 (patch) | |
tree | e23d53a0fb73251f7416993e4d3a7241e533e79e /arm_compute/graph2/nodes | |
parent | 7390e05561a5c49306ebbf2eb2dcb1848546f201 (diff) | |
download | ComputeLibrary-d8734b55d89f05901ba9a75349761a9c955d9243.tar.gz |
COMPMID-793 : Add graph intermediate representation
Change-Id: Ic1685de4e19e0ac79669ef2da64e1dc96c7ea0bf
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/115248
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'arm_compute/graph2/nodes')
-rw-r--r-- | arm_compute/graph2/nodes/ActivationLayerNode.h | 59 | ||||
-rw-r--r-- | arm_compute/graph2/nodes/BatchNormalizationLayerNode.h | 71 | ||||
-rw-r--r-- | arm_compute/graph2/nodes/ConstNode.h | 54 | ||||
-rw-r--r-- | arm_compute/graph2/nodes/ConvolutionLayerNode.h | 83 | ||||
-rw-r--r-- | arm_compute/graph2/nodes/DepthConcatenateLayerNode.h | 77 | ||||
-rw-r--r-- | arm_compute/graph2/nodes/DepthwiseConvolutionLayerNode.h | 83 | ||||
-rw-r--r-- | arm_compute/graph2/nodes/EltwiseLayerNode.h | 59 | ||||
-rw-r--r-- | arm_compute/graph2/nodes/FlattenLayerNode.h | 48 | ||||
-rw-r--r-- | arm_compute/graph2/nodes/FullyConnectedLayerNode.h | 74 | ||||
-rw-r--r-- | arm_compute/graph2/nodes/InputNode.h | 54 | ||||
-rw-r--r-- | arm_compute/graph2/nodes/Nodes.h | 43 | ||||
-rw-r--r-- | arm_compute/graph2/nodes/NodesFwd.h | 50 | ||||
-rw-r--r-- | arm_compute/graph2/nodes/NormalizationLayerNode.h | 59 | ||||
-rw-r--r-- | arm_compute/graph2/nodes/OutputNode.h | 48 | ||||
-rw-r--r-- | arm_compute/graph2/nodes/PoolingLayerNode.h | 67 | ||||
-rw-r--r-- | arm_compute/graph2/nodes/ReshapeLayerNode.h | 54 | ||||
-rw-r--r-- | arm_compute/graph2/nodes/SoftmaxLayerNode.h | 59 |
17 files changed, 1042 insertions, 0 deletions
diff --git a/arm_compute/graph2/nodes/ActivationLayerNode.h b/arm_compute/graph2/nodes/ActivationLayerNode.h new file mode 100644 index 0000000000..c3775231a4 --- /dev/null +++ b/arm_compute/graph2/nodes/ActivationLayerNode.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_GRAPH2_ACTIVATION_LAYER_NODE_H__ +#define __ARM_COMPUTE_GRAPH2_ACTIVATION_LAYER_NODE_H__ + +#include "arm_compute/graph2/INode.h" + +namespace arm_compute +{ +namespace graph2 +{ +class ActivationLayerNode final : public INode +{ +public: + /** Constructor + * + * @param[in] info Activation Layer information + */ + ActivationLayerNode(ActivationLayerInfo info); + /** Activation metadata accessor + * + * @return The activation info of the layer + */ + ActivationLayerInfo activation_info() const; + + // Inherited overridden methods: + Status validate() override; + NodeType type() const override; + bool forward_descriptors() override; + TensorDescriptor configure_output(size_t idx) const override; + void accept(INodeVisitor &v) override; + +private: + ActivationLayerInfo _info; +}; +} // namespace graph2 +} // namespace arm_compute +#endif /* __ARM_COMPUTE_GRAPH2_ACTIVATION_LAYER_NODE_H__ */ diff --git a/arm_compute/graph2/nodes/BatchNormalizationLayerNode.h b/arm_compute/graph2/nodes/BatchNormalizationLayerNode.h new file mode 100644 index 0000000000..a521938414 --- /dev/null +++ b/arm_compute/graph2/nodes/BatchNormalizationLayerNode.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_GRAPH2_BATCH_NORMALIZATION_LAYER_NODE_H__ +#define __ARM_COMPUTE_GRAPH2_BATCH_NORMALIZATION_LAYER_NODE_H__ + +#include "arm_compute/graph2/INode.h" + +namespace arm_compute +{ +namespace graph2 +{ +class BatchNormalizationLayerNode final : public INode +{ +public: + /** Constructor + * + * @param[in] epsilon (Optional) Epsilon parameter. Defaults to 1.f + * @param[in] fused_activation (Optional) Fused activation layer. Disabled if not specified + */ + BatchNormalizationLayerNode(float epsilon = 1.f, ActivationLayerInfo fused_activation = ActivationLayerInfo()); + /** Epsilon parameter accessor + * + * @return Epsilon parameter + */ + float epsilon() const; + /** Returns fused activation + * + * @return Fused activation + */ + ActivationLayerInfo fused_activation() const; + /** Sets fused activation + * + * @param[in] fused_activation Fused activation to set + */ + void set_fused_activation(ActivationLayerInfo fused_activation); + + // Inherited overridden methods: + Status validate() override; + NodeType type() const override; + bool forward_descriptors() override; + TensorDescriptor configure_output(size_t idx) const override; + void accept(INodeVisitor &v) override; + +private: + float _epsilon; + ActivationLayerInfo _fused_activation; +}; +} // namespace graph2 +} // namespace arm_compute +#endif /* __ARM_COMPUTE_GRAPH2_BATCH_NORMALIZATION_LAYER_NODE_H__ */ diff --git a/arm_compute/graph2/nodes/ConstNode.h b/arm_compute/graph2/nodes/ConstNode.h new file mode 100644 index 0000000000..73a2246498 --- /dev/null +++ b/arm_compute/graph2/nodes/ConstNode.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_GRAPH2_CONST_NODE_H__ +#define __ARM_COMPUTE_GRAPH2_CONST_NODE_H__ + +#include "arm_compute/graph2/INode.h" + +namespace arm_compute +{ +namespace graph2 +{ +class ConstNode final : public INode +{ +public: + /** Constructor + * + * @param[in] desc Tensor descriptor + */ + ConstNode(TensorDescriptor desc); + + // Inherited overridden methods: + Status validate() override; + NodeType type() const override; + bool forward_descriptors() override; + TensorDescriptor configure_output(size_t idx) const override; + void accept(INodeVisitor &v) override; + +private: + TensorDescriptor _desc; +}; +} // namespace graph2 +} // namespace arm_compute +#endif /* __ARM_COMPUTE_GRAPH2_CONST_NODE_H__ */ diff --git a/arm_compute/graph2/nodes/ConvolutionLayerNode.h b/arm_compute/graph2/nodes/ConvolutionLayerNode.h new file mode 100644 index 0000000000..1af344ea13 --- /dev/null +++ b/arm_compute/graph2/nodes/ConvolutionLayerNode.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_GRAPH2_CONVOLUTION_LAYER_NODE_H__ +#define __ARM_COMPUTE_GRAPH2_CONVOLUTION_LAYER_NODE_H__ + +#include "arm_compute/graph2/INode.h" + +namespace arm_compute +{ +namespace graph2 +{ +class ConvolutionLayerNode final : public INode +{ +public: + /** Constructor + * + * @param[in] info Convolution layer attributes + * @param[in] method (Optional) Convolution method to use + */ + ConvolutionLayerNode(PadStrideInfo info, ConvolutionMethod method = ConvolutionMethod::DEFAULT); + /** Sets the convolution layer method to use + * + * @param[in] method Method to use for convolution + */ + void set_convolution_method(ConvolutionMethod method); + /** Convolution layer method accessor + * + * @note This is an indication on which convolution layer implementation to use, + * if it fails to be created the library's heuristic approach will be used + * + * @return Convolution layer method do be used by the node + */ + ConvolutionMethod convolution_method() const; + /** Convolution metadata accessor + * + * @return Convolution information + */ + PadStrideInfo convolution_info() const; + /** Computes convolution output shape + * + * @param[in] input_shape Input shape + * @param[in] weights_shape Weights shape + * @param[in] info Convolution operation attributes + * + * @return Output shape + */ + static TensorShape compute_output_shape(TensorShape input_shape, TensorShape weights_shape, PadStrideInfo info); + + // Inherited overridden methods: + Status validate() override; + NodeType type() const override; + bool forward_descriptors() override; + TensorDescriptor configure_output(size_t idx) const override; + void accept(INodeVisitor &v) override; + +private: + PadStrideInfo _info; + ConvolutionMethod _method; +}; +} // namespace graph2 +} // namespace arm_compute +#endif /* __ARM_COMPUTE_GRAPH2_CONVOLUTION_LAYER_NODE_H__ */ diff --git a/arm_compute/graph2/nodes/DepthConcatenateLayerNode.h b/arm_compute/graph2/nodes/DepthConcatenateLayerNode.h new file mode 100644 index 0000000000..617b9842fb --- /dev/null +++ b/arm_compute/graph2/nodes/DepthConcatenateLayerNode.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_GRAPH2_DEPTH_CONCATENATE_LAYER_NODE_H__ +#define __ARM_COMPUTE_GRAPH2_DEPTH_CONCATENATE_LAYER_NODE_H__ + +#include "arm_compute/graph2/INode.h" + +namespace arm_compute +{ +namespace graph2 +{ +class DepthConcatenateLayerNode final : public INode +{ +public: + /** Constructor + * + * @param[in] total_nodes Number of nodes that will get concatenated + */ + DepthConcatenateLayerNode(unsigned int total_nodes); + /** Computes depth concatenations output shape + * + * @param input_shapes Shapes of the inputs + * + * @return Expected output shape + */ + static TensorShape compute_output_shape(const std::vector<TensorShape> &input_shapes); + /** Disables or not the depth concatenate node + * + * @warning This is used when depth concatenate is performed with sub-tensors, + * where this node is used as a placeholder. + * + * @param[in] is_enabled If true a backend function is created to perform the depth concatenation (involves copying), + * while if false, no function is created and we assume that subtensors are properly set to simulate + * a no copy operation. + */ + void set_enabled(bool is_enabled); + /** Enabled parameter accessor + * + * @return True if a backend function is to be created else false + */ + bool is_enabled() const; + + // Inherited overridden methods: + Status validate() override; + NodeType type() const override; + bool forward_descriptors() override; + TensorDescriptor configure_output(size_t idx) const override; + void accept(INodeVisitor &v) override; + +private: + unsigned int _total_nodes; + bool _is_enabled; +}; +} // namespace graph2 +} // namespace arm_compute +#endif /* __ARM_COMPUTE_GRAPH2_DEPTH_CONCATENATE_LAYER_NODE_H__ */ diff --git a/arm_compute/graph2/nodes/DepthwiseConvolutionLayerNode.h b/arm_compute/graph2/nodes/DepthwiseConvolutionLayerNode.h new file mode 100644 index 0000000000..1b05edf4dc --- /dev/null +++ b/arm_compute/graph2/nodes/DepthwiseConvolutionLayerNode.h @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_GRAPH2_DEPTHWISE_CONVOLUTION_LAYER_NODE_H__ +#define __ARM_COMPUTE_GRAPH2_DEPTHWISE_CONVOLUTION_LAYER_NODE_H__ + +#include "arm_compute/graph2/INode.h" + +namespace arm_compute +{ +namespace graph2 +{ +class DepthwiseConvolutionLayerNode final : public INode +{ +public: + /** Constructor + * + * @param[in] info Convolution layer attributes + * @param[in] method Depthwise convolution method to use + */ + DepthwiseConvolutionLayerNode(PadStrideInfo info, DepthwiseConvolutionMethod method = DepthwiseConvolutionMethod::DEFAULT); + /** Sets the depthwise convolution method to use + * + * @param[in] method Depthwise convolution method to use + */ + void set_depthwise_convolution_method(DepthwiseConvolutionMethod method); + /** Depthwise convolution layer method accessor + * + * @note This is an indication on which depthwise implementation to use, + * if it fails to be created the generic approach will be used + * + * @return Depthwise convolution layer method do be used by the node + */ + DepthwiseConvolutionMethod depthwise_convolution_method() const; + /** Convolution metadata accessor + * + * @return Convolution information + */ + PadStrideInfo convolution_info() const; + /** Computes depthwise convolution output shape + * + * @param[in] input_shape Input shape + * @param[in] weights_shape Weights shape + * @param[in] info Convolution operation attributes + * + * @return Output shape + */ + static TensorShape compute_output_shape(TensorShape input_shape, TensorShape weights_shape, PadStrideInfo info); + + // Inherited overridden methods: + Status validate() override; + NodeType type() const override; + bool forward_descriptors() override; + TensorDescriptor configure_output(size_t idx) const override; + void accept(INodeVisitor &v) override; + +private: + PadStrideInfo _info; + DepthwiseConvolutionMethod _method; +}; +} // namespace graph2 +} // namespace arm_compute +#endif /* __ARM_COMPUTE_GRAPH2_DEPTHWISE_CONVOLUTION_LAYER_NODE_H__ */ diff --git a/arm_compute/graph2/nodes/EltwiseLayerNode.h b/arm_compute/graph2/nodes/EltwiseLayerNode.h new file mode 100644 index 0000000000..2b217decff --- /dev/null +++ b/arm_compute/graph2/nodes/EltwiseLayerNode.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_GRAPH2_ELTWISE_LAYER_NODE_H__ +#define __ARM_COMPUTE_GRAPH2_ELTWISE_LAYER_NODE_H__ + +#include "arm_compute/graph2/INode.h" + +namespace arm_compute +{ +namespace graph2 +{ +class EltwiseLayerNode final : public INode +{ +public: + /** Constructor + * + * @param[in] op Element-wise operation to perform + */ + EltwiseLayerNode(EltwiseOperation op); + /** Eltwise operation accessor + * + * @return Eltwise operation that is to be performed by the node + */ + EltwiseOperation eltwise_operation() const; + + // Inherited overridden methods: + Status validate() override; + NodeType type() const override; + bool forward_descriptors() override; + TensorDescriptor configure_output(size_t idx) const override; + void accept(INodeVisitor &v) override; + +private: + EltwiseOperation _op; +}; +} // namespace graph2 +} // namespace arm_compute +#endif /* __ARM_COMPUTE_GRAPH2_ELTWISE_LAYER_NODE_H__ */ diff --git a/arm_compute/graph2/nodes/FlattenLayerNode.h b/arm_compute/graph2/nodes/FlattenLayerNode.h new file mode 100644 index 0000000000..de601f5f4e --- /dev/null +++ b/arm_compute/graph2/nodes/FlattenLayerNode.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_GRAPH2_FLATTEN_LAYER_NODE_H__ +#define __ARM_COMPUTE_GRAPH2_FLATTEN_LAYER_NODE_H__ + +#include "arm_compute/graph2/INode.h" + +namespace arm_compute +{ +namespace graph2 +{ +class FlattenLayerNode final : public INode +{ +public: + /** Default Constructor */ + FlattenLayerNode(); + + // Inherited overridden methods: + Status validate() override; + NodeType type() const override; + bool forward_descriptors() override; + TensorDescriptor configure_output(size_t idx) const override; + void accept(INodeVisitor &v) override; +}; +} // namespace graph2 +} // namespace arm_compute +#endif /* __ARM_COMPUTE_GRAPH2_FLATTEN_LAYER_NODE_H__ */ diff --git a/arm_compute/graph2/nodes/FullyConnectedLayerNode.h b/arm_compute/graph2/nodes/FullyConnectedLayerNode.h new file mode 100644 index 0000000000..836f20fdb3 --- /dev/null +++ b/arm_compute/graph2/nodes/FullyConnectedLayerNode.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_GRAPH2_FULLY_CONNECTED_LAYER_NODE_H__ +#define __ARM_COMPUTE_GRAPH2_FULLY_CONNECTED_LAYER_NODE_H__ + +#include "arm_compute/graph2/INode.h" + +namespace arm_compute +{ +namespace graph2 +{ +class FullyConnectedLayerNode final : public INode +{ +public: + /** Constructor + * + * @param[in] num_outputs Number of neurons in the layer + */ + FullyConnectedLayerNode(unsigned int num_outputs); + /** Computes weights shape + * + * @warning Works for inputs with 1D batch space + * + * @param[in] input_shape Input shape + * @param[in] num_outputs Number of output neurons + * + * @return Weights shape + */ + static TensorShape compute_weights_shape(TensorShape input_shape, unsigned int num_outputs); + /** Computes fully connected layer output shape + * + * @warning Works for inputs with 1D batch space + * + * @param[in] input_shape Input shape + * @param[in] num_outputs Number of output neurons + * + * @return Output shape + */ + static TensorShape compute_output_shape(TensorShape input_shape, unsigned int num_outputs); + + // Inherited overridden methods: + Status validate() override; + NodeType type() const override; + bool forward_descriptors() override; + TensorDescriptor configure_output(size_t idx) const override; + void accept(INodeVisitor &v) override; + +private: + unsigned int _num_outputs; +}; +} // namespace graph2 +} // namespace arm_compute +#endif /* __ARM_COMPUTE_GRAPH2_FULLY_CONNECTED_LAYER_NODE_H__ */ diff --git a/arm_compute/graph2/nodes/InputNode.h b/arm_compute/graph2/nodes/InputNode.h new file mode 100644 index 0000000000..2cad6f8fc6 --- /dev/null +++ b/arm_compute/graph2/nodes/InputNode.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_GRAPH2_INPUT_NODE_H__ +#define __ARM_COMPUTE_GRAPH2_INPUT_NODE_H__ + +#include "arm_compute/graph2/INode.h" + +namespace arm_compute +{ +namespace graph2 +{ +class InputNode final : public INode +{ +public: + /** Constructor + * + * @param[in] desc Tensor descriptor + */ + InputNode(TensorDescriptor desc); + + // Inherited overridden methods: + Status validate() override; + NodeType type() const override; + bool forward_descriptors() override; + TensorDescriptor configure_output(size_t idx) const override; + void accept(INodeVisitor &v) override; + +private: + TensorDescriptor _desc; +}; +} // namespace graph2 +} // namespace arm_compute +#endif /* __ARM_COMPUTE_GRAPH2_INPUT_NODE_H__ */ diff --git a/arm_compute/graph2/nodes/Nodes.h b/arm_compute/graph2/nodes/Nodes.h new file mode 100644 index 0000000000..8201361304 --- /dev/null +++ b/arm_compute/graph2/nodes/Nodes.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_GRAPH2_NODES_H__ +#define __ARM_COMPUTE_GRAPH2_NODES_H__ + +#include "arm_compute/graph2/nodes/ActivationLayerNode.h" +#include "arm_compute/graph2/nodes/BatchNormalizationLayerNode.h" +#include "arm_compute/graph2/nodes/ConstNode.h" +#include "arm_compute/graph2/nodes/ConvolutionLayerNode.h" +#include "arm_compute/graph2/nodes/DepthConcatenateLayerNode.h" +#include "arm_compute/graph2/nodes/DepthwiseConvolutionLayerNode.h" +#include "arm_compute/graph2/nodes/EltwiseLayerNode.h" +#include "arm_compute/graph2/nodes/FlattenLayerNode.h" +#include "arm_compute/graph2/nodes/FullyConnectedLayerNode.h" +#include "arm_compute/graph2/nodes/InputNode.h" +#include "arm_compute/graph2/nodes/NormalizationLayerNode.h" +#include "arm_compute/graph2/nodes/OutputNode.h" +#include "arm_compute/graph2/nodes/PoolingLayerNode.h" +#include "arm_compute/graph2/nodes/ReshapeLayerNode.h" +#include "arm_compute/graph2/nodes/SoftmaxLayerNode.h" + +#endif /* __ARM_COMPUTE_GRAPH2_NODES_H__ */ diff --git a/arm_compute/graph2/nodes/NodesFwd.h b/arm_compute/graph2/nodes/NodesFwd.h new file mode 100644 index 0000000000..03ca65e056 --- /dev/null +++ b/arm_compute/graph2/nodes/NodesFwd.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_GRAPH2_NODES_FWD_H__ +#define __ARM_COMPUTE_GRAPH2_NODES_FWD_H__ + +namespace arm_compute +{ +namespace graph2 +{ +// Forward declarations +class INode; +class ActivationLayerNode; +class BatchNormalizationLayerNode; +class ConstNode; +class ConvolutionLayerNode; +class DepthConcatenateLayerNode; +class DepthwiseConvolutionLayerNode; +class EltwiseLayerNode; +class FlattenLayerNode; +class FullyConnectedLayerNode; +class InputNode; +class NormalizationLayerNode; +class OutputNode; +class PoolingLayerNode; +class ReshapeLayerNode; +class SoftmaxLayerNode; +} // namespace graph2 +} // namespace arm_compute +#endif /* __ARM_COMPUTE_GRAPH2_NODES_FWD_H__ */ diff --git a/arm_compute/graph2/nodes/NormalizationLayerNode.h b/arm_compute/graph2/nodes/NormalizationLayerNode.h new file mode 100644 index 0000000000..e2816e9352 --- /dev/null +++ b/arm_compute/graph2/nodes/NormalizationLayerNode.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_GRAPH2_NORMALIZATION_LAYER_NODE_H__ +#define __ARM_COMPUTE_GRAPH2_NORMALIZATION_LAYER_NODE_H__ + +#include "arm_compute/graph2/INode.h" + +namespace arm_compute +{ +namespace graph2 +{ +class NormalizationLayerNode final : public INode +{ +public: + /** Constructor + * + * @param[in] norm_info Normalization Layer information + */ + NormalizationLayerNode(NormalizationLayerInfo norm_info); + /** Normalization info accessor + * + * @return Normalization layer info + */ + NormalizationLayerInfo normalization_info() const; + + // Inherited overridden methods: + Status validate() override; + NodeType type() const override; + bool forward_descriptors() override; + TensorDescriptor configure_output(size_t idx) const override; + void accept(INodeVisitor &v) override; + +private: + NormalizationLayerInfo _info; +}; +} // namespace graph2 +} // namespace arm_compute +#endif /* __ARM_COMPUTE_GRAPH2_NORMALIZATION_LAYER_NODE_H__ */ diff --git a/arm_compute/graph2/nodes/OutputNode.h b/arm_compute/graph2/nodes/OutputNode.h new file mode 100644 index 0000000000..94df382d22 --- /dev/null +++ b/arm_compute/graph2/nodes/OutputNode.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_GRAPH2_OUTPUT_NODE_H__ +#define __ARM_COMPUTE_GRAPH2_OUTPUT_NODE_H__ + +#include "arm_compute/graph2/INode.h" + +namespace arm_compute +{ +namespace graph2 +{ +class OutputNode final : public INode +{ +public: + /** Default Constructor */ + OutputNode(); + + // Inherited overridden methods: + Status validate() override; + NodeType type() const override; + bool forward_descriptors() override; + TensorDescriptor configure_output(size_t idx) const override; + void accept(INodeVisitor &v) override; +}; +} // namespace graph2 +} // namespace arm_compute +#endif /* __ARM_COMPUTE_GRAPH2_OUTPUT_NODE_H__ */ diff --git a/arm_compute/graph2/nodes/PoolingLayerNode.h b/arm_compute/graph2/nodes/PoolingLayerNode.h new file mode 100644 index 0000000000..b0c6270999 --- /dev/null +++ b/arm_compute/graph2/nodes/PoolingLayerNode.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_GRAPH2_POOLING_LAYER_NODE_H__ +#define __ARM_COMPUTE_GRAPH2_POOLING_LAYER_NODE_H__ + +#include "arm_compute/graph2/INode.h" + +namespace arm_compute +{ +namespace graph2 +{ +class PoolingLayerNode final : public INode +{ +public: + /** Constructor + * + * @param[in] pool_info Pooling Layer information + */ + PoolingLayerNode(PoolingLayerInfo pool_info); + /** Pooling metadata accessor + * + * @return Pooling Layer info + */ + PoolingLayerInfo pooling_info() const; + /** Computes pooling output shape + * + * @param[in] input_shape Input shape + * @param[in] info Pooling operation attributes + * + * @return Output shape + */ + static TensorShape compute_output_shape(TensorShape input_shape, PoolingLayerInfo info); + + // Inherited overridden methods: + Status validate() override; + NodeType type() const override; + bool forward_descriptors() override; + TensorDescriptor configure_output(size_t idx) const override; + void accept(INodeVisitor &v) override; + +private: + PoolingLayerInfo _info; +}; +} // namespace graph2 +} // namespace arm_compute +#endif /* __ARM_COMPUTE_GRAPH2_POOLING_LAYER_NODE_H__ */ diff --git a/arm_compute/graph2/nodes/ReshapeLayerNode.h b/arm_compute/graph2/nodes/ReshapeLayerNode.h new file mode 100644 index 0000000000..89ee46c8e1 --- /dev/null +++ b/arm_compute/graph2/nodes/ReshapeLayerNode.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_GRAPH2_RESHAPE_LAYER_NODE_H__ +#define __ARM_COMPUTE_GRAPH2_RESHAPE_LAYER_NODE_H__ + +#include "arm_compute/graph2/INode.h" + +namespace arm_compute +{ +namespace graph2 +{ +class ReshapeLayerNode final : public INode +{ +public: + /** Constructor + * + * @param[in] shape Reshaped tensor shape + */ + ReshapeLayerNode(TensorShape shape); + + // Inherited overridden methods: + Status validate() override; + NodeType type() const override; + bool forward_descriptors() override; + TensorDescriptor configure_output(size_t idx) const override; + void accept(INodeVisitor &v) override; + +private: + TensorShape _shape; +}; +} // namespace graph2 +} // namespace arm_compute +#endif /* __ARM_COMPUTE_GRAPH2_RESHAPE_LAYER_NODE_H__ */ diff --git a/arm_compute/graph2/nodes/SoftmaxLayerNode.h b/arm_compute/graph2/nodes/SoftmaxLayerNode.h new file mode 100644 index 0000000000..86decb80d9 --- /dev/null +++ b/arm_compute/graph2/nodes/SoftmaxLayerNode.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2018 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_GRAPH2_SOFTMAX_LAYER_NODE_H__ +#define __ARM_COMPUTE_GRAPH2_SOFTMAX_LAYER_NODE_H__ + +#include "arm_compute/graph2/INode.h" + +namespace arm_compute +{ +namespace graph2 +{ +class SoftmaxLayerNode final : public INode +{ +public: + /** Constructor + * + * @param[in] beta (Optional) Beta parameter. Defaults to 1 + */ + SoftmaxLayerNode(float beta = 1.f); + /** Beta parameter accessor + * + * @return Beta parameter + */ + float beta() const; + + // Inherited overridden methods: + Status validate() override; + NodeType type() const override; + bool forward_descriptors() override; + TensorDescriptor configure_output(size_t idx) const override; + void accept(INodeVisitor &v) override; + +private: + float _beta; +}; +} // namespace graph2 +} // namespace arm_compute +#endif /* __ARM_COMPUTE_GRAPH2_SOFTMAX_LAYER_NODE_H__ */ |