aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/tensor_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/tensor_ops.cc')
-rw-r--r--reference_model/src/ops/tensor_ops.cc32
1 files changed, 20 insertions, 12 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index acdeeeb..d9608b7 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -715,7 +715,8 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
in_channels);
ERROR_IF(f_out_channels != out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels,
out_channels);
- ERROR_IF(b_out_channels != out_channels, "OpConv2d: bias channel mismatch %d != %d", b_out_channels, out_channels);
+ ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpConv2d: bias channel mismatch %d != %d",
+ b_out_channels, out_channels);
int pad_top = this->attribute->pad()[0];
int pad_bottom = this->attribute->pad()[1];
@@ -767,7 +768,7 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
Eigen::array<Eigen::Index, 2> bias_bcast_dims;
bias_bcast_dims[0] = out_batch * out_height * out_width;
- bias_bcast_dims[1] = 1;
+ bias_bcast_dims[1] = (b_out_channels == 1) ? out_channels : 1;
Eigen::array<Eigen::Index, 4> col2im_output_dims;
col2im_output_dims[0] = out_batch;
@@ -924,7 +925,8 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
in_channels);
ERROR_IF(f_out_channels != out_channels, "OpConv3d: tensor output channel mismatch %d != %d", f_out_channels,
out_channels);
- ERROR_IF(b_out_channels != out_channels, "OpConv3d: bias channel mismatch %d != %d", b_out_channels, out_channels);
+ ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpConv3d: bias channel mismatch %d != %d",
+ b_out_channels, out_channels);
int pad_d0 = this->attribute->pad()[0];
int pad_d1 = this->attribute->pad()[1];
@@ -1004,7 +1006,7 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
bcast[1] = out_depth;
bcast[2] = out_height;
bcast[3] = out_width;
- bcast[4] = 1;
+ bcast[4] = (b_out_channels == 1) ? out_channels : 1;
this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);
// 2. direct convolution
@@ -1137,8 +1139,8 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
in_channels);
ERROR_IF(in_channels * f_multiplier != out_channels, "OpDepthwiseConv2d: tensor output channel mismatch %d != %d",
in_channels * f_multiplier, out_channels);
- ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels,
- out_channels);
+ ERROR_IF(b_out_channels != out_channels && b_out_channels != 1,
+ "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels, out_channels);
int pad_top = this->attribute->pad()[0];
int pad_bottom = this->attribute->pad()[1];
@@ -1212,7 +1214,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
bcast[0] = out_batch;
bcast[1] = out_height;
bcast[2] = out_width;
- bcast[3] = 1;
+ bcast[3] = (b_out_channels == 1) ? out_channels : 1;
// initialize with bias
this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);
@@ -1321,13 +1323,19 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
Eigen::array<Eigen::Index, 2> weight_shuffle{ 1, 0 };
+ int b_out_channels = this->bias->getShape()[0];
+ int out_channels = this->output->getShape()[1];
+
+ ERROR_IF(b_out_channels != out_channels && b_out_channels != 1, "OpFullyConnected: bias channels mismatch %d != %d",
+ b_out_channels, out_channels);
+
Eigen::array<Eigen::Index, 2> bias_reshape;
bias_reshape[0] = 1;
- bias_reshape[1] = this->bias->getShape()[0];
+ bias_reshape[1] = b_out_channels;
Eigen::array<Eigen::Index, 2> bias_bcast;
bias_bcast[0] = this->input->getShape()[0];
- bias_bcast[1] = 1;
+ bias_bcast[1] = (b_out_channels == 1) ? out_channels : 1;
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
@@ -2028,8 +2036,8 @@ int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
in_channels);
ERROR_IF(f_out_channels != out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d",
f_out_channels, out_channels);
- ERROR_IF(b_out_channels != out_channels, "OpDepthwiseConv2d: bias channels mismatch %d != %d", b_out_channels,
- out_channels);
+ ERROR_IF(b_out_channels != out_channels && b_out_channels != 1,
+ "OpTransposeConv2d: bias channels mismatch %d != %d", b_out_channels, out_channels);
// Check Tosa Level
auto tosa_level = g_func_config.tosa_level;
@@ -2076,7 +2084,7 @@ int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
bcast[0] = out_batch;
bcast[1] = out_height;
bcast[2] = out_width;
- bcast[3] = 1;
+ bcast[3] = (b_out_channels == 1) ? out_channels : 1;
// initialize with bias
this->output->getTensor() = bias_val.reshape(reshape_dim).broadcast(bcast);