aboutsummaryrefslogtreecommitdiff
path: root/arm_compute
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute')
-rw-r--r--arm_compute/core/CL/kernels/CLGEMMMatrixAdditionKernel.h2
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h8
-rw-r--r--arm_compute/runtime/CL/functions/CLGEMM.h10
3 files changed, 14 insertions, 6 deletions
diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixAdditionKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixAdditionKernel.h
index dc84a40ca8..3755d943c5 100644
--- a/arm_compute/core/CL/kernels/CLGEMMMatrixAdditionKernel.h
+++ b/arm_compute/core/CL/kernels/CLGEMMMatrixAdditionKernel.h
@@ -65,7 +65,7 @@ public:
*
* @return a status
*/
- static Status validate(const ITensorInfo *input, const ITensorInfo *output, const float beta);
+ static Status validate(const ITensorInfo *input, const ITensorInfo *output, float beta);
// Inherited methods overridden:
void run(const Window &window, cl::CommandQueue &queue) override;
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index 9543d989b8..30d3f9bb62 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -358,6 +358,14 @@ inline TensorShape compute_rnn_shape(const ITensorInfo *input, const unsigned in
return output_shape;
}
+inline TensorShape compute_mm_shape(const ITensorInfo &input0, const ITensorInfo &input1, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info)
+{
+ TensorShape tensor_shape{ input0.tensor_shape() };
+ tensor_shape.set(0, is_interleaved_transposed ? reshape_info.n() : input1.dimension(0));
+ tensor_shape.set(1, is_interleaved_transposed ? reshape_info.m() : input0.dimension(1));
+
+ return tensor_shape;
+}
} // namespace shape_calculator
} // namespace misc
} // namespace arm_compute
diff --git a/arm_compute/runtime/CL/functions/CLGEMM.h b/arm_compute/runtime/CL/functions/CLGEMM.h
index 60ff32c6fa..7a145cc183 100644
--- a/arm_compute/runtime/CL/functions/CLGEMM.h
+++ b/arm_compute/runtime/CL/functions/CLGEMM.h
@@ -85,10 +85,10 @@ public:
void configure(const ICLTensor *a, const ICLTensor *b, const ICLTensor *c, ICLTensor *output, float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo());
/** Static function to check if given info will lead to a valid configuration of @ref CLGEMM.
*
- * @param[in] a First input tensor (Matrix or Vector A). Data types supported: QS8/QS16/F16/F32
- * @param[in] b Second input tensor (Matrix B). Data type supported: same as @p a.
- * @param[in] c Third input tensor (Matrix C). It can be a nullptr if just the multiplication between @p a and @p b is needed. Data type supported: same as @p a.
- * @param[out] output Output tensor. Data type supported: same as @p a
+ * @param[in] a First input tensor info (Matrix or Vector A). Data types supported: QS8/QS16/F16/F32
+ * @param[in] b Second input tensor info (Matrix B). Data type supported: same as @p a.
+ * @param[in] c Third input tensor info (Matrix C). It can be a nullptr if just the multiplication between @p a and @p b is needed. Data type supported: same as @p a.
+ * @param[out] output Output tensor info. Data type supported: same as @p a
* @param[in] alpha Weight of the matrix product
* @param[in] beta Weight of matrix C
* @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and
@@ -96,7 +96,7 @@ public:
*
* @return a status
*/
- static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ICLTensor *c, const ITensorInfo *output, const float alpha, const float beta, const GEMMInfo &gemm_info = GEMMInfo());
+ static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo());
// Inherited methods overridden:
void run() override;