aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-06-21 18:43:12 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-07-05 15:30:24 +0000
commit37d080f2f11cfd734104b76512e1fb191486216e (patch)
treed5df067c826aacc0676e7e9557a54b61a9a3b7eb
parent11de30da8a9f79943255ddba7bb70a66b076673b (diff)
downloadComputeLibrary-37d080f2f11cfd734104b76512e1fb191486216e.tar.gz
COMPMID-2378: Sanitize GEMM configuration for NEON
Change-Id: I7859b82b2059e14685f8792424648ac5eacd67f1 Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-on: https://review.mlplatform.org/c/1418 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--SConstruct15
-rw-r--r--arm_compute/core/Dimensions.h23
-rw-r--r--arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h18
-rw-r--r--arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h55
-rw-r--r--arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h38
-rw-r--r--arm_compute/core/Types.h64
-rw-r--r--arm_compute/core/WindowIterator.h11
-rw-r--r--arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h46
-rw-r--r--arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h14
-rw-r--r--src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp4
-rw-r--r--src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp4
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp4
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp6
-rw-r--r--src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp25
-rw-r--r--src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h37
-rw-r--r--src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp18
-rw-r--r--src/runtime/NEON/functions/NEGEMM.cpp10
-rw-r--r--src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp68
-rw-r--r--src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp2
-rw-r--r--src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp9
-rw-r--r--src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp12
27 files changed, 338 insertions, 181 deletions
diff --git a/SConstruct b/SConstruct
index f18b13782c..5fa6cdfc48 100644
--- a/SConstruct
+++ b/SConstruct
@@ -187,18 +187,15 @@ elif env['arch'] == 'arm64-v8a':
env.Append(CXXFLAGS = ['-no-integrated-as'])
elif 'arm64-v8.2-a' in env['arch']:
if env['arch'] == 'arm64-v8.2-a-sve':
- if env['os'] != 'bare_metal':
- print("Only bare metal SVE is supported at the moment")
- Exit(1)
env.Append(CXXFLAGS = ['-march=armv8.2-a+sve+fp16+dotprod'])
else:
env.Append(CXXFLAGS = ['-march=armv8.2-a+fp16']) # explicitly enable fp16 extension otherwise __ARM_FEATURE_FP16_VECTOR_ARITHMETIC is undefined
- if env['os'] == 'linux':
- prefix = "aarch64-linux-gnu-"
- elif env['os'] == 'bare_metal':
- prefix = "aarch64-elf-"
- elif env['os'] == 'android':
- prefix = "aarch64-linux-android-"
+ if env['os'] == 'linux':
+ prefix = "aarch64-linux-gnu-"
+ elif env['os'] == 'bare_metal':
+ prefix = "aarch64-elf-"
+ elif env['os'] == 'android':
+ prefix = "aarch64-linux-android-"
env.Append(CPPDEFINES = ['ARM_COMPUTE_AARCH64_V8_2','NO_DOT_IN_TOOLCHAIN'])
if 'clang++' in cpp_compiler:
env.Append(CXXFLAGS = ['-no-integrated-as'])
diff --git a/arm_compute/core/Dimensions.h b/arm_compute/core/Dimensions.h
index 0a9264f6b0..9c38c60779 100644
--- a/arm_compute/core/Dimensions.h
+++ b/arm_compute/core/Dimensions.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -166,6 +166,27 @@ public:
collapse(num_dimensions() - start, start);
}
+ /** Remove dimension of a given index
+ *
+ * @note If index is greater than the number of dimensions no operation is performed
+ *
+ * @param[in] idx Dimension index to remove
+ */
+ void remove(size_t idx)
+ {
+ ARM_COMPUTE_ERROR_ON(_num_dimensions < 1);
+ if(idx >= _num_dimensions)
+ {
+ return;
+ }
+
+ std::copy(_id.begin() + idx + 1, _id.end(), _id.begin() + idx);
+ _num_dimensions--;
+
+ // Make sure all empty dimensions are filled with 0
+ std::fill(_id.begin() + _num_dimensions, _id.end(), 0);
+ }
+
/** Returns a read/write iterator that points to the first element in the dimension array.
*
* @return an iterator.
diff --git a/arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h b/arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h
index 63178a738a..352f73d7f1 100644
--- a/arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h
+++ b/arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -45,7 +45,7 @@ public:
unsigned int multis{ 0 }; /**< Number of "multi" GEMMs (unique A, B and C). */
};
- static Params extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c);
+ static Params extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c, const GEMMInfo &gemm_info);
/** Constructor */
INEGEMMWrapperKernel();
@@ -61,13 +61,14 @@ public:
*
* @note The input and output tensor must have the same dimensions
*
- * @param[in] a Input tensor (Matrix A)
- * @param[in] b Input tensor (Matrix B)
- * @param[out] c Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
- * @param[in] alpha Scalar multiplier to apply to AB matrix product.
- * @param[in] beta Scalar multiplier to apply to input C matrix before adding product.
+ * @param[in] a Input tensor (Matrix A)
+ * @param[in] b Input tensor (Matrix B)
+ * @param[out] c Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
+ * @param[in] alpha Scalar multiplier to apply to AB matrix product.
+ * @param[in] beta Scalar multiplier to apply to input C matrix before adding product.
+ * @param[in] gemm_info GEMM meta-data
*/
- void configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta);
+ void configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, const GEMMInfo &gemm_info);
// Inherited methods overridden:
void run(const Window &window, const ThreadInfo &info) override;
@@ -95,6 +96,7 @@ protected:
const ITensor *_b;
ITensor *_c;
Params _params;
+ GEMMInfo _gemm_info;
private:
Window _window3d;
diff --git a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h
index e2b849aa3d..40b6f5da39 100644
--- a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h
+++ b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h
@@ -95,31 +95,32 @@ class NEGEMMInterleavedMatrixMultiplyWrapperTemplate : public NEGEMMInterleavedM
public:
/** Configure the matrix multiplication: C = alpha * A * B + beta * C
*
- * @param[in] prepared_a Already reshaped matrix A.
- * @param[in] transformed_b Already reshaped matrix B.
- * @param[out] tmp_c Temporary buffer to be used to store intermediate results.
- * @param[in,out] c Result matrix C.
- * @param[in] block_walker Window containing iteration information for the M and batch dimensions.
- * @param[in] block_sizes Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes).
- * @param[in] params M, N, K sizes.
- * @param[in] is_pretransposed Is B also pretransposed ?
- * @param[in] alpha Alpha value
- * @param[in] beta Beta value
- * @param[in] max_num_threads Maximum number of threads that might be used for the calculations.
+ * @param[in] prepared_a Already reshaped matrix A.
+ * @param[in] transformed_b Already reshaped matrix B.
+ * @param[out] tmp_c Temporary buffer to be used to store intermediate results.
+ * @param[in,out] c Result matrix C.
+ * @param[in] block_walker Window containing iteration information for the M and batch dimensions.
+ * @param[in] block_sizes Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes).
+ * @param[in] params M, N, K sizes.
+ * @param[in] gemm_info GEMM meta-data
+ * @param[in] alpha Alpha value
+ * @param[in] beta Beta value
+ * @param[in] max_num_threads Maximum number of threads that might be used for the calculations.
*/
void configure(const ITensor *prepared_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c, const Window &block_walker, const BlockSizes &block_sizes,
- const INEGEMMWrapperKernel::Params &params, bool b_is_pretransposed, float alpha, float beta, unsigned int max_num_threads)
+ const INEGEMMWrapperKernel::Params &params, const GEMMInfo &gemm_info, float alpha, float beta, unsigned int max_num_threads)
{
- _prepared_a = prepared_a;
- _transformed_b = transformed_b;
- _tmp_c = tmp_c;
- _c = c;
- _block_walker = block_walker;
- _block_sizes = block_sizes;
- _params = params;
- _b_is_pretransposed = b_is_pretransposed;
- _alpha = alpha;
- _beta = beta;
+ _prepared_a = prepared_a;
+ _transformed_b = transformed_b;
+ _tmp_c = tmp_c;
+ _c = c;
+ _block_walker = block_walker;
+ _block_sizes = block_sizes;
+ _params = params;
+ _b_is_pretransposed = gemm_info.pretranpose_B();
+ _reinterpret_c_as_3d = gemm_info.depth_output_gemm3d() != 0;
+ _alpha = alpha;
+ _beta = beta;
auto_init_if_empty(*_tmp_c->info(), c->info()->clone()->set_tensor_shape(TensorShape{ _block_sizes.x_block * strategy::out_height(), max_num_threads }));
}
@@ -133,6 +134,14 @@ public:
TensorAccessor<typename strategy::result_type> c(*_c);
TensorAccessor<typename strategy::result_type> tmp_c(*_tmp_c);
+ // Handle 3d output re-interpretation
+ if(_reinterpret_c_as_3d)
+ {
+ Strides c_strides_as_3d = _c->info()->strides_in_bytes();
+ c_strides_as_3d.remove(Window::DimZ);
+ c.set_strides(c_strides_as_3d);
+ }
+
int prev_batch = -1;
typename strategy::operand_type *a_ptr = nullptr;
auto window_iterator = arm_compute::create_window_iterator(batch_window, start_offset, end_offset, [&](const Coordinates & id)
@@ -216,9 +225,9 @@ private:
INEGEMMWrapperKernel::Params _params{};
Window _block_walker{};
bool _b_is_pretransposed{ false };
+ bool _reinterpret_c_as_3d{ false };
typename strategy::result_type _alpha{};
typename strategy::result_type _beta{};
};
-
} // namespace arm_compute
#endif /* __ARM_COMPUTE_NEGEMMINTERLEAVEDMATRIXMULTIPLYWRAPPER_H__ */
diff --git a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h
index 5d6cd02398..b18d327339 100644
--- a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h
+++ b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h
@@ -87,20 +87,22 @@ class NEGEMMInterleavedTransformAWrapperTemplate : public NEGEMMInterleavedTrans
public:
/** Configure the reshape A routine.
*
- * @param[in] a Input matrix A.
- * @param[out] transformed_a Reshaped matrix A.
- * @param[in] transpose_a Also transpose A ?
- * @param[in] block_walker Window representing the layout of the matrix's blocks
- * @param[in] params M, N, K sizes.
+ * @param[in] a Input matrix A.
+ * @param[out] transformed_a Reshaped matrix A.
+ * @param[in] transpose_a Also transpose A ?
+ * @param[in] reinterpret_a_as_3d Re-interpret as 3D ?
+ * @param[in] block_walker Window representing the layout of the matrix's blocks
+ * @param[in] params M, N, K sizes.
*/
- void configure(const ITensor *a, ITensor *transformed_a, bool transpose_a, const Window &block_walker, const INEGEMMWrapperKernel::Params &params)
+ void configure(const ITensor *a, ITensor *transformed_a, bool transpose_a, bool reinterpret_a_as_3d, const Window &block_walker, const INEGEMMWrapperKernel::Params &params)
{
- _a = a;
- _transformed_a = transformed_a;
- _transpose_a = transpose_a;
- _Ksize = params.K;
- _Msize = params.M;
- _k_multi_window = block_walker.shift_dimensions(1); // block_walker contains (M,K,Multi) --> shift by 1 to get rid of the "M" dimension
+ _a = a;
+ _transformed_a = transformed_a;
+ _transpose_a = transpose_a;
+ _reinterpret_a_as_3d = reinterpret_a_as_3d;
+ _Ksize = params.K;
+ _Msize = params.M;
+ _k_multi_window = block_walker.shift_dimensions(1); // block_walker contains (M,K,Multi) --> shift by 1 to get rid of the "M" dimension
}
// Inherited methods overridden:
@@ -110,12 +112,12 @@ public:
TensorAccessor<typename strategy::operand_type> a(*_a);
TensorAccessor<typename strategy::operand_type> transformed_a(*_transformed_a);
- if(_a->info()->data_layout() == DataLayout::NHWC)
+ // Handle 3d input re-interpretation
+ if(_reinterpret_a_as_3d)
{
- // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is
- // the relevant multiple of the row stride.
- const size_t nhwc_batch_stride = _a->info()->strides_in_bytes().y() * _Msize;
- a.set_stride(2, nhwc_batch_stride);
+ Strides a_strides_as_3d = _a->info()->strides_in_bytes();
+ a_strides_as_3d.remove(Window::DimZ);
+ a.set_strides(a_strides_as_3d);
}
unsigned int last_m = 0;
@@ -164,8 +166,8 @@ private:
unsigned int _Msize{ 0 };
unsigned int _Ksize{ 0 };
bool _transpose_a{ false };
+ bool _reinterpret_a_as_3d{ false };
Window _k_multi_window{};
};
-
} // namespace arm_compute
#endif /* __ARM_COMPUTE_NEGEMMINTERLEAVEDTRANSFORMAWRAPPER_H__ */
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index ad679d6786..b4d94eced4 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -1765,9 +1765,17 @@ class GEMMInfo
{
public:
/** Default constructor */
- GEMMInfo()
- : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(true), _depth_output_gemm3d(0), _reinterpret_input_as_3d(false), _retain_internal_weights(false), _gemmlowp_output_stage(),
- _fp_mixed_precision(false), _broadcast_bias(false)
+ GEMMInfo() noexcept
+ : _is_a_reshaped(false),
+ _is_b_reshaped(false),
+ _reshape_b_only_on_first_run(true),
+ _depth_output_gemm3d(0),
+ _reinterpret_input_as_3d(false),
+ _retain_internal_weights(false),
+ _gemmlowp_output_stage(),
+ _fp_mixed_precision(false),
+ _broadcast_bias(false),
+ _pretranpose_B(true)
{
}
/** Constructor
@@ -1785,10 +1793,17 @@ public:
* @param[in] broadcast_bias (Optional) Broadcast the shape of the bias tensor from a vector to a matrix.
*/
GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false,
- GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool broadcast_bias = false)
- : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), _depth_output_gemm3d(depth_output_gemm3d),
- _reinterpret_input_as_3d(reinterpret_input_as_3d), _retain_internal_weights(retain_internal_weights), _gemmlowp_output_stage(gemmlowp_output_stage), _fp_mixed_precision(fp_mixed_precision),
- _broadcast_bias(broadcast_bias)
+ GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool broadcast_bias = false) noexcept
+ : _is_a_reshaped(is_a_reshaped),
+ _is_b_reshaped(is_b_reshaped),
+ _reshape_b_only_on_first_run(reshape_b_only_on_first_run),
+ _depth_output_gemm3d(depth_output_gemm3d),
+ _reinterpret_input_as_3d(reinterpret_input_as_3d),
+ _retain_internal_weights(retain_internal_weights),
+ _gemmlowp_output_stage(gemmlowp_output_stage),
+ _fp_mixed_precision(fp_mixed_precision),
+ _broadcast_bias(broadcast_bias),
+ _pretranpose_B(reshape_b_only_on_first_run)
{
}
/** Flag which specifies if the matrix A has been reshaped
@@ -1865,17 +1880,34 @@ public:
{
return _broadcast_bias;
};
+ /** Flag which specifies whether b should be pre-transposed if supported.
+ *
+ * @return True if b should be pre-transposed else false.
+ */
+ bool pretranpose_B() const
+ {
+ return _pretranpose_B;
+ };
+ /** Set pre-transpose b flag
+ *
+ * @param[in] flag Flag to set
+ */
+ void set_pretranpose_B(bool flag)
+ {
+ _pretranpose_B = flag;
+ }
private:
- const bool _is_a_reshaped;
- const bool _is_b_reshaped;
- const bool _reshape_b_only_on_first_run;
- const int _depth_output_gemm3d;
- const bool _reinterpret_input_as_3d;
- const bool _retain_internal_weights;
- const GEMMLowpOutputStageInfo _gemmlowp_output_stage;
- const bool _fp_mixed_precision;
- const bool _broadcast_bias;
+ bool _is_a_reshaped;
+ bool _is_b_reshaped;
+ bool _reshape_b_only_on_first_run;
+ int _depth_output_gemm3d;
+ bool _reinterpret_input_as_3d;
+ bool _retain_internal_weights;
+ GEMMLowpOutputStageInfo _gemmlowp_output_stage;
+ bool _fp_mixed_precision;
+ bool _broadcast_bias;
+ bool _pretranpose_B;
};
/** Winograd information */
diff --git a/arm_compute/core/WindowIterator.h b/arm_compute/core/WindowIterator.h
index 32d6293a5a..15289b6d69 100644
--- a/arm_compute/core/WindowIterator.h
+++ b/arm_compute/core/WindowIterator.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -86,6 +86,15 @@ public:
_strides[dim] = size;
}
+ /** Manually set the strides
+ *
+ * @param[in] strides Strides to set
+ */
+ void set_strides(const Strides &strides)
+ {
+ _strides = strides;
+ }
+
/** Returns a pointer to the element at coordinates (x,y,z,w)
*
* @param[in] x X coordinates
diff --git a/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h b/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h
index 2fc2cf4a99..b5a2978ea1 100644
--- a/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h
+++ b/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -64,17 +64,17 @@ private:
/** If supported create the ACL function corresponding to the GemmMethod provided to process the other passed parameters
*
- * @param[in] method GemmMethod to use to perform the matrix multiplication.
- * @param[in] a Input tensor (Matrix A).
- * @param[in] b Input tensor (Matrix B).
- * @param[out] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
- * @param[in] alpha Scalar multiplier to apply to AB matrix product.
- * @param[in] beta Scalar multiplier to apply to input D matrix before adding product.
- * @param[in] pretransposed_hint Can the B tensor can be pretransposed (ie shared across invocations)?
+ * @param[in] method GemmMethod to use to perform the matrix multiplication.
+ * @param[in] a Input tensor (Matrix A).
+ * @param[in] b Input tensor (Matrix B).
+ * @param[out] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
+ * @param[in] alpha Scalar multiplier to apply to AB matrix product.
+ * @param[in] beta Scalar multiplier to apply to input D matrix before adding product.
+ * @param[in] gemm_info GEMM meta-data
*
* @return True if the method is supported and the function was successfully created, false otherwise.
*/
- bool create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint);
+ bool create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info);
/** Interface for the arm_gemm fallback */
std::unique_ptr<IFallback> _arm_gemm;
@@ -83,27 +83,27 @@ private:
public:
/** If supported create an ACL function else fallback to the arm_gemm function.
*
- * @param[in] a Input tensor (Matrix A)
- * @param[in] b Input tensor (Matrix B)
- * @param[out] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
- * @param[in] alpha Scalar multiplier to apply to AB matrix product.
- * @param[in] beta Scalar multiplier to apply to input D matrix before adding product.
- * @param[in] pretranspose_hint Can the B tensor can be pretransposed (ie shared across invocations)?
+ * @param[in] a Input tensor (Matrix A)
+ * @param[in] b Input tensor (Matrix B)
+ * @param[out] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
+ * @param[in] alpha Scalar multiplier to apply to AB matrix product.
+ * @param[in] beta Scalar multiplier to apply to input D matrix before adding product.
+ * @param[in] gemm_info GEMM meta-data
*/
- void configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint);
+ void configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info);
/** Indicates whether or not this function can be used to process the given parameters.
*
- * @param[in] a Input tensor (Matrix A)
- * @param[in] b Input tensor (Matrix B)
- * @param[in] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
- * @param[in] alpha Scalar multiplier to apply to AB matrix product.
- * @param[in] beta Scalar multiplier to apply to input D matrix before adding product.
- * @param[in] pretranspose_hint Can the B tensor can be pretransposed (ie shared across invocations)?
+ * @param[in] a Input tensor (Matrix A)
+ * @param[in] b Input tensor (Matrix B)
+ * @param[in] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
+ * @param[in] alpha Scalar multiplier to apply to AB matrix product.
+ * @param[in] beta Scalar multiplier to apply to input D matrix before adding product.
+ * @param[in] gemm_info GEMM meta-data
*
* @return a status.
*/
- static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, bool pretranspose_hint);
+ static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info);
/** Was the function successfully configured ?
*
* @return True if the function is configured and ready to run
diff --git a/arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h b/arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h
index 949564750b..ad89e1fbec 100644
--- a/arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h
+++ b/arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h
@@ -104,14 +104,14 @@ public:
*
* @note The input and output tensor must have the same dimensions
*
- * @param[in] a Input tensor (Matrix A)
- * @param[in] b Input tensor (Matrix B)
- * @param[out] c Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
- * @param[in] alpha Scalar multiplier to apply to AB matrix product.
- * @param[in] beta Scalar multiplier to apply to input C matrix before adding product.
- * @param[in] pretranspose_b If true, pretranspose B once during the prepare() stage instead of on the fly every time.
+ * @param[in] a Input tensor (Matrix A)
+ * @param[in] b Input tensor (Matrix B)
+ * @param[out] c Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0.
+ * @param[in] alpha Scalar multiplier to apply to AB matrix product.
+ * @param[in] beta Scalar multiplier to apply to input C matrix before adding product.
+ * @param[in] gemm_info GEMM meta-data
*/
- void configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, bool pretranspose_b);
+ void configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, const GEMMInfo &gemm_info);
// Inherited methods overridden:
void run() override;
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp b/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp
index f4485bcbb1..e1af2d4490 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp
@@ -61,12 +61,16 @@ inline void MergeResults<8, 6, false>(float *out, const float *in, const int ldo
switch ((y + 5) - ymax) {
case 4:
outptr1 = dummyres;
+ // fall through
case 3:
outptr2 = dummyres;
+ // fall through
case 2:
outptr3 = dummyres;
+ // fall through
case 1:
outptr4 = dummyres;
+ // fall through
case 0:
outptr5 = dummyres;
break;
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp
index be23978b80..9fca4e3a84 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp
@@ -63,16 +63,22 @@ inline void MergeResults<12, 8, false>(float *out, const float *in, const int ld
switch ((y + 7) - ymax) {
case 6:
outptr1 = dummyres;
+ // fall through
case 5:
outptr2 = dummyres;
+ // fall through
case 4:
outptr3 = dummyres;
+ // fall through
case 3:
outptr4 = dummyres;
+ // fall through
case 2:
outptr5 = dummyres;
+ // fall through
case 1:
outptr6 = dummyres;
+ // fall through
case 0:
outptr7 = dummyres;
break;
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp
index 9e5eb88dc1..0e638eef1c 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp
@@ -66,16 +66,22 @@ inline void MergeResults<12,8,false>(__fp16 *out, const float *in, int ldout, in
switch ((y + 7) - ymax) {
case 6:
outptr1 = dummyres;
+ // fall through
case 5:
outptr2 = dummyres;
+ // fall through
case 4:
outptr3 = dummyres;
+ // fall through
case 3:
outptr4 = dummyres;
+ // fall through
case 2:
outptr5 = dummyres;
+ // fall through
case 1:
outptr6 = dummyres;
+ // fall through
case 0:
outptr7 = dummyres;
break;
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp
index 3ed43b10bd..60cc2f32da 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp
@@ -65,16 +65,22 @@ inline void MergeResults<24, 8>(__fp16 *out, const __fp16 *in, const int ldout,
switch ((y + 7) - ymax) {
case 6:
outptr1 = dummyres;
+ // fall through
case 5:
outptr2 = dummyres;
+ // fall through
case 4:
outptr3 = dummyres;
+ // fall through
case 3:
outptr4 = dummyres;
+ // fall through
case 2:
outptr5 = dummyres;
+ // fall through
case 1:
outptr6 = dummyres;
+ // fall through
case 0:
outptr7 = dummyres;
break;
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp
index 35d4cc5d73..0212dfdbb6 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp
@@ -63,16 +63,22 @@ inline void MergeResults<12, 8, false>(int32_t *out, const int32_t *in, const in
switch ((y + 7) - ymax) {
case 6:
outptr1 = dummyres;
+ // fall through
case 5:
outptr2 = dummyres;
+ // fall through
case 4:
outptr3 = dummyres;
+ // fall through
case 3:
outptr4 = dummyres;
+ // fall through
case 2:
outptr5 = dummyres;
+ // fall through
case 1:
outptr6 = dummyres;
+ // fall through
case 0:
outptr7 = dummyres;
break;
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp
index 20ad301a18..a460fdfcf4 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp
@@ -60,12 +60,16 @@ inline void TransformImpl<6, 1, false, 4, 4, false>::Transform(T *out, const T *
/* Everything falls through in here */
case 4:
inptr1 = zerobuff;
+ // fall through
case 3:
inptr2 = zerobuff;
+ // fall through
case 2:
inptr3 = zerobuff;
+ // fall through
case 1:
inptr4 = zerobuff;
+ // fall through
case 0:
inptr5 = zerobuff;
break;
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp
index 2f513a6118..6a15fc42e4 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp
@@ -57,8 +57,10 @@ void TransformImpl<4, 16, false, 1, 1, false>::Transform(T *out, const T *in, in
/* Everything falls through in here */
case 2:
inptr1 = zerobuff;
+ // fall through
case 1:
inptr2 = zerobuff;
+ // fall through
case 0:
inptr3 = zerobuff;
break;
@@ -93,8 +95,10 @@ void TransformImpl<4, 16, false, 1, 1, false>::Transform(T *out, const T *in, in
/* Everything falls through in here */
case 2:
inptr1 = zerobuff;
+ // fall through
case 1:
inptr2 = zerobuff;
+ // fall through
case 0:
inptr3 = zerobuff;
break;
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp
index 27136d144a..0028ab08a9 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp
@@ -64,16 +64,22 @@ void TransformImpl<8, 1, false, 2, 2, false>::Transform(T *out, const T *in, int
/* Everything falls through in here */
case 6:
inptr1 = zerobuff;
+ // fall through
case 5:
inptr2 = zerobuff;
+ // fall through
case 4:
inptr3 = zerobuff;
+ // fall through
case 3:
inptr4 = zerobuff;
+ // fall through
case 2:
inptr5 = zerobuff;
+ // fall through
case 1:
inptr6 = zerobuff;
+ // fall through
case 0:
inptr7 = zerobuff;
break;
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp
index 54822c81b0..758c084a46 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp
@@ -64,16 +64,22 @@ inline void TransformImpl<8, 1, false, 4, 4, false>::Transform(T *out, const T *
/* Everything falls through in here */
case 6:
inptr1 = zerobuff;
+ // fall through
case 5:
inptr2 = zerobuff;
+ // fall through
case 4:
inptr3 = zerobuff;
+ // fall through
case 3:
inptr4 = zerobuff;
+ // fall through
case 2:
inptr5 = zerobuff;
+ // fall through
case 1:
inptr6 = zerobuff;
+ // fall through
case 0:
inptr7 = zerobuff;
break;
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp
index 0606330d27..de8e95a6d7 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp
@@ -64,16 +64,22 @@ inline void TransformImpl<8, 1, false, 4, 2, false>::Transform(float *out, const
/* Everything falls through in here */
case 6:
inptr1 = zerobuff;
+ // fall through
case 5:
inptr2 = zerobuff;
+ // fall through
case 4:
inptr3 = zerobuff;
+ // fall through
case 3:
inptr4 = zerobuff;
+ // fall through
case 2:
inptr5 = zerobuff;
+ // fall through
case 1:
inptr6 = zerobuff;
+ // fall through
case 0:
inptr7 = zerobuff;
break;
diff --git a/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp b/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp
index 0fc3610014..d00f204b81 100644
--- a/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp
+++ b/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -33,11 +33,11 @@
using namespace arm_compute;
INEGEMMWrapperKernel::INEGEMMWrapperKernel()
- : _a(nullptr), _b(nullptr), _c(nullptr), _params(), _window3d(), _window_shape()
+ : _a(nullptr), _b(nullptr), _c(nullptr), _params(), _gemm_info(), _window3d(), _window_shape()
{
}
-INEGEMMWrapperKernel::Params INEGEMMWrapperKernel::extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c)
+INEGEMMWrapperKernel::Params INEGEMMWrapperKernel::extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c, const GEMMInfo &gemm_info)
{
Params p;
@@ -45,21 +45,30 @@ INEGEMMWrapperKernel::Params INEGEMMWrapperKernel::extract_parameters(const ITen
ARM_COMPUTE_ERROR_ON_NULLPTR(b);
ARM_COMPUTE_ERROR_ON_NULLPTR(c);
+ // Initalize params
p.M = c->info()->tensor_shape().y();
p.N = c->info()->tensor_shape().x();
p.K = a->info()->tensor_shape().x();
p.multis = b->info()->tensor_shape().z();
p.batches = c->info()->tensor_shape().total_size_upper(2) / p.multis; //COMPMID-1423: Agree on and document the layout of gemm inputs/outputs
+ // Update M in case of GEMM3D for output
+ if(gemm_info.depth_output_gemm3d() != 0)
+ {
+ p.M = c->info()->tensor_shape().y() * c->info()->tensor_shape().z();
+ p.batches = c->info()->tensor_shape().total_size_upper(3) / p.multis;
+ }
+
return p;
}
-void INEGEMMWrapperKernel::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta)
+void INEGEMMWrapperKernel::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, const GEMMInfo &gemm_info)
{
- _params = extract_parameters(a, b, c);
- _a = a;
- _b = b;
- _c = c;
+ _gemm_info = gemm_info;
+ _params = extract_parameters(a, b, c, gemm_info);
+ _a = a;
+ _b = b;
+ _c = c;
_window3d = configure_internal(alpha, beta);
_window_shape = _window3d.shape();
diff --git a/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h b/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h
index 26d9e9999d..6e30148b5d 100644
--- a/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h
+++ b/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h
@@ -76,32 +76,34 @@ public:
* @param[in] transformed_a Reshaped tensor A.
* @param[in] block_walker Window representing the layout of the matrix's blocks.
* @param[in] params M, N, K sizes.
+ * @param[in] gemm_info GEMM meta-data
*
* @return A wrapped specialized transformA kernel
*/
virtual std::unique_ptr<NEGEMMInterleavedTransformAWrapper> instantiate_transformA(const ITensor *a,
ITensor *transformed_a,
const Window &block_walker,
- const INEGEMMWrapperKernel::Params &params) = 0;
+ const INEGEMMWrapperKernel::Params &params,
+ const GEMMInfo &gemm_info) = 0;
/** Instantiate and configure a prepareB Kernel
*
- * @param transformed_a Already reshaped tensor A.
- * @param transformed_b Already reshaped tensor B.
- * @param tmp_c Temporary buffer to be used to store intermediate results.
- * @param c Result tensor C.
- * @param block_walker Window containing iteration information for the M and batch dimensions.
- * @param block_sizes Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes).
- * @param params M, N, K sizes.
- * @param alpha Alpha value
- * @param beta Beta value
- * @param pretranspose_b Is B also pretransposed ?
- * @param num_threads Maximum number of threads that might be used for the calculations.
+ * @param[in] transformed_a Already reshaped tensor A.
+ * @param[in] transformed_b Already reshaped tensor B.
+ * @param[in] tmp_c Temporary buffer to be used to store intermediate results.
+ * @param[in] c Result tensor C.
+ * @param[in] block_walker Window containing iteration information for the M and batch dimensions.
+ * @param[in] block_sizes Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes).
+ * @param[in] params M, N, K sizes.
+ * @param[in] alpha Alpha value
+ * @param[in] beta Beta value
+ * @param[in] gemm_info GEMM meta-data
+ * @param[in] num_threads Maximum number of threads that might be used for the calculations.
*
* @return A wrapped specialized MatrixMultiply kernel
*/
virtual std::unique_ptr<NEGEMMInterleavedMatrixMultiplyWrapper> instantiate_matrix_multiply(const ITensor *transformed_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c,
const Window &block_walker, const BlockSizes &block_sizes,
- const INEGEMMWrapperKernel::Params &params, float alpha, float beta, bool pretranspose_b,
+ const INEGEMMWrapperKernel::Params &params, float alpha, float beta, const GEMMInfo &gemm_info,
unsigned int num_threads) = 0;
/** Calculates the block sizes of a given strategy
*
@@ -138,19 +140,20 @@ public:
std::unique_ptr<NEGEMMInterleavedTransformAWrapper> instantiate_transformA(const ITensor *a,
ITensor *transformed_a,
const Window &block_walker,
- const INEGEMMWrapperKernel::Params &params) override
+ const INEGEMMWrapperKernel::Params &params,
+ const GEMMInfo &gemm_info) override
{
auto transform_a = support::cpp14::make_unique<NEGEMMInterleavedTransformAWrapperTemplate<strategy>>();
- transform_a->configure(a, transformed_a, false, block_walker, params);
+ transform_a->configure(a, transformed_a, false, gemm_info.reinterpret_input_as_3d(), block_walker, params);
return std::move(transform_a);
}
std::unique_ptr<NEGEMMInterleavedMatrixMultiplyWrapper> instantiate_matrix_multiply(const ITensor *transformed_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c,
const Window &block_walker, const BlockSizes &block_sizes,
- const INEGEMMWrapperKernel::Params &params, float alpha, float beta, bool pretranspose_b,
+ const INEGEMMWrapperKernel::Params &params, float alpha, float beta, const GEMMInfo &gemm_info,
unsigned int num_threads) override
{
auto matrix_multiply = support::cpp14::make_unique<NEGEMMInterleavedMatrixMultiplyWrapperTemplate<strategy>>();
- matrix_multiply->configure(transformed_a, transformed_b, tmp_c, c, block_walker, block_sizes, params, pretranspose_b, alpha, beta, num_threads);
+ matrix_multiply->configure(transformed_a, transformed_b, tmp_c, c, block_walker, block_sizes, params, gemm_info, alpha, beta, num_threads);
return std::move(matrix_multiply);
}
diff --git a/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp b/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp
index 97c20dbd4e..ecdb5a938c 100644
--- a/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp
+++ b/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp
@@ -81,12 +81,20 @@ void NEGEMMNativeWrapperKernel<To, Tr>::run_internal(const Window &window, const
TensorAccessor<To> b(*_b);
TensorAccessor<Tr> c(*_c);
- if(_a->info()->data_layout() == DataLayout::NHWC)
+ // Handle 3d input re-interpretation
+ if(_gemm_info.reinterpret_input_as_3d())
{
- // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is
- // the relevant multiple of the row stride.
- const size_t nhwc_batch_stride = _a->info()->strides_in_bytes().y() * _c->info()->dimension(1);
- a.set_stride(2, nhwc_batch_stride);
+ Strides a_strides_as_3d = _a->info()->strides_in_bytes();
+ a_strides_as_3d.remove(Window::DimZ);
+ a.set_strides(a_strides_as_3d);
+ }
+
+ // Handle 3d output re-interpretation
+ if(_gemm_info.depth_output_gemm3d() != 0)
+ {
+ Strides c_strides_as_3d = _c->info()->strides_in_bytes();
+ c_strides_as_3d.remove(Window::DimZ);
+ c.set_strides(c_strides_as_3d);
}
unsigned int m_end = 0;
diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp
index 55bcc45d12..2f36397c8e 100644
--- a/src/runtime/NEON/functions/NEGEMM.cpp
+++ b/src/runtime/NEON/functions/NEGEMM.cpp
@@ -58,17 +58,19 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe
_run_vector_matrix_multiplication = a->info()->dimension(1) < 2;
_original_b = b;
- bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, _reshape_b_only_on_first_run));
+ bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, gemm_info));
if(run_optimised)
{
if(MEMInfo::get_policy() == MemoryPolicy::MINIMIZE)
{
- _asm_glue.configure(a, b, d, alpha, beta, false);
+ GEMMInfo gemm_info_ntb = gemm_info;
+ gemm_info_ntb.set_pretranpose_B(false);
+ _asm_glue.configure(a, b, d, alpha, beta, gemm_info_ntb);
}
else
{
- _asm_glue.configure(a, b, d, alpha, beta, _reshape_b_only_on_first_run);
+ _asm_glue.configure(a, b, d, alpha, beta, gemm_info);
}
ARM_COMPUTE_ERROR_ON(!_asm_glue.is_configured());
}
@@ -176,7 +178,7 @@ Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso
}
// Check if we need to run the optimized assembly kernel
- const bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a, b, output, alpha, beta, true));
+ const bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a, b, output, alpha, beta, gemm_info));
if(!run_optimised)
{
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index 55e067f52d..2de7d2b279 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -36,21 +36,22 @@ namespace arm_compute
namespace
{
std::unique_ptr<IFunction> create_function_all_types(const arm_gemm::KernelDescription &gemm_kernel_info,
- const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+ const ITensor *a, const ITensor *b, ITensor *d,
+ float alpha, float beta, const GEMMInfo &gemm_info,
std::shared_ptr<IMemoryManager> memory_manager)
{
- //Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure()
+ // Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure()
switch(gemm_kernel_info.method)
{
case arm_gemm::GemmMethod::GEMM_INTERLEAVED:
{
- if(!pretranspose_hint)
+ if(!gemm_info.pretranpose_B())
{
return nullptr;
}
auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
- function->configure(a, b, d, alpha, beta, pretranspose_hint);
+ function->configure(a, b, d, alpha, beta, gemm_info);
return std::move(function);
}
#if defined(__aarch64__)
@@ -59,7 +60,7 @@ std::unique_ptr<IFunction> create_function_all_types(const arm_gemm::KernelDescr
if(gemm_kernel_info.name.find("sgemm_native_16x4") != std::string::npos)
{
auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>();
- kernel->configure(a, b, d, alpha, beta);
+ kernel->configure(a, b, d, alpha, beta, gemm_info);
auto function = support::cpp14::make_unique<NESimpleAssemblyFunction>();
function->configure(std::move(kernel));
return std::move(function);
@@ -83,9 +84,11 @@ public:
* @param[in] b Input tensor containing the Matrix B.
* @param[out] d Output tensor to store the result of matrix multiplication.
* @param[in] args Matrix multiplication information.
+ * @param[in] gemm_info GEMM meta-data
* @param[in] memory_group Memory group to be used by the function.
*/
- void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args, MemoryGroup &memory_group);
+ void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args,
+ const GEMMInfo &gemm_info, MemoryGroup &memory_group);
// Inherited methods overridden:
void run() override;
@@ -123,10 +126,13 @@ private:
Tensor _pretranspose{};
/** Prepared flag */
bool _is_prepared{ false };
+ /** GEMM meta-data */
+ GEMMInfo _gemm_info{};
};
template <typename TypeInput, typename TypeOutput>
-void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args, MemoryGroup &memory_group)
+void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args,
+ const GEMMInfo &gemm_info, MemoryGroup &memory_group)
{
arm_gemm::GemmConfig gemm_cfg;
const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args);
@@ -168,6 +174,7 @@ void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor
_a = a;
_b = b;
_d = d;
+ _gemm_info = gemm_info;
// Check for pre-transposed support
if(_gemm_kernel_asm->B_pretranspose_required())
{
@@ -222,17 +229,17 @@ void Fallback<TypeInput, TypeOutput>::run()
int ldb = 0;
const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
- // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is
- // the relevant multiple of the row stride.
- const bool is_nhwc = _a->info()->data_layout() == DataLayout::NHWC;
- const int stride_in_bytes_a = is_nhwc ? _a->info()->strides_in_bytes().y() * _d->info()->dimension(1) : _a->info()->strides_in_bytes().z();
+ const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d() != 0 ? 3 : 2;
+ const size_t a_multi_idx = a_batch_idx + 1;
+ const size_t d_batch_idx = _gemm_info.depth_output_gemm3d() != 0 ? 3 : 2;
+ const size_t d_multi_idx = d_batch_idx + 1;
- const int batch_stride_a = stride_in_bytes_a / sizeof(TypeInput);
- const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput);
+ const int batch_stride_a = _a->info()->strides_in_bytes()[a_batch_idx] / sizeof(TypeInput);
+ const int batch_stride_d = _d->info()->strides_in_bytes()[d_batch_idx] / sizeof(TypeOutput);
- const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput);
+ const int multi_stride_a = _a->info()->strides_in_bytes()[a_multi_idx] / sizeof(TypeInput);
int multi_stride_b = 0;
- const int multi_stride_d = _d->info()->strides_in_bytes()[3] / sizeof(TypeOutput);
+ const int multi_stride_d = _d->info()->strides_in_bytes()[d_multi_idx] / sizeof(TypeOutput);
const auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer() + _a->info()->offset_first_element_in_bytes());
const TypeInput *in1_ptr = nullptr;
@@ -270,24 +277,27 @@ void Fallback<TypeInput, TypeOutput>::run()
}
template <typename TypeInput, typename TypeOutput>
-void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group, const ITensor *a, const ITensor *b,
- ITensor *d, float alpha, float beta, bool pretranspose_hint, std::shared_ptr<IMemoryManager> memory_manager)
+void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function,
+ std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm,
+ MemoryGroup &memory_group, const ITensor *a, const ITensor *b,
+ ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info,
+ std::shared_ptr<IMemoryManager> memory_manager)
{
- INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d);
+ INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info);
const CPUInfo &ci = NEScheduler::get().cpu_info();
unsigned int num_threads = NEScheduler::get().num_threads();
- arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
+ arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, gemm_info.pretranpose_B());
//Try to create an ACL function:
- acl_function = create_function_all_types(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, std::move(memory_manager));
+ acl_function = create_function_all_types(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, gemm_info, std::move(memory_manager));
//If we still don't have an ACL function:
if(acl_function == nullptr)
{
//Fallback onto arm_gemm function if ACL doesn't support this method.
auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput>>();
- fallback->configure(a, b, d, args, memory_group);
+ fallback->configure(a, b, d, args, gemm_info, memory_group);
arm_gemm = std::move(fallback);
}
}
@@ -299,11 +309,11 @@ NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> m
{
}
-Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, bool pretranspose_hint)
+Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info)
{
ARM_COMPUTE_UNUSED(alpha);
ARM_COMPUTE_UNUSED(beta);
- ARM_COMPUTE_UNUSED(pretranspose_hint);
+ ARM_COMPUTE_UNUSED(gemm_info);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
#ifndef __aarch64__
@@ -319,14 +329,14 @@ Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo
return Status{};
}
-void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
+void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a);
ARM_COMPUTE_ERROR_ON_NULLPTR(b);
ARM_COMPUTE_ERROR_ON_NULLPTR(d);
//If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured()
- if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, pretranspose_hint))
+ if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, gemm_info))
{
return;
}
@@ -334,20 +344,20 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITens
switch(a->info()->data_type())
{
case DataType::F32:
- create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
+ create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, gemm_info, _memory_manager);
break;
#ifdef __aarch64__
case DataType::U8:
case DataType::QASYMM8:
- create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
+ create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, gemm_info, _memory_manager);
break;
case DataType::S8:
- create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
+ create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, gemm_info, _memory_manager);
break;
#endif /* __aarch64__ */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
- create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
+ create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, gemm_info, _memory_manager);
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
default:
diff --git a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp
index ede89bf558..5b70c8724c 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp
@@ -59,7 +59,7 @@ void NEGEMMLowpAssemblyMatrixMultiplyCore::configure(const ITensor *a, const ITe
case DataType::QASYMM8:
case DataType::U8:
{
- _asm_glue.configure(a, b, output, 1.f, 0.f, true);
+ _asm_glue.configure(a, b, output, 1.f, 0.f, GEMMInfo(false, false, true));
run_optimised = _asm_glue.is_configured();
break;
}
diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
index d8773e37ab..f10f114287 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
@@ -87,7 +87,7 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b,
case DataType::U8:
case DataType::S8:
{
- _asm_glue.configure(a, b, _fuse_output_stage ? &_mm_result_s32 : output, 1.f, 0.f, _reshape_b_only_on_first_run);
+ _asm_glue.configure(a, b, _fuse_output_stage ? &_mm_result_s32 : output, 1.f, 0.f, gemm_info);
_dot_product_path = _asm_glue.is_configured();
break;
}
@@ -224,9 +224,8 @@ Status NEGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
TensorInfo tmp_b_info{};
TensorInfo mm_result_s32_info{};
- int32_t a_offset = a->quantization_info().uniform().offset;
- int32_t b_offset = b->quantization_info().uniform().offset;
- const bool reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run();
+ int32_t a_offset = a->quantization_info().uniform().offset;
+ int32_t b_offset = b->quantization_info().uniform().offset;
bool fuse_output_stage = gemm_info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE;
if(fuse_output_stage)
@@ -235,7 +234,7 @@ Status NEGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
}
// Check if we need to run the optimized assembly kernel
- const bool run_optimised = bool(NEGEMMAssemblyDispatch::validate(a, b, fuse_output_stage ? &mm_result_s32_info : output, 1.f, 0.f, reshape_b_only_on_first_run));
+ const bool run_optimised = bool(NEGEMMAssemblyDispatch::validate(a, b, fuse_output_stage ? &mm_result_s32_info : output, 1.f, 0.f, gemm_info));
if(run_optimised)
{
diff --git a/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp b/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
index 20aa1496b6..ac809fa142 100644
--- a/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
+++ b/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
@@ -339,19 +339,19 @@ void NEGEMMInterleavedWrapper::prepare()
}
}
-void NEGEMMInterleavedWrapper::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, bool pretranspose_b)
+void NEGEMMInterleavedWrapper::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, const GEMMInfo &gemm_info)
{
- _params = INEGEMMWrapperKernel::extract_parameters(a, b, c);
+ _params = INEGEMMWrapperKernel::extract_parameters(a, b, c, gemm_info);
_a = a;
_b = b;
_c = c;
- _pretranspose_b = pretranspose_b;
+ _pretranspose_b = gemm_info.pretranpose_B();
const DataType input_type = a->info()->data_type();
const CPUInfo &ci = NEScheduler::get().cpu_info();
const unsigned int num_threads = NEScheduler::get().num_threads();
- const arm_gemm::KernelDescription gemm_kernel_info = get_gemm_info(input_type, ci, num_threads, _params, alpha, beta, pretranspose_b);
+ const arm_gemm::KernelDescription gemm_kernel_info = get_gemm_info(input_type, ci, num_threads, _params, alpha, beta, _pretranspose_b);
ARM_COMPUTE_ERROR_ON(gemm_kernel_info.method != arm_gemm::GemmMethod::GEMM_INTERLEAVED);
// Forcing 128-byte alignment (required by 32-bit kernels)
@@ -411,8 +411,8 @@ void NEGEMMInterleavedWrapper::configure(const ITensor *a, const ITensor *b, ITe
_memory_group.manage(&_transformed_a);
_memory_group.manage(&_tmp_c);
- _transform_a = strategy->instantiate_transformA(_a, &_transformed_a, _block_walker, _params);
- _matrix_multiply = strategy->instantiate_matrix_multiply(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, alpha, beta, pretranspose_b, num_threads);
+ _transform_a = strategy->instantiate_transformA(_a, &_transformed_a, _block_walker, _params, gemm_info);
+ _matrix_multiply = strategy->instantiate_matrix_multiply(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, alpha, beta, gemm_info, num_threads);
ARM_COMPUTE_ERROR_ON(_transform_a == nullptr);
ARM_COMPUTE_ERROR_ON(_matrix_multiply == nullptr);