aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com>2023-06-19 14:57:57 +0100
committerMohmun02 <MohammedSuhail.Munshi@arm.com>2023-06-26 11:34:03 +0000
commita2bb80ea7111509c24caad8629533089decef430 (patch)
treef674572e0cc705af9b66633bfcd9d6ad9e29d970
parentc952596e70f2fe0073029f053e329a4e930ced8c (diff)
downloadComputeLibrary-a2bb80ea7111509c24caad8629533089decef430.tar.gz
Use MatMul in fully connected layer with dynamic weights when supported
- Use MatMul kernels in FC layer when using dynamic weights without broadcasting or bias. - Fix minor typo in IClMatMulNativeKernelConfig.h Partially Resolves : [COMPMID-6193] Signed-off-by: Mohammed Suhail Munshi <MohammedSuhail.Munshi@arm.com> Change-Id: Id494062b5b4f4e75ff9714c202dde941955afa52 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9797 Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--src/core/CL/cl_kernels/common/mat_mul_quantized.cl10
-rw-r--r--src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp7
-rw-r--r--src/gpu/cl/operators/ClFullyConnected.cpp352
-rw-r--r--src/gpu/cl/operators/ClFullyConnected.h19
-rw-r--r--src/runtime/heuristics/matmul_native/IClMatMulNativeKernelConfig.h2
-rw-r--r--tests/validation/CL/FullyConnectedLayer.cpp38
-rw-r--r--tests/validation/fixtures/FullyConnectedLayerFixture.h63
7 files changed, 330 insertions, 161 deletions
diff --git a/src/core/CL/cl_kernels/common/mat_mul_quantized.cl b/src/core/CL/cl_kernels/common/mat_mul_quantized.cl
index bd415bb4a7..8cf857dd84 100644
--- a/src/core/CL/cl_kernels/common/mat_mul_quantized.cl
+++ b/src/core/CL/cl_kernels/common/mat_mul_quantized.cl
@@ -21,9 +21,9 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
+#include "activation_float_helpers.h"
#include "helpers.h"
#include "tile_helpers.h"
-#include "activation_float_helpers.h"
#if defined(MAT_MUL_NATIVE_QUANTIZED_NT_NT)
/** This OpenCL kernel performs the batch matrix multiplication (BatchMatMul): LHS non-transposed, RHS non-transposed - buffer only
@@ -189,7 +189,7 @@ __kernel void mat_mul_native_quantized_nt_nt(
{
LOOP_UNROLLING(int, j, 0, 1, N0,
{
- acc[i].s[j] += ((int)RHS_OFFSET) * a_sum[0].s[i] + ((int)(LHS_OFFSET)) * b_sum[0].s[j];
+ acc[i].s[j] -= ((int)RHS_OFFSET) * a_sum[0].s[i] + ((int)(LHS_OFFSET)) * b_sum[0].s[j];
})
})
@@ -368,7 +368,7 @@ __kernel void mat_mul_native_quantized_nt_t(
{
LOOP_UNROLLING(int, j, 0, 1, N0,
{
- acc[i].s[j] += ((int)(RHS_OFFSET)) * a_sum[0].s[i] + ((int)(LHS_OFFSET)) * b_sum[0].s[j];
+ acc[i].s[j] -= ((int)(RHS_OFFSET)) * a_sum[0].s[i] + ((int)(LHS_OFFSET)) * b_sum[0].s[j];
})
})
@@ -549,7 +549,7 @@ __kernel void mat_mul_native_quantized_t_nt(
{
LOOP_UNROLLING(int, j, 0, 1, N0,
{
- acc[i].s[j] += ((int)(RHS_OFFSET)) * a_sum[0].s[i] + ((int)(LHS_OFFSET)) * b_sum[0].s[j];
+ acc[i].s[j] -= ((int)(RHS_OFFSET)) * a_sum[0].s[i] + ((int)(LHS_OFFSET)) * b_sum[0].s[j];
})
})
@@ -734,7 +734,7 @@ __kernel void mat_mul_native_quantized_t_t(
{
LOOP_UNROLLING(int, j, 0, 1, N0,
{
- acc[i].s[j] += ((int)RHS_OFFSET) * a_sum[0].s[i] + ((int)(LHS_OFFSET)) * b_sum[0].s[j];
+ acc[i].s[j] -= ((int)RHS_OFFSET) * a_sum[0].s[i] + ((int)(LHS_OFFSET)) * b_sum[0].s[j];
})
})
diff --git a/src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp b/src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp
index 9bbec908a3..38d78c618b 100644
--- a/src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp
+++ b/src/gpu/cl/kernels/ClMatMulLowpNativeKernel.cpp
@@ -164,9 +164,10 @@ void ClMatMulLowpNativeKernel::configure(const ClCompileContext &compile_context
build_opts.add_option("-DDST_MULTIPLIER=" + support::cpp11::to_string(output_multiplier));
build_opts.add_option("-DDST_SHIFT=" + support::cpp11::to_string(output_shift));
- build_opts.add_option("-DLHS_OFFSET=" + support::cpp11::to_string(-lqinfo.offset)); // Note this is passed as negative to maintain similarity with CLDirectConv2D
- build_opts.add_option("-DRHS_OFFSET=" + support::cpp11::to_string(-rqinfo.offset)); // Note this is passed as negative to maintain similarity with CLDirectConv2D
- build_opts.add_option("-DDST_OFFSET=" + support::cpp11::to_string(dqinfo.offset)); // Passed as positive (unlike the above two)
+ // Note : Offset is not negated, unlike gemmlowp kernels
+ build_opts.add_option("-DLHS_OFFSET=" + support::cpp11::to_string(lqinfo.offset));
+ build_opts.add_option("-DRHS_OFFSET=" + support::cpp11::to_string(rqinfo.offset));
+ build_opts.add_option("-DDST_OFFSET=" + support::cpp11::to_string(dqinfo.offset)); // Passed as positive (unlike the above two)
build_opts.add_option(("-DA_VAL=" + float_to_string_with_full_precision(act_info.a())));
build_opts.add_option(("-DB_VAL=" + float_to_string_with_full_precision(act_info.b())));
diff --git a/src/gpu/cl/operators/ClFullyConnected.cpp b/src/gpu/cl/operators/ClFullyConnected.cpp
index b289cc0104..c62e4b531f 100644
--- a/src/gpu/cl/operators/ClFullyConnected.cpp
+++ b/src/gpu/cl/operators/ClFullyConnected.cpp
@@ -38,6 +38,12 @@
#include "src/gpu/cl/operators/ClTranspose.h"
#include "src/gpu/cl/utils/ClAuxTensorHandler.h"
+#include "src/gpu/cl/operators/ClMatMul.h"
+#include "utils/TypePrinter.h"
+
+#include "src/runtime/heuristics/matmul_native/ClMatMulNativeKernelConfig.h"
+#include "src/runtime/heuristics/matmul_native/IClMatMulNativeKernelConfig.h"
+
#include "src/common/utils/Log.h"
#include "support/Cast.h"
@@ -52,6 +58,12 @@ using namespace arm_compute::misc::shape_calculator;
namespace
{
+// Function to calculate batched tensor shape in format [M, 1, B0, B1 ..] which is the format matmul expects
+inline TensorShape get_reshaped_matmul_tensor(const TensorShape &src)
+{
+ return TensorShape(src.x(), 1, src.y(), src.collapsed_from(2).z()); // Return value optimisation
+}
+
Status construct_gemmlowp_output_stage(const ITensorInfo &src, const ITensorInfo &weights, const ITensorInfo &dst,
GEMMLowpOutputStageInfo &gemmlowp_output_stage, ActivationLayerInfo activation_info)
{
@@ -101,41 +113,61 @@ Status construct_gemmlowp_output_stage(const ITensorInfo &src, const ITensorInfo
Status validate_mm(const ITensorInfo &src, const ITensorInfo &weights, const ITensorInfo *bias, const ITensorInfo &dst, const FullyConnectedLayerInfo &fc_info)
{
- GEMMLowpOutputStageInfo gemmlowp_output_stage;
- ARM_COMPUTE_RETURN_ON_ERROR(construct_gemmlowp_output_stage(src, weights, dst, gemmlowp_output_stage, fc_info.activation_info));
-
- const GEMMInfo &gemm_info = GEMMInfo(false, // is_a_reshaped
- false, // is_b_reshaped
- true, // reshape_b_only_on_first_run
- 0, // depth_output_gemm3d
- false, // reinterpret_input_as_3d
- fc_info.retain_internal_weights, // retain_internal_weights
- gemmlowp_output_stage, // gemmlowp_output_stage
- fc_info.fp_mixed_precision, // fp_mixed_precision
- false, // fast_math
- true, // broadcast_bias
- ActivationLayerInfo()); // activation_info
-
- if(is_data_type_quantized_asymmetric(src.data_type()))
+ // If weights are dynamic, data is not batched, and bias is nullptr validate using matmul.
+ const bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
+ const bool use_matmul = !weights.are_values_constant() && !weights_reshaped && !(dst.dimension(1) > 1) && (bias != nullptr);
+
+ if(use_matmul)
{
- const UniformQuantizationInfo iq_info = src.quantization_info().uniform();
- const UniformQuantizationInfo wq_info = weights.quantization_info().uniform();
-
- // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
- // Extract and negate src and weights offset
- const QuantizationInfo src_quantization_info(iq_info.scale, -iq_info.offset);
- const QuantizationInfo weights_quantization_info(wq_info.scale, -wq_info.offset);
-
- // Validate gemmlowp function
- ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixMultiplyCore::validate(&src.clone()->set_quantization_info(src_quantization_info),
- &weights.clone()->set_quantization_info(weights_quantization_info),
- bias,
- &dst,
- gemm_info));
+ MatMulInfo m_info{};
+ m_info.adj_rhs(fc_info.transpose_weights);
+
+ // Note: Currently, shape is [M, B0, B1]
+ // LHS is reshaped here to match ClMatMul expectations of batch index in format - [M, 1, B0, B1, .. ]
+ TensorInfo lhs_to_use{ src };
+ lhs_to_use.set_tensor_shape(get_reshaped_matmul_tensor(src.tensor_shape()));
+
+ // Operator level validation.
+ ARM_COMPUTE_RETURN_ON_ERROR(ClMatMul::validate(&lhs_to_use, &weights, &dst, m_info, fc_info.activation_info));
}
else
{
- ARM_COMPUTE_RETURN_ON_ERROR(ClGemm::validate(&src, &weights, bias, &dst, 1.f, 1.f, gemm_info));
+ GEMMLowpOutputStageInfo gemmlowp_output_stage;
+ ARM_COMPUTE_RETURN_ON_ERROR(construct_gemmlowp_output_stage(src, weights, dst, gemmlowp_output_stage, fc_info.activation_info));
+
+ const GEMMInfo &gemm_info = GEMMInfo(false, // is_a_reshaped
+ false, // is_b_reshaped
+ true, // reshape_b_only_on_first_run
+ 0, // depth_output_gemm3d
+ false, // reinterpret_input_as_3d
+ fc_info.retain_internal_weights, // retain_internal_weights
+ gemmlowp_output_stage, // gemmlowp_output_stage
+ fc_info.fp_mixed_precision, // fp_mixed_precision
+ false, // fast_math
+ true, // broadcast_bias
+ ActivationLayerInfo()); // activation_info
+
+ if(is_data_type_quantized_asymmetric(src.data_type()))
+ {
+ const UniformQuantizationInfo iq_info = src.quantization_info().uniform();
+ const UniformQuantizationInfo wq_info = weights.quantization_info().uniform();
+
+ // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
+ // Extract and negate src and weights offset
+ const QuantizationInfo src_quantization_info(iq_info.scale, -iq_info.offset);
+ const QuantizationInfo weights_quantization_info(wq_info.scale, -wq_info.offset);
+
+ // Validate gemmlowp function
+ ARM_COMPUTE_RETURN_ON_ERROR(ClGemmLowpMatrixMultiplyCore::validate(&src.clone()->set_quantization_info(src_quantization_info),
+ &weights.clone()->set_quantization_info(weights_quantization_info),
+ bias,
+ &dst,
+ gemm_info));
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(ClGemm::validate(&src, &weights, bias, &dst, 1.f, 1.f, gemm_info));
+ }
}
return Status{};
@@ -148,6 +180,8 @@ ClFullyConnected::ClFullyConnected()
_reshape_weights(nullptr),
_mm_gemm(nullptr),
_mm_gemmlowp(nullptr),
+ _matmul_native_kernel(nullptr),
+ _matmul_lowp_native_kernel(nullptr),
_aux_mem(Count)
{
}
@@ -157,50 +191,85 @@ ClFullyConnected::~ClFullyConnected() = default;
void ClFullyConnected::configure_mm(const CLCompileContext &compile_context, ITensorInfo *src, ITensorInfo *weights, ITensorInfo *bias, ITensorInfo *dst,
const FullyConnectedLayerInfo &fc_info)
{
- GEMMLowpOutputStageInfo gemmlowp_output_stage;
- construct_gemmlowp_output_stage(*src, *weights, *dst, gemmlowp_output_stage, fc_info.activation_info);
-
- const GEMMInfo &gemm_info = GEMMInfo(false, // is_a_reshaped
- false, // is_b_reshaped
- !_dynamic_weights, // reshape_b_only_on_first_run
- 0, // depth_output_gemm3d
- false, // reinterpret_input_as_3d
- fc_info.retain_internal_weights, // retain_internal_weights
- gemmlowp_output_stage, // gemmlowp_output_stage
- fc_info.fp_mixed_precision, // fp_mixed_precision
- false, // fast_math
- true, // broadcast_bias
- fc_info.activation_info); // activation_info
-
- if(_is_quantized)
+ // If weights are dynamic, configure matmul operator - else use gemm
+ if(_use_matmul)
{
- // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
- // Extract and negate input and weights offset
- const QuantizationInfo src_quantization_info = src->quantization_info();
- const QuantizationInfo weights_quantization_info = weights->quantization_info();
-
- TensorInfo src_info = src->clone()->set_quantization_info(src_quantization_info);
- TensorInfo weights_info = weights->clone()->set_quantization_info(weights_quantization_info);
-
- src_info.set_quantization_info(QuantizationInfo(src_quantization_info.uniform().scale, -src_quantization_info.uniform().offset));
- weights_info.set_quantization_info(QuantizationInfo(weights_quantization_info.uniform().scale, -weights_quantization_info.uniform().offset));
-
- // Configure gemmlowp function
- _mm_gemmlowp = std::make_unique<ClGemmLowpMatrixMultiplyCore>();
- _mm_gemmlowp->configure(compile_context, &src_info, &weights_info, bias, dst, gemm_info);
+ // Transpose RHS as _are_weights_reshaped == false when mat_mul is used.
+ const MatMulInfo mat_info = MatMulInfo().adj_rhs(fc_info.transpose_weights);
+
+ // Note: MatMul does not need offset negation unlike gemm
+ // 1. Change shape when calling matmul to fit batch expectations.
+ _lhs_to_use = *src->clone();
+ _lhs_to_use.set_tensor_shape(get_reshaped_matmul_tensor(_lhs_to_use.tensor_shape())); // Collapse all dims > 2 into final dimension.
+ _is_quantized = is_data_type_quantized_asymmetric(_lhs_to_use.data_type());
+
+ // 2. Call kernel for matmul directly.
+ const GPUTarget gpu_target = CLScheduler::get().target();
+ std::unique_ptr<cl_matmul::IClMatMulNativeKernelConfig> kernel_config = cl_matmul::ClMatMulNativeKernelConfigurationFactory::create(gpu_target);
+
+ // Configure relevant matmul kernel
+ MatMulKernelInfo kernel_info = kernel_config->configure(src, weights, mat_info);
+ if(_is_quantized)
+ {
+ _matmul_lowp_native_kernel = std::make_unique<kernels::ClMatMulLowpNativeKernel>();
+ _matmul_lowp_native_kernel->set_target(gpu_target);
+ _matmul_lowp_native_kernel->configure(compile_context, src, weights, dst, kernel_info, fc_info.activation_info);
+ }
+ else
+ {
+ _matmul_native_kernel = std::make_unique<kernels::ClMatMulNativeKernel>();
+ _matmul_native_kernel->set_target(gpu_target);
+ _matmul_native_kernel->configure(compile_context, src, weights, dst, kernel_info, fc_info.activation_info);
+ }
}
else
{
- // Configure matrix multiply kernel
- _mm_gemm = std::make_unique<ClGemm>();
- _mm_gemm->configure(compile_context, src, weights, bias, dst, 1.f, 1.f, gemm_info);
+ // Configure GEMM
+ GEMMLowpOutputStageInfo gemmlowp_output_stage;
+ construct_gemmlowp_output_stage(*src, *weights, *dst, gemmlowp_output_stage, fc_info.activation_info);
+
+ const GEMMInfo &gemm_info = GEMMInfo(false, // is_a_reshaped
+ false, // is_b_reshaped
+ !_dynamic_weights, // reshape_b_only_on_first_run
+ 0, // depth_output_gemm3d
+ false, // reinterpret_input_as_3d
+ fc_info.retain_internal_weights, // retain_internal_weights
+ gemmlowp_output_stage, // gemmlowp_output_stage
+ fc_info.fp_mixed_precision, // fp_mixed_precision
+ false, // fast_math
+ true, // broadcast_bias
+ fc_info.activation_info); // activation_info
+
+ if(_is_quantized)
+ {
+ // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
+ // Extract and negate input and weights offset
+ const QuantizationInfo src_quantization_info = src->quantization_info();
+ const QuantizationInfo weights_quantization_info = weights->quantization_info();
+
+ TensorInfo src_info = src->clone()->set_quantization_info(src_quantization_info);
+ TensorInfo weights_info = weights->clone()->set_quantization_info(weights_quantization_info);
+
+ src_info.set_quantization_info(QuantizationInfo(src_quantization_info.uniform().scale, -src_quantization_info.uniform().offset));
+ weights_info.set_quantization_info(QuantizationInfo(weights_quantization_info.uniform().scale, -weights_quantization_info.uniform().offset));
+
+ // Configure gemmlowp function
+ _mm_gemmlowp = std::make_unique<ClGemmLowpMatrixMultiplyCore>();
+ _mm_gemmlowp->configure(compile_context, &src_info, &weights_info, bias, dst, gemm_info);
+ }
+ else
+ {
+ // Configure matrix multiply kernel
+ _mm_gemm = std::make_unique<ClGemm>();
+ _mm_gemm->configure(compile_context, src, weights, bias, dst, 1.f, 1.f, gemm_info);
+ }
}
}
void ClFullyConnected::configure_conv_fc(const CLCompileContext &compile_context, ITensorInfo *src, ITensorInfo *weights, ITensorInfo *bias, ITensorInfo *dst,
const FullyConnectedLayerInfo &fc_info)
{
- ARM_COMPUTE_ERROR_ON((weights->dimension(1) != (src->dimension(0) * src->dimension(1) * src->dimension(2))));
+ ARM_COMPUTE_ERROR_ON((weights->dimension((_use_matmul) ? 0 : 1) != (src->dimension(0) * src->dimension(1) * src->dimension(2))));
// If the fully connected layer is called after a convolution layer, the input tensor must be linearized
@@ -211,6 +280,7 @@ void ClFullyConnected::configure_conv_fc(const CLCompileContext &compile_context
_flatten = std::make_unique<ClFlatten>();
_flatten->configure(compile_context, src, &_flattened_src);
+ // Note: if flatten has > 1 dimensions after, these dimensions are batch
// Configure matrix multiply kernel
configure_mm(compile_context, &_flattened_src, weights, bias, dst, fc_info);
}
@@ -218,7 +288,8 @@ void ClFullyConnected::configure_conv_fc(const CLCompileContext &compile_context
void ClFullyConnected::configure_fc_fc(const CLCompileContext &compile_context, ITensorInfo *src, ITensorInfo *weights, ITensorInfo *bias, ITensorInfo *dst,
const FullyConnectedLayerInfo &fc_info)
{
- ARM_COMPUTE_ERROR_ON(src->dimension(0) != weights->dimension(1));
+ // Compare first dimension when using matmul, as it performs transpose operation
+ ARM_COMPUTE_ERROR_ON(src->dimension(0) != weights->dimension((_use_matmul) ? 0 : 1));
// Configure matrix multiply kernel
configure_mm(compile_context, src, weights, bias, dst, fc_info);
@@ -240,7 +311,13 @@ void ClFullyConnected::configure(const CLCompileContext &compile_context, ITenso
_is_prepared = fc_info.retain_internal_weights;
_weights_to_use = TensorInfo(*weights);
_weights_to_use_idx = ACL_SRC_1;
- _dynamic_weights = !weights->are_values_constant() && !_are_weights_reshaped;
+
+ // When using dynamic weights - use matmul kernels.
+ // Note: We don't appear to support dynamic weights with pre-reshaped RHS.
+ // Note: No matmul with biases for the moment.
+ const bool is_batched_fc_layer = dst->dimension(1) > 1;
+ _dynamic_weights = !weights->are_values_constant() && !_are_weights_reshaped;
+ _use_matmul = _dynamic_weights && !is_batched_fc_layer && !(biases);
// With the Fully Connected layer we can have 4 different cases:
// 1) Convolution layer -> Fully Connected layer without batches
@@ -249,7 +326,6 @@ void ClFullyConnected::configure(const CLCompileContext &compile_context, ITenso
// 4) Fully Connected layer -> Fully Connected layer with batches
// Check if we have a fully connected layer with batches
- const bool is_batched_fc_layer = dst->dimension(1) > 1;
if(is_batched_fc_layer)
{
_is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3,
@@ -264,7 +340,8 @@ void ClFullyConnected::configure(const CLCompileContext &compile_context, ITenso
ITensorInfo *weights_used = weights;
// Reshape weights if needed
- if(!_are_weights_reshaped)
+ // Not needed when matmul is in use - MatMul has transpose RHS flags.
+ if(!_are_weights_reshaped && !_use_matmul)
{
// Reshape the weights
_reshape_weights = std::make_unique<ClTranspose>();
@@ -302,39 +379,47 @@ void ClFullyConnected::configure(const CLCompileContext &compile_context, ITenso
// Update TensorInfo of final weights used (Need to be done in the end due to padding expansion)
_weights_to_use = *weights_used;
- // Set auxiliary memory requirements
- auto gemm_mem_req = (_is_quantized) ? _mm_gemmlowp->workspace() : _mm_gemm->workspace();
- for(unsigned int i = 0; i < gemm_mem_req.size(); ++i)
+ if(_use_matmul)
{
- _aux_mem[i] = gemm_mem_req[i];
- }
- if(_aux_mem[1].size > 0 || _aux_mem[2].size > 0) // Persistent weights memory on GEMMs
- {
- // Release permuted weights at the of prepare as they are further transposed by the assembly dispatch
- // Keep all the auxiliary tensors in case of dynamic weights as they are recalculated every time
- _aux_mem[TransposedWeights] = MemoryInfo(
- offset_int_vec(TransposedWeights),
- _dynamic_weights ? MemoryLifetime::Temporary : MemoryLifetime::Prepare,
- _reshaped_weights.total_size());
- _aux_mem[ConvertedWeights] = MemoryInfo(
- offset_int_vec(ConvertedWeights),
- _dynamic_weights ? MemoryLifetime::Temporary : MemoryLifetime::Prepare,
- _converted_weights.total_size());
+ // Note : MatMul does not use transpose and does not need auxillary memory, so only converted weights are added to aux_mem
+ _aux_mem[ConvertedWeights] = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Temporary, _converted_weights.total_size());
}
else
{
- // Release permuted weights at the of prepare as they are further transposed by the assembly dispatch
- const auto transposed_wei_lft = (_weights_to_use_idx == offset_int_vec(TransposedWeights)) ? MemoryLifetime::Persistent : MemoryLifetime::Prepare;
- const auto converted_wei_lft = (_weights_to_use_idx == offset_int_vec(ConvertedWeights)) ? MemoryLifetime::Persistent : MemoryLifetime::Prepare;
-
- _aux_mem[TransposedWeights] = MemoryInfo(
- offset_int_vec(TransposedWeights),
- _dynamic_weights ? MemoryLifetime::Temporary : transposed_wei_lft,
- _reshaped_weights.total_size());
- _aux_mem[ConvertedWeights] = MemoryInfo(
- offset_int_vec(ConvertedWeights),
- _dynamic_weights ? MemoryLifetime::Temporary : converted_wei_lft,
- _converted_weights.total_size());
+ // Set auxiliary memory requirements for gemm operators
+ auto gemm_mem_req = (_is_quantized) ? _mm_gemmlowp->workspace() : _mm_gemm->workspace();
+ for(unsigned int i = 0; i < gemm_mem_req.size(); ++i)
+ {
+ _aux_mem[i] = gemm_mem_req[i];
+ }
+ if(_aux_mem[1].size > 0 || _aux_mem[2].size > 0) // Persistent weights memory on GEMMs
+ {
+ // Release permuted weights at the of prepare as they are further transposed by the assembly dispatch
+ // Keep all the auxiliary tensors in case of dynamic weights as they are recalculated every time
+ _aux_mem[TransposedWeights] = MemoryInfo(
+ offset_int_vec(TransposedWeights),
+ _dynamic_weights ? MemoryLifetime::Temporary : MemoryLifetime::Prepare,
+ _reshaped_weights.total_size());
+ _aux_mem[ConvertedWeights] = MemoryInfo(
+ offset_int_vec(ConvertedWeights),
+ _dynamic_weights ? MemoryLifetime::Temporary : MemoryLifetime::Prepare,
+ _converted_weights.total_size());
+ }
+ else
+ {
+ // Release permuted weights at the of prepare as they are further transposed by the assembly dispatch
+ const auto transposed_wei_lft = (_weights_to_use_idx == offset_int_vec(TransposedWeights)) ? MemoryLifetime::Persistent : MemoryLifetime::Prepare;
+ const auto converted_wei_lft = (_weights_to_use_idx == offset_int_vec(ConvertedWeights)) ? MemoryLifetime::Persistent : MemoryLifetime::Prepare;
+
+ _aux_mem[TransposedWeights] = MemoryInfo(
+ offset_int_vec(TransposedWeights),
+ _dynamic_weights ? MemoryLifetime::Temporary : transposed_wei_lft,
+ _reshaped_weights.total_size());
+ _aux_mem[ConvertedWeights] = MemoryInfo(
+ offset_int_vec(ConvertedWeights),
+ _dynamic_weights ? MemoryLifetime::Temporary : converted_wei_lft,
+ _converted_weights.total_size());
+ }
}
_aux_mem[FlattenedSrc] = MemoryInfo(offset_int_vec(FlattenedSrc), MemoryLifetime::Temporary, _flattened_src.total_size());
}
@@ -349,8 +434,15 @@ Status ClFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *wei
ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(src->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU
&& fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
- bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
- bool is_fc_after_conv = true;
+ const bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
+ bool is_fc_after_conv = true;
+
+ // When using dynamic weights - use matmul kernels.
+ // Note: MatMul does not support broadcasting or biases so fallback with batched cases or when biases != nullptr.
+ // Note: Pre-Shaped RHS is a deprecated use case and is therefore not supported with matmul.
+ const bool dynamic_weights = !weights->are_values_constant() && !weights_reshaped;
+ const bool is_batched_fc_layer = dst->dimension(1) > 1;
+ const bool use_matmul = dynamic_weights && !is_batched_fc_layer && (biases != nullptr);
const ITensorInfo &flatten_src = TensorInfo(src->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_flatten_shape(src)).set_data_layout(DataLayout::NCHW));
const ITensorInfo &reshaped_weights = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(compute_transposed_shape(*weights)));
@@ -378,8 +470,7 @@ Status ClFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *wei
}
}
- // Check if we have a fully connected layer with batches
- const bool is_batched_fc_layer = dst->dimension(1) > 1;
+ // Check if FC is after conv (flatten kernel is run in case where FC is after conv.)
if(is_batched_fc_layer)
{
is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3,
@@ -391,7 +482,7 @@ Status ClFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *wei
is_fc_after_conv = src->num_dimensions() > 1;
}
- if(!weights_reshaped)
+ if(!weights_reshaped && !use_matmul)
{
// Validate reshape weights kernel
ARM_COMPUTE_RETURN_ON_ERROR(ClTranspose::validate(weights, &reshaped_weights));
@@ -411,7 +502,14 @@ Status ClFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *wei
if(is_fc_after_conv)
{
// Fully Connected layer after a Convolution Layer without batches
- ARM_COMPUTE_RETURN_ERROR_ON((weights_to_use->dimension(1) != (src->dimension(0) * src->dimension(1) * src->dimension(2))));
+ if(use_matmul)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON((weights_to_use->dimension(0) != (src->dimension(0) * src->dimension(1) * src->dimension(2))));
+ }
+ else
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON((weights_to_use->dimension(1) != (src->dimension(0) * src->dimension(1) * src->dimension(2))));
+ }
// Validate flatten kernel
ARM_COMPUTE_RETURN_ON_ERROR(ClFlatten::validate(src, &flatten_src));
@@ -420,7 +518,7 @@ Status ClFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *wei
else
{
// Fully Connected layer after a Fully Connected Layer without batches
- ARM_COMPUTE_RETURN_ERROR_ON(src->dimension(0) != weights_to_use->dimension(1));
+ ARM_COMPUTE_RETURN_ERROR_ON(src->dimension(0) != weights_to_use->dimension((use_matmul) ? 0 : 1));
}
// Validate matrix multiply kernel
@@ -457,14 +555,30 @@ void ClFullyConnected::run(ITensorPack &tensors)
gemm_pack.add_const_tensor(ACL_SRC_1, weights.get());
}
- // Run matrix multiply
- if(_is_quantized)
+ // Run MatMul Op
+ if(_use_matmul)
{
- _mm_gemmlowp->run(gemm_pack);
+ // Run matmul kernels for matrix multiplication
+ if(_is_quantized)
+ {
+ CLScheduler::get().enqueue_op(*_matmul_lowp_native_kernel, gemm_pack, true);
+ }
+ else
+ {
+ CLScheduler::get().enqueue_op(*_matmul_native_kernel, gemm_pack, true);
+ }
}
else
{
- _mm_gemm->run(gemm_pack);
+ // Run matrix multiply
+ if(_is_quantized)
+ {
+ _mm_gemmlowp->run(gemm_pack);
+ }
+ else
+ {
+ _mm_gemm->run(gemm_pack);
+ }
}
}
@@ -486,7 +600,7 @@ void ClFullyConnected::prepare(ITensorPack &tensors)
const ITensor *cur_weights = weights;
// Reshape of the weights if needed
- if(!_are_weights_reshaped)
+ if(!_are_weights_reshaped && !_use_matmul)
{
// Run reshape weights kernel and mark weights as unused
ITensorPack transpose_pack{ { ACL_SRC, weights }, { ACL_DST, reshaped_weights.get() } };
@@ -509,15 +623,19 @@ void ClFullyConnected::prepare(ITensorPack &tensors)
ITensorPack gemm_pack = tensors;
gemm_pack.add_const_tensor(ACL_SRC_1, cur_weights);
- // Prepare GEMM prepare and release unused weights
- if(!_is_quantized)
+ // Prepare GEMM prepare and release unused weights (If not using matmul)
+ if(!_use_matmul)
{
- _mm_gemm->prepare(gemm_pack);
- }
- else
- {
- _mm_gemmlowp->prepare(gemm_pack);
+ if(!_is_quantized)
+ {
+ _mm_gemm->prepare(gemm_pack);
+ }
+ else
+ {
+ _mm_gemmlowp->prepare(gemm_pack);
+ }
}
+
_is_prepared = true;
}
}
diff --git a/src/gpu/cl/operators/ClFullyConnected.h b/src/gpu/cl/operators/ClFullyConnected.h
index 11a59b2359..5dc68c1bbe 100644
--- a/src/gpu/cl/operators/ClFullyConnected.h
+++ b/src/gpu/cl/operators/ClFullyConnected.h
@@ -42,7 +42,12 @@ class ClFlatten;
class ClGemm;
class ClGemmLowpMatrixMultiplyCore;
class ClTranspose;
-
+// Kernel Forward Declarations
+namespace kernels
+{
+class ClMatMulNativeKernel;
+class ClMatMulLowpNativeKernel;
+}
/** Basic function to compute a Fully Connected layer on OpenCL. This function calls the following OpenCL kernels:
*
* -# @ref opencl::kernels::ClIm2ColKernel (called when the input comes from a convolutional layer)
@@ -119,12 +124,19 @@ private:
std::unique_ptr<ClGemm> _mm_gemm;
std::unique_ptr<ClGemmLowpMatrixMultiplyCore> _mm_gemmlowp;
+ std::unique_ptr<kernels::ClMatMulNativeKernel> _matmul_native_kernel;
+ std::unique_ptr<kernels::ClMatMulLowpNativeKernel> _matmul_lowp_native_kernel;
+
experimental::MemoryRequirements _aux_mem{};
TensorInfo _flattened_src{};
TensorInfo _converted_weights{};
TensorInfo _reshaped_weights{};
+ // Saved tensor shapes for reshaping when using matmul
+ TensorShape _lhs_shape_original{};
+ TensorInfo _lhs_to_use{};
+
TensorInfo _weights_to_use{};
int _weights_to_use_idx{ ACL_SRC_1 };
@@ -134,10 +146,11 @@ private:
bool _is_quantized{ false };
bool _is_prepared{ false };
bool _dynamic_weights{ false };
+ bool _use_matmul{ false };
#ifdef ARM_COMPUTE_ASSERTS_ENABLED
- int _asrt_run_count{};
- int _asrt_prepare_count{};
+ int _asrt_run_count {};
+ int _asrt_prepare_count{};
#endif // ARM_COMPUTE_ASSERTS_ENABLED
};
} // namespace opencl
diff --git a/src/runtime/heuristics/matmul_native/IClMatMulNativeKernelConfig.h b/src/runtime/heuristics/matmul_native/IClMatMulNativeKernelConfig.h
index 203f68c253..60e838c5cb 100644
--- a/src/runtime/heuristics/matmul_native/IClMatMulNativeKernelConfig.h
+++ b/src/runtime/heuristics/matmul_native/IClMatMulNativeKernelConfig.h
@@ -111,6 +111,6 @@ public:
protected:
GPUTarget _target;
};
-} // namespace opencl
+} // namespace cl_matmul
} // namespace arm_compute
#endif /* SRC_RUNTIME_HEURISTICS_MATMUL_NATIVE_ICLMATMULNATIVEKERNELCONFIG */
diff --git a/tests/validation/CL/FullyConnectedLayer.cpp b/tests/validation/CL/FullyConnectedLayer.cpp
index 9213ab541d..474a87dd1c 100644
--- a/tests/validation/CL/FullyConnectedLayer.cpp
+++ b/tests/validation/CL/FullyConnectedLayer.cpp
@@ -131,6 +131,8 @@ template <typename T>
using CLFullyConnectedLayerMixedDataLayoutFixture = FullyConnectedLayerValidationFixture<CLTensor, CLAccessor, CLFullyConnectedLayer, T, true>;
template <typename T>
using CLFullyConnectedLayerDynamicWeightsFixture = FullyConnectedWithDynamicWeightsFixture<CLTensor, CLAccessor, CLFullyConnectedLayer, T>;
+template <typename T>
+using CLFullyConnectedNoBiasFixture = FullyConnectedDynamicNoBiasFixture<CLTensor, CLAccessor, CLFullyConnectedLayer, T>;
TEST_SUITE(Float)
TEST_SUITE(FP16)
@@ -151,9 +153,9 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLFullyConnectedLayerFixture<half>, framework::
validate(CLAccessor(_target), _reference, tolerance_f16, tolerance_num);
}
FIXTURE_DATA_TEST_CASE(RunDynamicWeights, CLFullyConnectedLayerDynamicWeightsFixture<half>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(),
- framework::dataset::make("DataType", DataType::F16)),
- framework::dataset::make("ActivationInfo", ActivationLayerInfo())),
- framework::dataset::make("WeightsReshaped", { false, true })))
+ framework::dataset::make("DataType", DataType::F16)),
+ framework::dataset::make("ActivationInfo", ActivationLayerInfo())),
+ framework::dataset::make("WeightsReshaped", { false, true })))
{
}
TEST_SUITE_END()
@@ -179,9 +181,15 @@ FIXTURE_DATA_TEST_CASE(RunMixedDataLayout, CLFullyConnectedLayerMixedDataLayoutF
validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0, abs_tolerance_f32);
}
FIXTURE_DATA_TEST_CASE(RunDynamicWeights, CLFullyConnectedLayerDynamicWeightsFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(),
- framework::dataset::make("DataType", DataType::F32)),
- framework::dataset::make("ActivationInfo", ActivationLayerInfo())),
- framework::dataset::make("WeightsReshaped", { false, true })))
+ framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("ActivationInfo", ActivationLayerInfo())),
+ framework::dataset::make("WeightsReshaped", { false, true })))
+{
+}
+FIXTURE_DATA_TEST_CASE(RunDynamicNoBias, CLFullyConnectedNoBiasFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(),
+ framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("ActivationInfo", { ActivationLayerInfo(), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU) })),
+ framework::dataset::make("WeightsReshaped", { false })))
{
}
FIXTURE_DATA_TEST_CASE(RunLarge, CLFullyConnectedLayerFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeFullyConnectedLayerDataset(), FullyConnectedParameters),
@@ -230,9 +238,9 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLFullyConnectedLayerQuantizedFixture<uint8_t>,
validate(CLAccessor(_target), _reference, tolerance_qasymm8);
}
FIXTURE_DATA_TEST_CASE(RunDynamicWeights, CLFullyConnectedLayerDynamicWeightsFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(),
- framework::dataset::make("DataType", DataType::QASYMM8)),
- framework::dataset::make("ActivationInfo", ActivationLayerInfo())),
- framework::dataset::make("WeightsReshaped", { false /* COMPMID-6000: Support FullyConnected with quantized dynamic weights already reshaped */ })))
+ framework::dataset::make("DataType", DataType::QASYMM8)),
+ framework::dataset::make("ActivationInfo", ActivationLayerInfo())),
+ framework::dataset::make("WeightsReshaped", { false /* COMPMID-6000: Support FullyConnected with quantized dynamic weights already reshaped */ })))
{
}
TEST_SUITE_END() /* QASYMM8 */
@@ -259,9 +267,15 @@ FIXTURE_DATA_TEST_CASE(RunMixedDataLayout, CLFullyConnectedLayerQuantizedMixedDa
validate(CLAccessor(_target), _reference, tolerance_qasymm8);
}
FIXTURE_DATA_TEST_CASE(RunDynamicWeights, CLFullyConnectedLayerDynamicWeightsFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(),
- framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
- framework::dataset::make("ActivationInfo", ActivationLayerInfo())),
- framework::dataset::make("WeightsReshaped", { false /* COMPMID-6000: Support FullyConnected with quantized dynamic weights already reshaped */ })))
+ framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("ActivationInfo", ActivationLayerInfo())),
+ framework::dataset::make("WeightsReshaped", { false /* COMPMID-6000: Support FullyConnected with quantized dynamic weights already reshaped */ })))
+{
+}
+FIXTURE_DATA_TEST_CASE(RunDynamicNoBias, CLFullyConnectedNoBiasFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallFullyConnectedLayerDataset(),
+ framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("ActivationInfo", ActivationLayerInfo())),
+ framework::dataset::make("WeightsReshaped", { false /* COMPMID-6000: Support FullyConnected with quantized dynamic weights already reshaped */ })))
{
}
TEST_SUITE_END() // QASYMM8_SIGNED
diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h
index 75bef144ad..e13c01d1e2 100644
--- a/tests/validation/fixtures/FullyConnectedLayerFixture.h
+++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h
@@ -335,9 +335,9 @@ private:
void validate_with_tolerance(TensorType &target, SimpleTensor<half_float::half> &ref)
{
- constexpr AbsoluteTolerance<float> abs_tolerance_f16(0.3f);
+ constexpr AbsoluteTolerance<float> abs_tolerance_f16(0.3f);
const RelativeTolerance<half_float::half> rel_tolerance_f16(half_float::half(0.2f));
- constexpr float tolerance_num_f16 = 0.07f;
+ constexpr float tolerance_num_f16 = 0.07f;
validate(AccessorType(target), ref, rel_tolerance_f16, tolerance_num_f16, abs_tolerance_f16);
}
@@ -360,36 +360,36 @@ public:
template <typename...>
void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape,
- DataType data_type, ActivationLayerInfo activation_info, bool constant_weights, bool constant_bias, bool weights_reshaped)
+ DataType data_type, ActivationLayerInfo activation_info, bool constant_weights, bool constant_bias, bool weights_reshaped, bool remove_bias = false)
{
_data_type = data_type;
- const bool is_quantized = is_data_type_quantized(data_type);
-
+ const bool is_quantized = is_data_type_quantized(data_type);
const DataType bias_data_type = (is_quantized) ? DataType::S32 : data_type;
const QuantizationInfo src_qinfo = is_quantized ? QuantizationInfo(0.1f, 10) : QuantizationInfo();
const QuantizationInfo weights_qinfo = is_quantized ? QuantizationInfo(0.3f, 20) : QuantizationInfo();
const QuantizationInfo dst_qinfo = is_quantized ? QuantizationInfo(0.2f, 5) : QuantizationInfo();
- // Setup tensor meta-data
+ // Configure TensorInfo Objects
const TensorInfo src_info(src_shape, 1, data_type, src_qinfo);
- _src.allocator()->init(src_info);
+ const TensorInfo dst_info(dst_shape, 1, data_type, dst_qinfo);
+ TensorInfo bias_info(bias_shape, 1, bias_data_type);
+ TensorInfo wei_info(weights_shape, 1, data_type, weights_qinfo);
- TensorInfo wei_info(weights_shape, 1, data_type, weights_qinfo);
if(!constant_weights && weights_reshaped)
{
const TensorShape tr_weights_shape{ weights_shape[1], weights_shape[0] };
wei_info.set_tensor_shape(tr_weights_shape);
}
wei_info.set_are_values_constant(constant_weights);
- _weights.allocator()->init(wei_info);
-
- TensorInfo bias_info(bias_shape, 1, bias_data_type);
bias_info.set_are_values_constant(constant_bias);
- _bias.allocator()->init(bias_info);
- const TensorInfo dst_info(dst_shape, 1, data_type, dst_qinfo);
+ // Initialise Tensors
+ _src.allocator()->init(src_info);
+ _weights.allocator()->init(wei_info);
+ if(!remove_bias)
+ _bias.allocator()->init(bias_info);
_dst.allocator()->init(dst_info);
// Configure FC layer and mark the weights as non constant
@@ -401,12 +401,13 @@ public:
fc_info.transpose_weights = !weights_reshaped;
}
FunctionType fc;
- fc.configure(&_src, &_weights, &_bias, &_dst, fc_info);
+ fc.configure(&_src, &_weights, (remove_bias) ? nullptr : &_bias, &_dst, fc_info);
// Allocate all the tensors
_src.allocator()->allocate();
_weights.allocator()->allocate();
- _bias.allocator()->allocate();
+ if(!remove_bias)
+ _bias.allocator()->allocate();
_dst.allocator()->allocate();
// Run multiple iterations with different inputs
@@ -424,11 +425,20 @@ public:
fill(AccessorType(_weights), 1);
fill(weights, 1);
}
- if(constant_bias)
+ if(constant_bias && !remove_bias)
{
fill(AccessorType(_bias), 2);
fill(bias, 2);
}
+ // To remove bias, fill with 0
+ if(remove_bias && is_quantized)
+ {
+ library->fill_tensor_value(bias, 0);
+ }
+ else if(remove_bias)
+ {
+ library->fill_tensor_value(bias, (float)0.0);
+ }
for(int i = 0; i < num_iterations; ++i)
{
@@ -446,7 +456,7 @@ public:
fill(AccessorType(_weights), randomizer_offset + 1);
}
}
- if(!constant_bias)
+ if(!constant_bias && !remove_bias)
{
fill(AccessorType(_bias), randomizer_offset + 2);
}
@@ -462,7 +472,7 @@ public:
{
fill(weights, randomizer_offset + 1);
}
- if(!constant_bias)
+ if(!constant_bias && !remove_bias)
{
fill(bias, randomizer_offset + 2);
}
@@ -491,7 +501,20 @@ public:
DataType data_type, ActivationLayerInfo activation_info, bool weights_reshaped)
{
FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, weights_shape, bias_shape,
- dst_shape, data_type, activation_info, false, true, weights_reshaped);
+ dst_shape, data_type, activation_info, false, true, weights_reshaped, false);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class FullyConnectedDynamicNoBiasFixture : public FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape,
+ DataType data_type, ActivationLayerInfo activation_info, bool weights_reshaped)
+ {
+ FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, weights_shape, bias_shape,
+ dst_shape, data_type, activation_info, false, true, weights_reshaped, true);
}
};
@@ -504,7 +527,7 @@ public:
DataType data_type, ActivationLayerInfo activation_info)
{
FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, weights_shape, bias_shape,
- dst_shape, data_type, activation_info, true, false, false /* weights_reshaped (not used) */);
+ dst_shape, data_type, activation_info, true, false, false, false /* weights_reshaped (not used) */);
}
};
} // namespace validation