diff options
author | Tai Ly <tai.ly@arm.com> | 2023-08-11 19:58:50 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2023-08-14 21:11:00 +0000 |
commit | a641dd5598440e806fd1a09196a636de12ee2ce4 (patch) | |
tree | d96018d12e4676efd8cde311e517845fedc6e415 | |
parent | 898d3a208e274d9d178ae3865ad6979e579d4d7f (diff) | |
download | reference_model-a641dd5598440e806fd1a09196a636de12ee2ce4.tar.gz |
Add support for bias broadcasting
add support for bias broadcasting for operators:
- conv2d
- conv3d
- depthwise_conv2d
- transpose_conv2d
- fully_connected
could not add framework test for this because tf.nn.bias_add requires
bias size to match channel dimension.
manually tested reference model evaluation on tosa mlir with bias size
of 1
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I70d29d231a63fc03b10e3006cbd6b16b53cca1f2
-rw-r--r-- | reference_model/src/ops/tensor_ops.cc | 32 |
1 files changed, 20 insertions, 12 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index acdeeeb..d9608b7 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -715,7 +715,8 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::eval() in_channels); ERROR_IF(f_out_channels != out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels, out_channels); - ERROR_IF(b_out_channels != out_channels, "OpConv2d: bias channel mismatch %d != %d", b_out_channels, out_channels); + ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpConv2d: bias channel mismatch %d != %d", + b_out_channels, out_channels); int pad_top = this->attribute->pad()[0]; int pad_bottom = this->attribute->pad()[1]; @@ -767,7 +768,7 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::eval() Eigen::array<Eigen::Index, 2> bias_bcast_dims; bias_bcast_dims[0] = out_batch * out_height * out_width; - bias_bcast_dims[1] = 1; + bias_bcast_dims[1] = (b_out_channels == 1) ? out_channels : 1; Eigen::array<Eigen::Index, 4> col2im_output_dims; col2im_output_dims[0] = out_batch; @@ -924,7 +925,8 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::eval() in_channels); ERROR_IF(f_out_channels != out_channels, "OpConv3d: tensor output channel mismatch %d != %d", f_out_channels, out_channels); - ERROR_IF(b_out_channels != out_channels, "OpConv3d: bias channel mismatch %d != %d", b_out_channels, out_channels); + ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpConv3d: bias channel mismatch %d != %d", + b_out_channels, out_channels); int pad_d0 = this->attribute->pad()[0]; int pad_d1 = this->attribute->pad()[1]; @@ -1004,7 +1006,7 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::eval() bcast[1] = out_depth; bcast[2] = out_height; bcast[3] = out_width; - bcast[4] = 1; + bcast[4] = (b_out_channels == 1) ? out_channels : 1; this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast); // 2. direct convolution @@ -1137,8 +1139,8 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval() in_channels); ERROR_IF(in_channels * f_multiplier != out_channels, "OpDepthwiseConv2d: tensor output channel mismatch %d != %d", in_channels * f_multiplier, out_channels); - ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels, - out_channels); + ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, + "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels, out_channels); int pad_top = this->attribute->pad()[0]; int pad_bottom = this->attribute->pad()[1]; @@ -1212,7 +1214,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval() bcast[0] = out_batch; bcast[1] = out_height; bcast[2] = out_width; - bcast[3] = 1; + bcast[3] = (b_out_channels == 1) ? out_channels : 1; // initialize with bias this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast); @@ -1321,13 +1323,19 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval() Eigen::array<Eigen::Index, 2> weight_shuffle{ 1, 0 }; + int b_out_channels = this->bias->getShape()[0]; + int out_channels = this->output->getShape()[1]; + + ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpFullyConnected: bias channels mismatch %d != %d", + b_out_channels, out_channels); + Eigen::array<Eigen::Index, 2> bias_reshape; bias_reshape[0] = 1; - bias_reshape[1] = this->bias->getShape()[0]; + bias_reshape[1] = b_out_channels; Eigen::array<Eigen::Index, 2> bias_bcast; bias_bcast[0] = this->input->getShape()[0]; - bias_bcast[1] = 1; + bias_bcast[1] = (b_out_channels == 1) ? out_channels : 1; TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle); @@ -2028,8 +2036,8 @@ int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval() in_channels); ERROR_IF(f_out_channels != out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d", f_out_channels, out_channels); - ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels, - out_channels); + ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, + "OpTransposeConv2d: bias channels mismatch %d != %d", b_out_channels, out_channels); // Check Tosa Level auto tosa_level = g_func_config.tosa_level; @@ -2076,7 +2084,7 @@ int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval() bcast[0] = out_batch; bcast[1] = out_height; bcast[2] = out_width; - bcast[3] = 1; + bcast[3] = (b_out_channels == 1) ? out_channels : 1; // initialize with bias this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast); |