aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPablo Marquez Tello <pablo.tello@arm.com>2021-12-08 15:56:01 +0000
committerPablo Marquez Tello <pablo.tello@arm.com>2021-12-20 12:37:12 +0000
commit4d44ac8685662984386b65869c3ed6af1144a419 (patch)
tree535fecf108d6e257207d0742d688ee6ab4b9fed9
parentd91761435df676720c93332a3fcbd428244c9843 (diff)
downloadComputeLibrary-4d44ac8685662984386b65869c3ed6af1144a419.tar.gz
Added support for filter size 8x8 NCHW DirectConv
* Allows NEDeconvLayer to reduce memory usage when workload has filter size 8x8 and NCHW * Resolves MLCE-696 Change-Id: Iaaf40c813376360f813d5babfb988d3e04e4bbc0 Signed-off-by: Pablo Marquez Tello <pablo.tello@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6806 Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--src/cpu/kernels/CpuDirectConv2dKernel.cpp288
-rw-r--r--tests/validation/NEON/DirectConvolutionLayer.cpp22
2 files changed, 298 insertions, 12 deletions
diff --git a/src/cpu/kernels/CpuDirectConv2dKernel.cpp b/src/cpu/kernels/CpuDirectConv2dKernel.cpp
index 68de9803eb..1ab716aeac 100644
--- a/src/cpu/kernels/CpuDirectConv2dKernel.cpp
+++ b/src/cpu/kernels/CpuDirectConv2dKernel.cpp
@@ -329,8 +329,36 @@ public:
};
template <unsigned int stridex>
-float32x4x2_t convolve_5x5(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
- const float *m0, const float *m1, const float *m2, const float *m3, const float *m4);
+float32x4_t convolve_8x8(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4, const float *in_5, const float *in_6, const float *in_7,
+ const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, const float *m5, const float *m6, const float *m7);
+
+inline float32x4x4_t load_matrix4x4(const float *const m0, const float *const m1, const float *const m2, const float *const m3)
+{
+ const float32x4x4_t m00 =
+ {
+ {
+ vld1q_dup_f32(m0),
+ vld1q_dup_f32(m1),
+ vld1q_dup_f32(m2),
+ vld1q_dup_f32(m3)
+ }
+ };
+ return m00;
+}
+
+inline float32x4x3_t load_input(const float *const in)
+{
+ const float32x4x3_t vin =
+ {
+ {
+ vld1q_f32(in),
+ vld1q_f32(in + 4),
+ vld1q_f32(in + 8)
+ }
+ };
+ return vin;
+}
+
inline float32x4x3_t load_matrix_hi(const float *const m0, const float *const m1, const float *const m2)
{
@@ -357,20 +385,79 @@ inline float32x4x2_t load_matrix_lo(const float *const m3, const float *const m4
return m00;
}
-inline float32x4x3_t load_input(const float *const in)
+
+inline void convolve_row(float32x4_t &out, const float32x4x3_t& vin, const float32x4x4_t & lm, const float32x4x4_t & rm)
{
- const float32x4x3_t vin =
- {
- {
- vld1q_f32(in),
- vld1q_f32(in + 4),
- vld1q_f32(in + 8)
- }
- };
- return vin;
+ const auto & v0v1v2v3 = vin.val[0];
+ const auto & v4v5v6v7 = vin.val[1];
+ const auto & v8v9vavb = vin.val[2];
+ // |V0|V1|V2|V3| * |M0|M0|M0|M0|
+ out = vmlaq_f32(out, v0v1v2v3, lm.val[0]);
+ // |V1|V2|V3|V4| * |M1|M1|M1|M1|
+ out = vmlaq_f32(out, vextq_f32(v0v1v2v3, v4v5v6v7,1), lm.val[1]);
+ // |V2|V3|V4|V5| * |M2|M2|M2|M2|
+ out = vmlaq_f32(out, vextq_f32(v0v1v2v3, v4v5v6v7,2), lm.val[2]);
+ // |V3|V4|V5|V6| * |M3|M3|M3|M3|
+ out = vmlaq_f32(out, vextq_f32(v0v1v2v3, v4v5v6v7,3), lm.val[3]);
+ // |V4|V5|V6|V7| * |M4|M4|M4|M4|
+ out = vmlaq_f32(out, v4v5v6v7, rm.val[0]);
+ // |V5|V6|V7|V8| * |M5|M5|M5|M5|
+ out = vmlaq_f32(out, vextq_f32(v4v5v6v7, v8v9vavb,1), rm.val[1]);
+ // |V6|V7|V8|V9| * |M6|M6|M6|M6|
+ out = vmlaq_f32(out, vextq_f32(v4v5v6v7, v8v9vavb,2), rm.val[2]);
+ // |V7|V8|V9|va| * |M7|M7|M7|M7|
+ out = vmlaq_f32(out, vextq_f32(v4v5v6v7, v8v9vavb,3), rm.val[3]);
}
template <>
+inline float32x4_t convolve_8x8<1>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4, const float *in_5, const float *in_6, const float *in_7,
+ const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, const float *m5, const float *m6, const float *m7)
+{
+ const float32x4x3_t vin0 = load_input(in_0); // bring 12 values from the first row
+ const float32x4x3_t vin1 = load_input(in_1); // bring 12 values from the second row
+ const float32x4x3_t vin2 = load_input(in_2);
+ const float32x4x3_t vin3 = load_input(in_3);
+ const float32x4x3_t vin4 = load_input(in_4);
+ const float32x4x3_t vin5 = load_input(in_5);
+ const float32x4x3_t vin6 = load_input(in_6);
+ const float32x4x3_t vin7 = load_input(in_7);
+
+ const float32x4x4_t m00 = load_matrix4x4(m0, 1 + m0, 2 + m0, 3 + m0);
+ const float32x4x4_t m01 = load_matrix4x4(4 + m0, 5 + m0, 6 + m0, 7 +m0);
+ const float32x4x4_t m10 = load_matrix4x4(m1, 1 + m1, 2 + m1, 3 + m1);
+ const float32x4x4_t m11 = load_matrix4x4(4 + m1, 5 + m1, 6 + m1, 7 +m1);
+ const float32x4x4_t m20 = load_matrix4x4(m2, 1 + m2, 2 + m2, 3 + m2);
+ const float32x4x4_t m21 = load_matrix4x4(4 + m2, 5 + m2, 6 + m2, 7 +m2);
+ const float32x4x4_t m30 = load_matrix4x4(m3, 1 + m3, 2 + m3, 3 + m3);
+ const float32x4x4_t m31 = load_matrix4x4(4 + m3, 5 + m3, 6 + m3, 7 +m3);
+ const float32x4x4_t m40 = load_matrix4x4(m4, 1 + m4, 2 + m4, 3 + m4);
+ const float32x4x4_t m41 = load_matrix4x4(4 + m4, 5 + m4, 6 + m4, 7 +m4);
+ const float32x4x4_t m50 = load_matrix4x4(m5, 1 + m5, 2 + m5, 3 + m5);
+ const float32x4x4_t m51 = load_matrix4x4(4 + m5, 5 + m5, 6 + m5, 7 +m5);
+ const float32x4x4_t m60 = load_matrix4x4(m6, 1 + m6, 2 + m6, 3 + m6);
+ const float32x4x4_t m61 = load_matrix4x4(4 + m6, 5 + m6, 6 + m6, 7 +m6);
+ const float32x4x4_t m70 = load_matrix4x4(m7, 1 + m7, 2 + m7, 3 + m7);
+ const float32x4x4_t m71 = load_matrix4x4(4 + m7, 5 + m7, 6 + m7, 7 +m7);
+
+ float32x4_t out = vdupq_n_f32(0.f);
+ convolve_row(out,vin0,m00,m01);
+ convolve_row(out,vin1,m10,m11);
+ convolve_row(out,vin2,m20,m21);
+ convolve_row(out,vin3,m30,m31);
+ convolve_row(out,vin4,m40,m41);
+ convolve_row(out,vin5,m50,m51);
+ convolve_row(out,vin6,m60,m61);
+ convolve_row(out,vin7,m70,m71);
+ return out;
+}
+
+
+template <unsigned int stridex>
+float32x4x2_t convolve_5x5(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
+ const float *m0, const float *m1, const float *m2, const float *m3, const float *m4);
+
+
+template <>
inline float32x4x2_t convolve_5x5<1>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
{
@@ -711,6 +798,130 @@ public:
}
};
+template <typename T1, typename T2, unsigned int stridex>
+class convolver_8x8
+{
+public:
+ static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
+ const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info)
+ {
+ ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
+ const int input_stride_x = src->info()->strides_in_bytes().x();
+ const int input_stride_y = src->info()->strides_in_bytes().y();
+ const int input_stride_z = src->info()->strides_in_bytes().z();
+ const int output_stride_y = dst->info()->strides_in_bytes().y();
+ const int output_stride_z = dst->info()->strides_in_bytes().z();
+ const int kernel_stride_x = weights->info()->strides_in_bytes().x();
+ const int kernel_stride_y = weights->info()->strides_in_bytes().y();
+ const int kernel_stride_z = weights->info()->strides_in_bytes().z();
+ const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
+ const int output_w = dst->info()->dimension(0);
+ const int output_h = dst->info()->dimension(1);
+ const int num_planes_z = window.z().end() - window.z().start();
+ const int delta_input = get_input_num_elems_processed(num_elems_written_per_iteration, stridex);
+ const int kernel_depth = weights->info()->dimension(Window::DimZ);
+ const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
+ const unsigned int conv_pad_left = conv_info.pad_left();
+ const unsigned int conv_pad_top = conv_info.pad_top();
+
+ // setup output window for the iterator
+ Window window_out = window;
+ window_out.set(Window::DimX, Window::Dimension(0, dst->info()->dimension(Window::DimX), dst->info()->dimension(Window::DimX)));
+ window_out.set(Window::DimY, Window::Dimension(0, dst->info()->dimension(Window::DimY), dst->info()->dimension(Window::DimY)));
+ window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
+
+ // setup input window for the iterator
+ Window window_in = window;
+ // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
+ window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
+ window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
+ window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+ Window window_k = calculate_max_window(*weights->info(), Steps(1u));
+
+ Iterator out(dst, window_out);
+ Iterator in(src, window_in);
+ Iterator k(weights, window_k);
+
+ const uint8_t *k_ptr = k.ptr();
+
+ execute_window_loop(window_out, [&](const Coordinates & id)
+ {
+ const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
+ uint8_t *out_ptr = out.ptr();
+ int ih = 0;
+ int oh = 0;
+ for(int oz = 0; oz < num_planes_z; ++oz)
+ {
+ const int zoffset = id.z() + oz;
+ uint8_t *p_out_base = out_ptr + oz * output_stride_z;
+ // Step 1
+ {
+ const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
+ const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
+ const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
+ const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
+ const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
+ const auto ptr_k_r5 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 5 * kernel_stride_y + 0 * kernel_stride_x);
+ const auto ptr_k_r6 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 6 * kernel_stride_y + 0 * kernel_stride_x);
+ const auto ptr_k_r7 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 7 * kernel_stride_y + 0 * kernel_stride_x);
+ for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
+ {
+ auto in_0 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
+ auto in_1 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
+ auto in_2 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
+ auto in_3 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 3) * input_stride_y);
+ auto in_4 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 4) * input_stride_y);
+ auto in_5 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 5) * input_stride_y);
+ auto in_6 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 6) * input_stride_y);
+ auto in_7 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 7) * input_stride_y);
+ auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
+ for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
+ in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, in_5 += delta_input, in_6 += delta_input, in_7 += delta_input,
+ p_out += num_elems_written_per_iteration)
+ {
+ auto vres = convolve_8x8<stridex>(in_0, in_1, in_2, in_3, in_4, in_5, in_6, in_7, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4, ptr_k_r5, ptr_k_r6 , ptr_k_r7);
+ vst1q_f32(p_out, vres);
+ }
+ }
+ }
+ // Step 2
+ for(int p = 1; p < kernel_depth; ++p)
+ {
+ const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
+ const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
+ const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
+ const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
+ const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
+ const auto ptr_k_r5 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 5 * kernel_stride_y + 0 * kernel_stride_x);
+ const auto ptr_k_r6 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 6 * kernel_stride_y + 0 * kernel_stride_x);
+ const auto ptr_k_r7 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 7 * kernel_stride_y + 0 * kernel_stride_x);
+ for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
+ {
+ auto in_0 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 0) * input_stride_y);
+ auto in_1 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 1) * input_stride_y);
+ auto in_2 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 2) * input_stride_y);
+ auto in_3 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 3) * input_stride_y);
+ auto in_4 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 4) * input_stride_y);
+ auto in_5 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 5) * input_stride_y);
+ auto in_6 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 6) * input_stride_y);
+ auto in_7 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 7) * input_stride_y);
+ auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
+ for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
+ in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, in_5 += delta_input, in_6 += delta_input, in_7 += delta_input,
+ p_out += num_elems_written_per_iteration)
+ {
+ auto vres = convolve_8x8<stridex>(in_0, in_1, in_2, in_3, in_4, in_5, in_6, in_7, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4, ptr_k_r5, ptr_k_r6,ptr_k_r7);
+ vst1q_f32(p_out, vaddq_f32(vld1q_f32(p_out), vres));
+ }
+ }
+ }
+ }
+ },
+ in, out);
+ }
+};
+
template <typename T1, typename T2>
inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info)
@@ -815,6 +1026,22 @@ inline void convolve_5x5(const Window &window, unsigned int num_elems_read_per_i
}
}
+template <typename T1, typename T2>
+inline void convolve_8x8(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
+ const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info)
+{
+ const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
+ switch(conv_stride_x)
+ {
+ case 1:
+ convolver_8x8<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Not implemented");
+ }
+}
+
+
Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const PadStrideInfo &conv_info)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst);
@@ -834,7 +1061,12 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, co
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
ARM_COMPUTE_RETURN_ERROR_ON(data_layout == DataLayout::NHWC && src->data_type() != DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON((weights->dimension(width_idx) > 3) && (src->data_type() == DataType::F16));
+ if(data_layout == DataLayout::NCHW && weights->dimension(width_idx) == 8u &&
+ weights->dimension(width_idx) == 8u && src->data_type() == DataType::F32)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(std::get<0>(conv_info.stride()) != 1u);
+ }
// Checks performed when output is configured
if(dst->total_size() != 0)
{
@@ -932,6 +1164,24 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *src, ITenso
}
}
break;
+ case 8:
+ {
+ switch(src->data_type())
+ {
+ case DataType::F32:
+ if(conv_stride_x > 1) {
+ ARM_COMPUTE_ERROR("Stride > 1 not supported for kernel size 8 in NCHW.");
+ }
+ num_weight_elems_read_per_row = 4 + kernel_size - 1;
+ num_elems_read_per_iteration = 12;
+ num_elems_written_per_iteration = 4;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Data type not supported.");
+ break;
+ }
+ }
+ break;
default:
{
ARM_COMPUTE_ERROR("Not implemented");
@@ -1336,6 +1586,20 @@ void CpuDirectConv2dKernel::run_op(ITensorPack &tensors, const Window &window, c
}
break;
}
+
+ case 8:
+ {
+ switch(src->info()->data_type())
+ {
+ case DataType::F32:
+ convolve_8x8<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, src, weights, dst, _conv_info);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Data type not supported");
+ break;
+ }
+ break;
+ }
default:
{
ARM_COMPUTE_ERROR("Only kernel sizes 1x1, 3x3 and 5x5 are supported.");
diff --git a/tests/validation/NEON/DirectConvolutionLayer.cpp b/tests/validation/NEON/DirectConvolutionLayer.cpp
index 368aef216a..b6c2f0df1b 100644
--- a/tests/validation/NEON/DirectConvolutionLayer.cpp
+++ b/tests/validation/NEON/DirectConvolutionLayer.cpp
@@ -93,11 +93,23 @@ const auto data9x9 = combine(datasets::SmallDirectConvolutionShapes(),
combine(framework::dataset::make("PadY", { 0, 3 }),
framework::dataset::make("KernelSize", 9))))));
+
+const auto data8x8 = combine(datasets::SmallDirectConvolutionShapes(),
+ combine(framework::dataset::make("StrideX", { 1 }),
+ combine(framework::dataset::make("StrideY", { 1 }),
+ combine(framework::dataset::make("PadX", { 0 }),
+ combine(framework::dataset::make("PadY", { 0 }),
+ framework::dataset::make("KernelSize", 8))))));
+
+
+
const auto data_f32_nightly = combine(data_f32, framework::dataset::make("NumKernels", { 1, 4 }));
const auto data_f16_nightly = combine(data_f16, framework::dataset::make("NumKernels", { 1, 4 }));
const auto data_precommit = combine(data_prec, framework::dataset::make("NumKernels", { 1 }));
const auto data_precommit9x9 = combine(data9x9, framework::dataset::make("NumKernels", { 4 }));
+const auto data_precommit8x8 = combine(data8x8, framework::dataset::make("NumKernels", { 4 }));
+
/* The following tests is from real use-case that made DirectConvolution
* overflows in terms of its tensor indexing. This test case is using
@@ -326,6 +338,16 @@ FIXTURE_DATA_TEST_CASE(RunMixedDataLayout, NEDirectConvolutionLayerMixedDataLayo
// Validate output
validate(Accessor(_target), _reference, tolerance_fp32);
}
+
+FIXTURE_DATA_TEST_CASE(RunSmall8x8, NEDirectConvolutionLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(data_precommit8x8, framework::dataset::make("DataType",
+ DataType::F32)),
+ ActivationFunctionsDataset),
+ framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_fp32);
+}
+
FIXTURE_DATA_TEST_CASE(RunSmall9x9, NEDirectConvolutionLayerFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(data_precommit9x9, framework::dataset::make("DataType",
DataType::F32)),
ActivationFunctionsDataset),