aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp')
-rw-r--r--src/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp9
1 files changed, 6 insertions, 3 deletions
diff --git a/src/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp b/src/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp
index e75b082ca5..a5b9eca56e 100644
--- a/src/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp
+++ b/src/cpu/operators/CpuDepthwiseConv2dAssemblyDispatch.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2021 Arm Limited.
+ * Copyright (c) 2019-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -40,6 +40,7 @@ struct CpuDepthwiseConv2dAssemblyDispatch::LocalImpl
{
std::unique_ptr<kernels::CpuDepthwiseConv2dAssemblyWrapperKernel> asm_kernel{ nullptr };
bool is_prepared{ false };
+ bool are_weights_const{ true };
experimental::MemoryRequirements mem_req{};
};
@@ -62,6 +63,7 @@ void CpuDepthwiseConv2dAssemblyDispatch::configure(const ITensorInfo *src,
const CPUInfo &ci = NEScheduler::get().cpu_info();
const unsigned int num_threads = NEScheduler::get().num_threads();
_pImpl->is_prepared = false;
+ _pImpl->are_weights_const = weights->are_values_constant();
// If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured()
if(!CpuDepthwiseConv2dAssemblyDispatch::validate(src, weights, bias, dst, info))
@@ -107,10 +109,11 @@ void CpuDepthwiseConv2dAssemblyDispatch::run(ITensorPack &tensors)
void CpuDepthwiseConv2dAssemblyDispatch::prepare(ITensorPack &tensors)
{
- if(!_pImpl->is_prepared)
+ const ITensor *weights = tensors.get_const_tensor(TensorType::ACL_SRC_1);
+
+ if((!_pImpl->are_weights_const && weights != nullptr) || !_pImpl->is_prepared)
{
// Pack weights and bias
- const ITensor *weights = tensors.get_const_tensor(TensorType::ACL_SRC_1);
const ITensor *bias = tensors.get_const_tensor(TensorType::ACL_SRC_2);
ITensor *storage = tensors.get_tensor(TensorType::ACL_INT_1);