From a641dd5598440e806fd1a09196a636de12ee2ce4 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Fri, 11 Aug 2023 19:58:50 +0000 Subject: Add support for bias broadcasting add support for bias broadcasting for operators: - conv2d - conv3d - depthwise_conv2d - transpose_conv2d - fully_connected could not add framework test for this because tf.nn.bias_add requires bias size to match channel dimension. manually tested reference model evaluation on tosa mlir with bias size of 1 Signed-off-by: Tai Ly Change-Id: I70d29d231a63fc03b10e3006cbd6b16b53cca1f2 --- reference_model/src/ops/tensor_ops.cc | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index acdeeeb..d9608b7 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -715,7 +715,8 @@ int OpConv2d::eval() in_channels); ERROR_IF(f_out_channels != out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels, out_channels); - ERROR_IF(b_out_channels != out_channels, "OpConv2d: bias channel mismatch %d != %d", b_out_channels, out_channels); + ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpConv2d: bias channel mismatch %d != %d", + b_out_channels, out_channels); int pad_top = this->attribute->pad()[0]; int pad_bottom = this->attribute->pad()[1]; @@ -767,7 +768,7 @@ int OpConv2d::eval() Eigen::array bias_bcast_dims; bias_bcast_dims[0] = out_batch * out_height * out_width; - bias_bcast_dims[1] = 1; + bias_bcast_dims[1] = (b_out_channels == 1) ? out_channels : 1; Eigen::array col2im_output_dims; col2im_output_dims[0] = out_batch; @@ -924,7 +925,8 @@ int OpConv3d::eval() in_channels); ERROR_IF(f_out_channels != out_channels, "OpConv3d: tensor output channel mismatch %d != %d", f_out_channels, out_channels); - ERROR_IF(b_out_channels != out_channels, "OpConv3d: bias channel mismatch %d != %d", b_out_channels, out_channels); + ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpConv3d: bias channel mismatch %d != %d", + b_out_channels, out_channels); int pad_d0 = this->attribute->pad()[0]; int pad_d1 = this->attribute->pad()[1]; @@ -1004,7 +1006,7 @@ int OpConv3d::eval() bcast[1] = out_depth; bcast[2] = out_height; bcast[3] = out_width; - bcast[4] = 1; + bcast[4] = (b_out_channels == 1) ? out_channels : 1; this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast); // 2. direct convolution @@ -1137,8 +1139,8 @@ int OpDepthwiseConv2d::eval() in_channels); ERROR_IF(in_channels * f_multiplier != out_channels, "OpDepthwiseConv2d: tensor output channel mismatch %d != %d", in_channels * f_multiplier, out_channels); - ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels, - out_channels); + ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, + "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels, out_channels); int pad_top = this->attribute->pad()[0]; int pad_bottom = this->attribute->pad()[1]; @@ -1212,7 +1214,7 @@ int OpDepthwiseConv2d::eval() bcast[0] = out_batch; bcast[1] = out_height; bcast[2] = out_width; - bcast[3] = 1; + bcast[3] = (b_out_channels == 1) ? out_channels : 1; // initialize with bias this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast); @@ -1321,13 +1323,19 @@ int OpFullyConnected::eval() Eigen::array weight_shuffle{ 1, 0 }; + int b_out_channels = this->bias->getShape()[0]; + int out_channels = this->output->getShape()[1]; + + ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpFullyConnected: bias channels mismatch %d != %d", + b_out_channels, out_channels); + Eigen::array bias_reshape; bias_reshape[0] = 1; - bias_reshape[1] = this->bias->getShape()[0]; + bias_reshape[1] = b_out_channels; Eigen::array bias_bcast; bias_bcast[0] = this->input->getShape()[0]; - bias_bcast[1] = 1; + bias_bcast[1] = (b_out_channels == 1) ? out_channels : 1; TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle); @@ -2028,8 +2036,8 @@ int OpTransposeConv2d::eval() in_channels); ERROR_IF(f_out_channels != out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d", f_out_channels, out_channels); - ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels, - out_channels); + ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, + "OpTransposeConv2d: bias channels mismatch %d != %d", b_out_channels, out_channels); // Check Tosa Level auto tosa_level = g_func_config.tosa_level; @@ -2076,7 +2084,7 @@ int OpTransposeConv2d::eval() bcast[0] = out_batch; bcast[1] = out_height; bcast[2] = out_width; - bcast[3] = 1; + bcast[3] = (b_out_channels == 1) ? out_channels : 1; // initialize with bias this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast); -- cgit v1.2.1