aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/TensorShape.h
diff options
context:
space:
mode:
authorDiego Lopez Recas <Diego.LopezRecas@arm.com>2017-12-18 14:42:56 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:45:00 +0000
commit0021d750d66d199c411df00cdd8308c325f1fef3 (patch)
treeb96e618977442a8aab335c136d369a958998d416 /arm_compute/core/TensorShape.h
parent5b6904b8d9cb5e8a343cde96fd5a8701f44dff90 (diff)
downloadComputeLibrary-0021d750d66d199c411df00cdd8308c325f1fef3.tar.gz
IVGCVSW-863 Broadcast support in CL/NEON Arithmetic Add
Also, added instrumentation to support generic tensor broadcasting for NEON and CL backends. Change-Id: I1bc5747a286e1a4b464c209067581e103d473b9a Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/114201 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core/TensorShape.h')
-rw-r--r--arm_compute/core/TensorShape.h58
1 files changed, 58 insertions, 0 deletions
diff --git a/arm_compute/core/TensorShape.h b/arm_compute/core/TensorShape.h
index 50f1211c18..dc836c98da 100644
--- a/arm_compute/core/TensorShape.h
+++ b/arm_compute/core/TensorShape.h
@@ -26,6 +26,7 @@
#include "arm_compute/core/Dimensions.h"
#include "arm_compute/core/Error.h"
+#include "arm_compute/core/utils/misc/utility.h"
#include <algorithm>
#include <array>
@@ -132,6 +133,19 @@ public:
std::fill(_id.begin() + _num_dimensions, _id.end(), 1);
}
+ /** Return a copy with collapsed dimensions starting from a given point.
+ *
+ * @param[in] start Starting point of collapsing dimensions.
+ *
+ * @return A copy with collapse dimensions starting from start.
+ */
+ TensorShape collapsed_from(size_t start) const
+ {
+ TensorShape copy(*this);
+ copy.collapse(num_dimensions(), start);
+ return copy;
+ }
+
/** Collapses all dimensions to a single linear total size.
*
* @return The total tensor size in terms of elements.
@@ -164,6 +178,50 @@ public:
return std::accumulate(_id.begin(), _id.begin() + dimension, 1, std::multiplies<size_t>());
}
+ /** If shapes are broadcast compatible, return the broadcasted shape.
+ *
+ * Two tensor shapes are broadcast compatible if for each dimension, they're equal or one of them is 1.
+ *
+ * If two shapes are compatible, each dimension in the broadcasted shape is the max of the original dimensions.
+ *
+ * @param[in] shapes Tensor shapes.
+ *
+ * @return The broadcasted shape or an empty shape if the shapes are not broadcast compatible.
+ */
+ template <typename... Shapes>
+ static TensorShape broadcast_shape(const Shapes &... shapes)
+ {
+ TensorShape bc_shape;
+
+ auto broadcast = [&bc_shape](const TensorShape & other)
+ {
+ if(bc_shape.num_dimensions() == 0)
+ {
+ bc_shape = other;
+ }
+ else if(other.num_dimensions() != 0)
+ {
+ for(size_t d = 0; d < TensorShape::num_max_dimensions; ++d)
+ {
+ const size_t dim_min = std::min(bc_shape[d], other[d]);
+ const size_t dim_max = std::max(bc_shape[d], other[d]);
+
+ if((dim_min != 1) && (dim_min != dim_max))
+ {
+ bc_shape = TensorShape{ 0U };
+ break;
+ }
+
+ bc_shape.set(d, dim_max);
+ }
+ }
+ };
+
+ utility::for_each(broadcast, shapes...);
+
+ return bc_shape;
+ }
+
private:
/** Remove trailing dimensions of size 1 from the reported number of dimensions. */
void apply_dimension_correction()