aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators')
-rw-r--r--src/cpu/operators/CpuDirectConv3d.cpp18
-rw-r--r--src/cpu/operators/CpuDirectConv3d.h22
2 files changed, 22 insertions, 18 deletions
diff --git a/src/cpu/operators/CpuDirectConv3d.cpp b/src/cpu/operators/CpuDirectConv3d.cpp
index 3827910d37..aa74e420a6 100644
--- a/src/cpu/operators/CpuDirectConv3d.cpp
+++ b/src/cpu/operators/CpuDirectConv3d.cpp
@@ -40,10 +40,10 @@ CpuDirectConv3d::CpuDirectConv3d(std::shared_ptr<IMemoryManager> memory_manager)
{
}
-void CpuDirectConv3d::configure(ITensorInfo *src, ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo conv_info)
+void CpuDirectConv3d::configure(ITensorInfo *src0, ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo conv_info)
{
- ARM_COMPUTE_LOG_PARAMS(src, weights, biases, dst, conv_info);
- ARM_COMPUTE_ERROR_ON(src->data_layout() != DataLayout::NDHWC);
+ ARM_COMPUTE_LOG_PARAMS(src0, src1, src2, dst, conv_info);
+ ARM_COMPUTE_ERROR_ON(src0->data_layout() != DataLayout::NDHWC);
_conv_kernel = std::make_unique<kernels::CpuDirectConv3dKernel>();
@@ -55,7 +55,7 @@ void CpuDirectConv3d::configure(ITensorInfo *src, ITensorInfo *weights, const IT
_dim_split = Window::DimY;
- _conv_kernel->configure(src, weights, biases, dst, conv_info);
+ _conv_kernel->configure(src0, src1, src2, dst, conv_info);
//Configure Activation Layer
_is_activationlayer_enabled = conv_info.act_info.enabled();
@@ -66,16 +66,12 @@ void CpuDirectConv3d::configure(ITensorInfo *src, ITensorInfo *weights, const IT
}
}
-Status CpuDirectConv3d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo conv_info)
+Status CpuDirectConv3d::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo conv_info)
{
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst);
-
- // output might not be initialized since it can be an intermediate tensor of another layer
- DataType data_type = src->data_type();
- TensorInfo accumulator(dst->clone()->set_is_resizable(true).reset_padding().set_data_type(data_type));
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
// Validate Convolution kernel
- ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuDirectConv3dKernel::validate(src, weights, biases, &accumulator, conv_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuDirectConv3dKernel::validate(src0, src1, src2, dst, conv_info));
if(conv_info.act_info.enabled())
{
diff --git a/src/cpu/operators/CpuDirectConv3d.h b/src/cpu/operators/CpuDirectConv3d.h
index ad04dee0fa..f7c3099be0 100644
--- a/src/cpu/operators/CpuDirectConv3d.h
+++ b/src/cpu/operators/CpuDirectConv3d.h
@@ -57,23 +57,31 @@ public:
~CpuDirectConv3d();
/** Set the input, weights, biases and output tensor info.
*
- * @param[in, out] src Input tensor info.
- * @param[in] weights Set of kernels to convolve the input volume.
- * The 2nd dimension must be the same as the input's volume 1st dimension.
- * Data type supported: Same as @p src.
- * @param[in] biases Set of biases. Can be nullptr. Data type supported: Same as @p src.
+ * Valid data layouts:
+ * - NDHWC
+ *
+ * Valid data type configurations:
+ * |src0 |src1 |src2 |dst |
+ * |:--------------|:------------------|:------|:--------------|
+ * |F16 |F16 |F16 |F16 |
+ * |F32 |F32 |F32 |F32 |
+ *
+ * @param[in, out] src0 Input tensor info.
+ * @param[in] src1 Set of kernels to convolve the input volume.
+ * The 2nd dimension must be the same as the src0's volume 1st dimension.
+ * @param[in] src2 Set of biases. Can be nullptr.
* @param[out] dst Output tensor info.
* The 1st dimensions must be equal to the 1st dimension of the @p kernels tensor.
* @param[in] conv_info Contains padding, stride, acitvation information.
*/
- void configure(ITensorInfo *src, ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo conv_info);
+ void configure(ITensorInfo *src0, ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo conv_info);
/** Static function to check if given info will lead to a valid configuration
*
* Similar to CpuDirectConv3d::configure()
*
* @return a status
*/
- static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo conv_info);
+ static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo conv_info);
// Inherited methods overridden:
void run(ITensorPack &tensors) override;