aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h')
-rw-r--r--arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h38
1 files changed, 20 insertions, 18 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h
index 5d6cd02398..b18d327339 100644
--- a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h
+++ b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h
@@ -87,20 +87,22 @@ class NEGEMMInterleavedTransformAWrapperTemplate : public NEGEMMInterleavedTrans
public:
/** Configure the reshape A routine.
*
- * @param[in] a Input matrix A.
- * @param[out] transformed_a Reshaped matrix A.
- * @param[in] transpose_a Also transpose A ?
- * @param[in] block_walker Window representing the layout of the matrix's blocks
- * @param[in] params M, N, K sizes.
+ * @param[in] a Input matrix A.
+ * @param[out] transformed_a Reshaped matrix A.
+ * @param[in] transpose_a Also transpose A ?
+ * @param[in] reinterpret_a_as_3d Re-interpret as 3D ?
+ * @param[in] block_walker Window representing the layout of the matrix's blocks
+ * @param[in] params M, N, K sizes.
*/
- void configure(const ITensor *a, ITensor *transformed_a, bool transpose_a, const Window &block_walker, const INEGEMMWrapperKernel::Params &params)
+ void configure(const ITensor *a, ITensor *transformed_a, bool transpose_a, bool reinterpret_a_as_3d, const Window &block_walker, const INEGEMMWrapperKernel::Params &params)
{
- _a = a;
- _transformed_a = transformed_a;
- _transpose_a = transpose_a;
- _Ksize = params.K;
- _Msize = params.M;
- _k_multi_window = block_walker.shift_dimensions(1); // block_walker contains (M,K,Multi) --> shift by 1 to get rid of the "M" dimension
+ _a = a;
+ _transformed_a = transformed_a;
+ _transpose_a = transpose_a;
+ _reinterpret_a_as_3d = reinterpret_a_as_3d;
+ _Ksize = params.K;
+ _Msize = params.M;
+ _k_multi_window = block_walker.shift_dimensions(1); // block_walker contains (M,K,Multi) --> shift by 1 to get rid of the "M" dimension
}
// Inherited methods overridden:
@@ -110,12 +112,12 @@ public:
TensorAccessor<typename strategy::operand_type> a(*_a);
TensorAccessor<typename strategy::operand_type> transformed_a(*_transformed_a);
- if(_a->info()->data_layout() == DataLayout::NHWC)
+ // Handle 3d input re-interpretation
+ if(_reinterpret_a_as_3d)
{
- // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is
- // the relevant multiple of the row stride.
- const size_t nhwc_batch_stride = _a->info()->strides_in_bytes().y() * _Msize;
- a.set_stride(2, nhwc_batch_stride);
+ Strides a_strides_as_3d = _a->info()->strides_in_bytes();
+ a_strides_as_3d.remove(Window::DimZ);
+ a.set_strides(a_strides_as_3d);
}
unsigned int last_m = 0;
@@ -164,8 +166,8 @@ private:
unsigned int _Msize{ 0 };
unsigned int _Ksize{ 0 };
bool _transpose_a{ false };
+ bool _reinterpret_a_as_3d{ false };
Window _k_multi_window{};
};
-
} // namespace arm_compute
#endif /* __ARM_COMPUTE_NEGEMMINTERLEAVEDTRANSFORMAWRAPPER_H__ */