aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuDirectConv3d.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuDirectConv3d.cpp')
-rw-r--r--src/cpu/operators/CpuDirectConv3d.cpp18
1 files changed, 7 insertions, 11 deletions
diff --git a/src/cpu/operators/CpuDirectConv3d.cpp b/src/cpu/operators/CpuDirectConv3d.cpp
index 3827910d37..aa74e420a6 100644
--- a/src/cpu/operators/CpuDirectConv3d.cpp
+++ b/src/cpu/operators/CpuDirectConv3d.cpp
@@ -40,10 +40,10 @@ CpuDirectConv3d::CpuDirectConv3d(std::shared_ptr<IMemoryManager> memory_manager)
{
}
-void CpuDirectConv3d::configure(ITensorInfo *src, ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo conv_info)
+void CpuDirectConv3d::configure(ITensorInfo *src0, ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo conv_info)
{
- ARM_COMPUTE_LOG_PARAMS(src, weights, biases, dst, conv_info);
- ARM_COMPUTE_ERROR_ON(src->data_layout() != DataLayout::NDHWC);
+ ARM_COMPUTE_LOG_PARAMS(src0, src1, src2, dst, conv_info);
+ ARM_COMPUTE_ERROR_ON(src0->data_layout() != DataLayout::NDHWC);
_conv_kernel = std::make_unique<kernels::CpuDirectConv3dKernel>();
@@ -55,7 +55,7 @@ void CpuDirectConv3d::configure(ITensorInfo *src, ITensorInfo *weights, const IT
_dim_split = Window::DimY;
- _conv_kernel->configure(src, weights, biases, dst, conv_info);
+ _conv_kernel->configure(src0, src1, src2, dst, conv_info);
//Configure Activation Layer
_is_activationlayer_enabled = conv_info.act_info.enabled();
@@ -66,16 +66,12 @@ void CpuDirectConv3d::configure(ITensorInfo *src, ITensorInfo *weights, const IT
}
}
-Status CpuDirectConv3d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo conv_info)
+Status CpuDirectConv3d::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo conv_info)
{
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst);
-
- // output might not be initialized since it can be an intermediate tensor of another layer
- DataType data_type = src->data_type();
- TensorInfo accumulator(dst->clone()->set_is_resizable(true).reset_padding().set_data_type(data_type));
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
// Validate Convolution kernel
- ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuDirectConv3dKernel::validate(src, weights, biases, &accumulator, conv_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuDirectConv3dKernel::validate(src0, src1, src2, dst, conv_info));
if(conv_info.act_info.enabled())
{