aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/tensor_ops.cc
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-08-11 19:58:50 +0000
committerTai Ly <tai.ly@arm.com>2023-08-14 21:11:00 +0000
commita641dd5598440e806fd1a09196a636de12ee2ce4 (patch)
treed96018d12e4676efd8cde311e517845fedc6e415 /reference_model/src/ops/tensor_ops.cc
parent898d3a208e274d9d178ae3865ad6979e579d4d7f (diff)
downloadreference_model-a641dd5598440e806fd1a09196a636de12ee2ce4.tar.gz
Add support for bias broadcasting
add support for bias broadcasting for operators: - conv2d - conv3d - depthwise_conv2d - transpose_conv2d - fully_connected could not add framework test for this because tf.nn.bias_add requires bias size to match channel dimension. manually tested reference model evaluation on tosa mlir with bias size of 1 Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I70d29d231a63fc03b10e3006cbd6b16b53cca1f2
Diffstat (limited to 'reference_model/src/ops/tensor_ops.cc')
-rw-r--r--reference_model/src/ops/tensor_ops.cc32
1 files changed, 20 insertions, 12 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index acdeeeb..d9608b7 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -715,7 +715,8 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
in_channels);
ERROR_IF(f_out_channels != out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels,
out_channels);
- ERROR_IF(b_out_channels != out_channels, "OpConv2d: bias channel mismatch %d != %d", b_out_channels, out_channels);
+ ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpConv2d: bias channel mismatch %d != %d",
+ b_out_channels, out_channels);
int pad_top = this->attribute->pad()[0];
int pad_bottom = this->attribute->pad()[1];
@@ -767,7 +768,7 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eigen::array<Eigen::Index, 2> bias_bcast_dims;
bias_bcast_dims[0] = out_batch * out_height * out_width;
- bias_bcast_dims[1] = 1;
+ bias_bcast_dims[1] = (b_out_channels == 1) ? out_channels : 1;
Eigen::array<Eigen::Index, 4> col2im_output_dims;
col2im_output_dims[0] = out_batch;
@@ -924,7 +925,8 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
in_channels);
ERROR_IF(f_out_channels != out_channels, "OpConv3d: tensor output channel mismatch %d != %d", f_out_channels,
out_channels);
- ERROR_IF(b_out_channels != out_channels, "OpConv3d: bias channel mismatch %d != %d", b_out_channels, out_channels);
+ ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpConv3d: bias channel mismatch %d != %d",
+ b_out_channels, out_channels);
int pad_d0 = this->attribute->pad()[0];
int pad_d1 = this->attribute->pad()[1];
@@ -1004,7 +1006,7 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
bcast[1] = out_depth;
bcast[2] = out_height;
bcast[3] = out_width;
- bcast[4] = 1;
+ bcast[4] = (b_out_channels == 1) ? out_channels : 1;
this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);
// 2. direct convolution
@@ -1137,8 +1139,8 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
in_channels);
ERROR_IF(in_channels * f_multiplier != out_channels, "OpDepthwiseConv2d: tensor output channel mismatch %d != %d",
in_channels * f_multiplier, out_channels);
- ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels,
- out_channels);
+ ERROR_IF(b_out_channels != out_channels && b_out_channels != 1,
+ "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels, out_channels);
int pad_top = this->attribute->pad()[0];
int pad_bottom = this->attribute->pad()[1];
@@ -1212,7 +1214,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
bcast[0] = out_batch;
bcast[1] = out_height;
bcast[2] = out_width;
- bcast[3] = 1;
+ bcast[3] = (b_out_channels == 1) ? out_channels : 1;
// initialize with bias
this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);
@@ -1321,13 +1323,19 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
Eigen::array<Eigen::Index, 2> weight_shuffle{ 1, 0 };
+ int b_out_channels = this->bias->getShape()[0];
+ int out_channels = this->output->getShape()[1];
+
+ ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpFullyConnected: bias channels mismatch %d != %d",
+ b_out_channels, out_channels);
+
Eigen::array<Eigen::Index, 2> bias_reshape;
bias_reshape[0] = 1;
- bias_reshape[1] = this->bias->getShape()[0];
+ bias_reshape[1] = b_out_channels;
Eigen::array<Eigen::Index, 2> bias_bcast;
bias_bcast[0] = this->input->getShape()[0];
- bias_bcast[1] = 1;
+ bias_bcast[1] = (b_out_channels == 1) ? out_channels : 1;
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
@@ -2028,8 +2036,8 @@ int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
in_channels);
ERROR_IF(f_out_channels != out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d",
f_out_channels, out_channels);
- ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels,
- out_channels);
+ ERROR_IF(b_out_channels != out_channels && b_out_channels != 1,
+ "OpTransposeConv2d: bias channels mismatch %d != %d", b_out_channels, out_channels);
// Check Tosa Level
auto tosa_level = g_func_config.tosa_level;
@@ -2076,7 +2084,7 @@ int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
bcast[0] = out_batch;
bcast[1] = out_height;
bcast[2] = out_width;
- bcast[3] = 1;
+ bcast[3] = (b_out_channels == 1) ? out_channels : 1;
// initialize with bias
this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);