aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-06-21 18:43:12 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-07-05 15:30:24 +0000
commit37d080f2f11cfd734104b76512e1fb191486216e (patch)
treed5df067c826aacc0676e7e9557a54b61a9a3b7eb /arm_compute/core/NEON
parent11de30da8a9f79943255ddba7bb70a66b076673b (diff)
downloadComputeLibrary-37d080f2f11cfd734104b76512e1fb191486216e.tar.gz
COMPMID-2378: Sanitize GEMM configuration for NEON
Change-Id: I7859b82b2059e14685f8792424648ac5eacd67f1 Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-on: https://review.mlplatform.org/c/1418 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core/NEON')
-rw-r--r--arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h18
-rw-r--r--arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h55
-rw-r--r--arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h38
3 files changed, 62 insertions, 49 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h b/arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h
index 63178a738a..352f73d7f1 100644
--- a/arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h
+++ b/arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -45,7 +45,7 @@ public:
unsigned int multis{ 0 }; /**< Number of "multi" GEMMs (unique A, B and C). */
};
- static Params extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c);
+ static Params extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c, const GEMMInfo &gemm_info);
/** Constructor */
INEGEMMWrapperKernel();
@@ -61,13 +61,14 @@ public:
*
* @note The input and output tensor must have the same dimensions
*
- * @param[in] a Input tensor (Matrix A)
- * @param[in] b Input tensor (Matrix B)
- * @param[out] c Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
- * @param[in] alpha Scalar multiplier to apply to AB matrix product.
- * @param[in] beta Scalar multiplier to apply to input C matrix before adding product.
+ * @param[in] a Input tensor (Matrix A)
+ * @param[in] b Input tensor (Matrix B)
+ * @param[out] c Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
+ * @param[in] alpha Scalar multiplier to apply to AB matrix product.
+ * @param[in] beta Scalar multiplier to apply to input C matrix before adding product.
+ * @param[in] gemm_info GEMM meta-data
*/
- void configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta);
+ void configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, const GEMMInfo &gemm_info);
// Inherited methods overridden:
void run(const Window &window, const ThreadInfo &info) override;
@@ -95,6 +96,7 @@ protected:
const ITensor *_b;
ITensor *_c;
Params _params;
+ GEMMInfo _gemm_info;
private:
Window _window3d;
diff --git a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h
index e2b849aa3d..40b6f5da39 100644
--- a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h
+++ b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h
@@ -95,31 +95,32 @@ class NEGEMMInterleavedMatrixMultiplyWrapperTemplate : public NEGEMMInterleavedM
public:
/** Configure the matrix multiplication: C = alpha * A * B + beta * C
*
- * @param[in] prepared_a Already reshaped matrix A.
- * @param[in] transformed_b Already reshaped matrix B.
- * @param[out] tmp_c Temporary buffer to be used to store intermediate results.
- * @param[in,out] c Result matrix C.
- * @param[in] block_walker Window containing iteration information for the M and batch dimensions.
- * @param[in] block_sizes Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes).
- * @param[in] params M, N, K sizes.
- * @param[in] is_pretransposed Is B also pretransposed ?
- * @param[in] alpha Alpha value
- * @param[in] beta Beta value
- * @param[in] max_num_threads Maximum number of threads that might be used for the calculations.
+ * @param[in] prepared_a Already reshaped matrix A.
+ * @param[in] transformed_b Already reshaped matrix B.
+ * @param[out] tmp_c Temporary buffer to be used to store intermediate results.
+ * @param[in,out] c Result matrix C.
+ * @param[in] block_walker Window containing iteration information for the M and batch dimensions.
+ * @param[in] block_sizes Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes).
+ * @param[in] params M, N, K sizes.
+ * @param[in] gemm_info GEMM meta-data
+ * @param[in] alpha Alpha value
+ * @param[in] beta Beta value
+ * @param[in] max_num_threads Maximum number of threads that might be used for the calculations.
*/
void configure(const ITensor *prepared_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c, const Window &block_walker, const BlockSizes &block_sizes,
- const INEGEMMWrapperKernel::Params &params, bool b_is_pretransposed, float alpha, float beta, unsigned int max_num_threads)
+ const INEGEMMWrapperKernel::Params &params, const GEMMInfo &gemm_info, float alpha, float beta, unsigned int max_num_threads)
{
- _prepared_a = prepared_a;
- _transformed_b = transformed_b;
- _tmp_c = tmp_c;
- _c = c;
- _block_walker = block_walker;
- _block_sizes = block_sizes;
- _params = params;
- _b_is_pretransposed = b_is_pretransposed;
- _alpha = alpha;
- _beta = beta;
+ _prepared_a = prepared_a;
+ _transformed_b = transformed_b;
+ _tmp_c = tmp_c;
+ _c = c;
+ _block_walker = block_walker;
+ _block_sizes = block_sizes;
+ _params = params;
+ _b_is_pretransposed = gemm_info.pretranpose_B();
+ _reinterpret_c_as_3d = gemm_info.depth_output_gemm3d() != 0;
+ _alpha = alpha;
+ _beta = beta;
auto_init_if_empty(*_tmp_c->info(), c->info()->clone()->set_tensor_shape(TensorShape{ _block_sizes.x_block * strategy::out_height(), max_num_threads }));
}
@@ -133,6 +134,14 @@ public:
TensorAccessor<typename strategy::result_type> c(*_c);
TensorAccessor<typename strategy::result_type> tmp_c(*_tmp_c);
+ // Handle 3d output re-interpretation
+ if(_reinterpret_c_as_3d)
+ {
+ Strides c_strides_as_3d = _c->info()->strides_in_bytes();
+ c_strides_as_3d.remove(Window::DimZ);
+ c.set_strides(c_strides_as_3d);
+ }
+
int prev_batch = -1;
typename strategy::operand_type *a_ptr = nullptr;
auto window_iterator = arm_compute::create_window_iterator(batch_window, start_offset, end_offset, [&](const Coordinates & id)
@@ -216,9 +225,9 @@ private:
INEGEMMWrapperKernel::Params _params{};
Window _block_walker{};
bool _b_is_pretransposed{ false };
+ bool _reinterpret_c_as_3d{ false };
typename strategy::result_type _alpha{};
typename strategy::result_type _beta{};
};
-
} // namespace arm_compute
#endif /* __ARM_COMPUTE_NEGEMMINTERLEAVEDMATRIXMULTIPLYWRAPPER_H__ */
diff --git a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h
index 5d6cd02398..b18d327339 100644
--- a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h
+++ b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h
@@ -87,20 +87,22 @@ class NEGEMMInterleavedTransformAWrapperTemplate : public NEGEMMInterleavedTrans
public:
/** Configure the reshape A routine.
*
- * @param[in] a Input matrix A.
- * @param[out] transformed_a Reshaped matrix A.
- * @param[in] transpose_a Also transpose A ?
- * @param[in] block_walker Window representing the layout of the matrix's blocks
- * @param[in] params M, N, K sizes.
+ * @param[in] a Input matrix A.
+ * @param[out] transformed_a Reshaped matrix A.
+ * @param[in] transpose_a Also transpose A ?
+ * @param[in] reinterpret_a_as_3d Re-interpret as 3D ?
+ * @param[in] block_walker Window representing the layout of the matrix's blocks
+ * @param[in] params M, N, K sizes.
*/
- void configure(const ITensor *a, ITensor *transformed_a, bool transpose_a, const Window &block_walker, const INEGEMMWrapperKernel::Params &params)
+ void configure(const ITensor *a, ITensor *transformed_a, bool transpose_a, bool reinterpret_a_as_3d, const Window &block_walker, const INEGEMMWrapperKernel::Params &params)
{
- _a = a;
- _transformed_a = transformed_a;
- _transpose_a = transpose_a;
- _Ksize = params.K;
- _Msize = params.M;
- _k_multi_window = block_walker.shift_dimensions(1); // block_walker contains (M,K,Multi) --> shift by 1 to get rid of the "M" dimension
+ _a = a;
+ _transformed_a = transformed_a;
+ _transpose_a = transpose_a;
+ _reinterpret_a_as_3d = reinterpret_a_as_3d;
+ _Ksize = params.K;
+ _Msize = params.M;
+ _k_multi_window = block_walker.shift_dimensions(1); // block_walker contains (M,K,Multi) --> shift by 1 to get rid of the "M" dimension
}
// Inherited methods overridden:
@@ -110,12 +112,12 @@ public:
TensorAccessor<typename strategy::operand_type> a(*_a);
TensorAccessor<typename strategy::operand_type> transformed_a(*_transformed_a);
- if(_a->info()->data_layout() == DataLayout::NHWC)
+ // Handle 3d input re-interpretation
+ if(_reinterpret_a_as_3d)
{
- // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is
- // the relevant multiple of the row stride.
- const size_t nhwc_batch_stride = _a->info()->strides_in_bytes().y() * _Msize;
- a.set_stride(2, nhwc_batch_stride);
+ Strides a_strides_as_3d = _a->info()->strides_in_bytes();
+ a_strides_as_3d.remove(Window::DimZ);
+ a.set_strides(a_strides_as_3d);
}
unsigned int last_m = 0;
@@ -164,8 +166,8 @@ private:
unsigned int _Msize{ 0 };
unsigned int _Ksize{ 0 };
bool _transpose_a{ false };
+ bool _reinterpret_a_as_3d{ false };
Window _k_multi_window{};
};
-
} // namespace arm_compute
#endif /* __ARM_COMPUTE_NEGEMMINTERLEAVEDTRANSFORMAWRAPPER_H__ */