aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuDirectConv3d.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuDirectConv3d.h')
-rw-r--r--src/cpu/operators/CpuDirectConv3d.h22
1 files changed, 15 insertions, 7 deletions
diff --git a/src/cpu/operators/CpuDirectConv3d.h b/src/cpu/operators/CpuDirectConv3d.h
index ad04dee0fa..f7c3099be0 100644
--- a/src/cpu/operators/CpuDirectConv3d.h
+++ b/src/cpu/operators/CpuDirectConv3d.h
@@ -57,23 +57,31 @@ public:
~CpuDirectConv3d();
/** Set the input, weights, biases and output tensor info.
*
- * @param[in, out] src Input tensor info.
- * @param[in] weights Set of kernels to convolve the input volume.
- * The 2nd dimension must be the same as the input's volume 1st dimension.
- * Data type supported: Same as @p src.
- * @param[in] biases Set of biases. Can be nullptr. Data type supported: Same as @p src.
+ * Valid data layouts:
+ * - NDHWC
+ *
+ * Valid data type configurations:
+ * |src0 |src1 |src2 |dst |
+ * |:--------------|:------------------|:------|:--------------|
+ * |F16 |F16 |F16 |F16 |
+ * |F32 |F32 |F32 |F32 |
+ *
+ * @param[in, out] src0 Input tensor info.
+ * @param[in] src1 Set of kernels to convolve the input volume.
+ * The 2nd dimension must be the same as the src0's volume 1st dimension.
+ * @param[in] src2 Set of biases. Can be nullptr.
* @param[out] dst Output tensor info.
* The 1st dimensions must be equal to the 1st dimension of the @p kernels tensor.
* @param[in] conv_info Contains padding, stride, acitvation information.
*/
- void configure(ITensorInfo *src, ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const Conv3dInfo conv_info);
+ void configure(ITensorInfo *src0, ITensorInfo *src1, const ITensorInfo *src2, ITensorInfo *dst, const Conv3dInfo conv_info);
/** Static function to check if given info will lead to a valid configuration
*
* Similar to CpuDirectConv3d::configure()
*
* @return a status
*/
- static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv3dInfo conv_info);
+ static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, const Conv3dInfo conv_info);
// Inherited methods overridden:
void run(ITensorPack &tensors) override;