diff options
Diffstat (limited to 'src/gpu/cl/operators')
-rw-r--r-- | src/gpu/cl/operators/ClFullyConnected.cpp | 18 | ||||
-rw-r--r-- | src/gpu/cl/operators/ClGemm.cpp | 4 |
2 files changed, 5 insertions, 17 deletions
diff --git a/src/gpu/cl/operators/ClFullyConnected.cpp b/src/gpu/cl/operators/ClFullyConnected.cpp index bd2fddad0b..8b7e336c9f 100644 --- a/src/gpu/cl/operators/ClFullyConnected.cpp +++ b/src/gpu/cl/operators/ClFullyConnected.cpp @@ -169,7 +169,8 @@ void ClFullyConnected::configure_mm(const CLCompileContext &compile_context, ITe fc_info.fp_mixed_precision, // fp_mixed_precision false, // fast_math true, // broadcast_bias - fc_info.activation_info); // activation_info + fc_info.activation_info, // activation_info + fc_info.constant_weights); // constant_weights if(_is_quantized) { @@ -332,7 +333,7 @@ Status ClFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *wei ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2); ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(src->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU); - ARM_COMPUTE_RETURN_ERROR_ON(!weights->are_values_constant() && (!fc_info.are_weights_reshaped || fc_info.transpose_weights)); + ARM_COMPUTE_RETURN_ERROR_ON(!fc_info.constant_weights && (!fc_info.are_weights_reshaped || fc_info.transpose_weights)); bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true; bool is_fc_after_conv = true; @@ -350,19 +351,6 @@ Status ClFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *wei const ITensorInfo *src_to_use = src; const ITensorInfo *weights_to_use = weights; - if(biases != nullptr) - { - ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1); - if(is_data_type_quantized(src->data_type())) - { - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32); - } - else - { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, biases); - } - } - // Check if we have a fully connected layer with batches const bool is_batched_fc_layer = dst->dimension(1) > 1; if(is_batched_fc_layer) diff --git a/src/gpu/cl/operators/ClGemm.cpp b/src/gpu/cl/operators/ClGemm.cpp index 292f531dc4..625c057cf4 100644 --- a/src/gpu/cl/operators/ClGemm.cpp +++ b/src/gpu/cl/operators/ClGemm.cpp @@ -574,7 +574,7 @@ void ClGemm::configure(const CLCompileContext &compile_context, ITensorInfo *a, // Select GEMMType _gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery{ CLScheduler::get().target(), a->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run, - b->are_values_constant()); + gemm_info.constant_weights()); const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); @@ -623,7 +623,7 @@ Status ClGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso { CLScheduler::get().target(), a->data_type(), m, n, k, batch_size, }, - gemm_info.reshape_b_only_on_first_run(), b->are_values_constant()); + gemm_info.reshape_b_only_on_first_run(), gemm_info.constant_weights()); const bool fuse_add_c = (!(helpers::float_ops::is_zero(beta)) && c != nullptr); |