aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/utils
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2019-06-26 17:18:11 +0100
committerGian Marco Iodice <gianmarco.iodice@arm.com>2019-06-28 13:51:34 +0000
commit7026b303d636e7639f8877ae8d5eff54f39c1121 (patch)
treed30d5969706dc688d84e276132c02cdd4c046e09 /arm_compute/core/utils
parent49f83497526816932e76e9e5f90a1799d50f15ba (diff)
downloadComputeLibrary-7026b303d636e7639f8877ae8d5eff54f39c1121.tar.gz
COMPMID-1979: Fuse Activation Function in CLGEMM - part 1
Implementing a new struct to contains the information for the OpenCL GEMM kernels Change-Id: I6c641c312f9c3b025a7c69dd0df3b730d2d2c2cb Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-on: https://review.mlplatform.org/c/1434 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
Diffstat (limited to 'arm_compute/core/utils')
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h40
1 files changed, 40 insertions, 0 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index 7eab17ba11..010501454f 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -26,6 +26,7 @@
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/ITensorInfo.h"
+#include "arm_compute/core/KernelDescriptors.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/utils/helpers/tensor_transform.h"
@@ -851,6 +852,8 @@ inline TensorShape compute_mm_shape(const ITensorInfo &input0, const ITensorInfo
/** Calculate the matrix multiplication output shape of two tensors
*
+ * @note Deprecated. Remove when GEMMReshapeInfo is not used anymore by any other kernels
+ *
* @param[in] input0 First input tensor info
* @param[in] input1 Second input tensor info
* @param[in] gemm_info GEMM reshape info
@@ -888,6 +891,43 @@ inline TensorShape compute_mm_shape(const ITensorInfo &input0, const ITensorInfo
/** Calculate the matrix multiplication output shape of two tensors
*
+ * @param[in] input0 First input tensor info
+ * @param[in] input1 Second input tensor info
+ * @param[in] gemm_info GEMM kernel info used to retrieve the original dimensions of the input matrices
+ *
+ * @return the calculated shape
+ */
+inline TensorShape compute_mm_shape(const ITensorInfo &input0, const ITensorInfo &input1, const GEMMKernelInfo &gemm_info)
+{
+ ARM_COMPUTE_ERROR_ON_MSG(input0.num_dimensions() > 4, "The number of dimensions for the matrix A must be <= 4");
+
+ const bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d;
+ const bool reinterpret_output_as_3d = gemm_info.depth_output_gemm3d != 0;
+ const unsigned int depth_output_gemm3d = reinterpret_output_as_3d ? gemm_info.depth_output_gemm3d : 1;
+
+ TensorShape output_shape{ input0.tensor_shape() };
+
+ if(!reinterpret_input_as_3d && !reinterpret_output_as_3d)
+ {
+ output_shape.set(0, gemm_info.n);
+ output_shape.set(1, gemm_info.m);
+ }
+ else
+ {
+ // If the output of GEMM has to be reinterpreted as 3D, the number of input0 rows (M) is obtained collapsing the second and third
+ // dimension of the output tensor
+ const unsigned int batch_size = reinterpret_input_as_3d ? input0.tensor_shape()[3] : input0.tensor_shape()[2];
+ output_shape.set(0, gemm_info.n);
+ output_shape.set(1, gemm_info.m / depth_output_gemm3d);
+ output_shape.set(2, reinterpret_output_as_3d ? depth_output_gemm3d : batch_size);
+ output_shape.set(3, reinterpret_output_as_3d ? batch_size : 1);
+ }
+
+ return output_shape;
+}
+
+/** Calculate the matrix multiplication output shape of two tensors
+ *
* @param[in] input Input tensor info
* @param[in] gemm_3d_depth (Optional) GEMM 3d depth
* @param[in] batch_size_on_z (Optional) True if batch size is on z axis