aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/ITensorInfo.h
diff options
context:
space:
mode:
authorDiego Lopez Recas <Diego.LopezRecas@arm.com>2017-12-18 14:42:56 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:45:00 +0000
commit0021d750d66d199c411df00cdd8308c325f1fef3 (patch)
treeb96e618977442a8aab335c136d369a958998d416 /arm_compute/core/ITensorInfo.h
parent5b6904b8d9cb5e8a343cde96fd5a8701f44dff90 (diff)
downloadComputeLibrary-0021d750d66d199c411df00cdd8308c325f1fef3.tar.gz
IVGCVSW-863 Broadcast support in CL/NEON Arithmetic Add
Also, added instrumentation to support generic tensor broadcasting for NEON and CL backends. Change-Id: I1bc5747a286e1a4b464c209067581e103d473b9a Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/114201 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core/ITensorInfo.h')
-rw-r--r--arm_compute/core/ITensorInfo.h46
1 files changed, 46 insertions, 0 deletions
diff --git a/arm_compute/core/ITensorInfo.h b/arm_compute/core/ITensorInfo.h
index 9112f3ea18..b5677dffd6 100644
--- a/arm_compute/core/ITensorInfo.h
+++ b/arm_compute/core/ITensorInfo.h
@@ -30,6 +30,7 @@
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/utils/misc/ICloneable.h"
+#include "arm_compute/core/utils/misc/utility.h"
#include <cstddef>
@@ -221,6 +222,51 @@ public:
* @return A QuantizationInfo containing the scale and offset.
*/
virtual QuantizationInfo quantization_info() const = 0;
+
+ /** If infos are broadcast compatible tensor info's, return the broadcasted shape and the intersection of
+ * the broadcasted valid regions of the tensors.
+ *
+ * Two tensor info's are broadcast compatible if their shapes are broadcast compatible.
+ *
+ * Two tensor shapes are broadcast compatible if for each dimension, they're equal or one of them is 1.
+ *
+ * If two shapes are compatible, each dimension in the broadcasted shape is the max of the original dimensions.
+ *
+ * @param[in] infos Tensor info's.
+ *
+ * @return The broadcasted shape and valid region, or an empty shape and valid region if the info's are
+ * not broadcast compatible.
+ */
+ template <typename... Infos>
+ static std::pair<TensorShape, ValidRegion> broadcast_shape_and_valid_region(const Infos &... infos)
+ {
+ TensorShape bc_shape = TensorShape::broadcast_shape(infos.tensor_shape()...);
+ ValidRegion bc_valid_region{ Coordinates(), bc_shape };
+
+ auto broadcast_valid_region = [&bc_valid_region](const ITensorInfo & info)
+ {
+ if(info.num_dimensions() != 0)
+ {
+ for(size_t d = 0; d < bc_valid_region.shape.num_dimensions(); ++d)
+ {
+ const bool is_broadcast = (info.tensor_shape()[d] == 1);
+
+ const int anchor_max = std::max(bc_valid_region.anchor[d], info.valid_region().anchor[d]);
+ const size_t valid_min = std::min(bc_valid_region.shape[d], info.valid_region().shape[d]);
+
+ if(!is_broadcast || (valid_min == 0))
+ {
+ bc_valid_region.anchor.set(d, anchor_max);
+ bc_valid_region.shape.set(d, valid_min);
+ }
+ }
+ }
+ };
+
+ utility::for_each(broadcast_valid_region, infos...);
+
+ return std::pair<TensorShape, ValidRegion>(bc_shape, bc_valid_region);
+ }
};
}
#endif /*__ARM_COMPUTE_TENSORINFO_H__ */