diff options
author | giuros01 <giuseppe.rossini@arm.com> | 2018-09-03 09:53:53 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:54:54 +0000 |
commit | efbf6c8fd54159b26eda43eea7a12fce491ca13a (patch) | |
tree | f24f63d73703ddcb5fe0ea3ccef101660a9eb9a4 /arm_compute/core/utils/misc/ShapeCalculator.h | |
parent | 477531c258801caf3cce44eb3e43df611b42fc6d (diff) | |
download | ComputeLibrary-efbf6c8fd54159b26eda43eea7a12fce491ca13a.tar.gz |
[COMPMID-386] Github: Support SoftmaxLayer on different number of dimensions?
Change-Id: I7422b977538ff29930a90f078badc2edee78af93
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/146638
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'arm_compute/core/utils/misc/ShapeCalculator.h')
-rw-r--r-- | arm_compute/core/utils/misc/ShapeCalculator.h | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index d72547ed07..cb04182c21 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -275,6 +275,38 @@ inline TensorShape compute_flatten_shape(const ITensorInfo *input) return output_shape; } +inline TensorShape compute_softmax_shape(const ITensorInfo *input, size_t axis = 1) +{ + // The output shape will be a 2D version of the input. For instance: + // - [x,y,z] and axis 1 will return [x, y*z] + // - [x,y,z,w] and axis 2 will return [x*y, w*z] + // - [x,y,z,w] and axis 3 will return [x*y*z, w] + TensorShape shape2D = input->tensor_shape(); + + if(axis < input->num_dimensions()) + { + // Collapse from axis onward (this changes the shape) + shape2D.collapse_from(axis); + + // Collapse the rest (collapse is inclusive) + shape2D.collapse(shape2D.num_dimensions() - 1); + } + else + { + // Collapse everything + shape2D.collapse(shape2D.num_dimensions()); + } + + if(axis == 0) + { + // If axis is zero the first dim should be one. Since + // collapse is an inclusive operation we need to shift + shape2D.shift_right(1); + } + + return shape2D; +} + inline TensorShape compute_interleave_custom_shape(const TensorShape &input, const int x_interleave, const int y_interleave) { TensorShape output_shape{ input }; |