aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2018-05-08 12:01:57 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:51:50 +0000
commit750641dd6aab1e5e62d1875b97b230312bb87959 (patch)
treeb3b180c07d7769cb32a6f35b6d0df2384a4638b0 /arm_compute/core
parentaa3240d3e2a575c436ec60ea0a31e8375d997425 (diff)
downloadComputeLibrary-750641dd6aab1e5e62d1875b97b230312bb87959.tar.gz
COMPMID-1052 - Rework validate method in CLGEMM
Change-Id: Iece5bd6478b5fac5164abff30c1e63e8a77291a9 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/130374 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core')
-rw-r--r--arm_compute/core/CL/kernels/CLGEMMMatrixAdditionKernel.h2
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h8
2 files changed, 9 insertions, 1 deletions
diff --git a/arm_compute/core/CL/kernels/CLGEMMMatrixAdditionKernel.h b/arm_compute/core/CL/kernels/CLGEMMMatrixAdditionKernel.h
index dc84a40ca8..3755d943c5 100644
--- a/arm_compute/core/CL/kernels/CLGEMMMatrixAdditionKernel.h
+++ b/arm_compute/core/CL/kernels/CLGEMMMatrixAdditionKernel.h
@@ -65,7 +65,7 @@ public:
*
* @return a status
*/
- static Status validate(const ITensorInfo *input, const ITensorInfo *output, const float beta);
+ static Status validate(const ITensorInfo *input, const ITensorInfo *output, float beta);
// Inherited methods overridden:
void run(const Window &window, cl::CommandQueue &queue) override;
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index 9543d989b8..30d3f9bb62 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -358,6 +358,14 @@ inline TensorShape compute_rnn_shape(const ITensorInfo *input, const unsigned in
return output_shape;
}
+inline TensorShape compute_mm_shape(const ITensorInfo &input0, const ITensorInfo &input1, bool is_interleaved_transposed, const GEMMReshapeInfo &reshape_info)
+{
+ TensorShape tensor_shape{ input0.tensor_shape() };
+ tensor_shape.set(0, is_interleaved_transposed ? reshape_info.n() : input1.dimension(0));
+ tensor_shape.set(1, is_interleaved_transposed ? reshape_info.m() : input0.dimension(1));
+
+ return tensor_shape;
+}
} // namespace shape_calculator
} // namespace misc
} // namespace arm_compute