aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/tensor_ops.cc
diff options
context:
space:
mode:
authorEric Kunze <eric.kunze@arm.com>2022-07-01 16:56:09 -0700
committerEric Kunze <eric.kunze@arm.com>2022-08-26 00:01:49 +0000
commitc1a978391b16dbbe634bc3338562066a75a6c678 (patch)
treef871bb95067bfc9c8df3278a3931eca2393c2c9f /reference_model/src/ops/tensor_ops.cc
parent4fb70ed713d49515040b041eff3639c36be17ac8 (diff)
downloadreference_model-c1a978391b16dbbe634bc3338562066a75a6c678.tar.gz
Align padding for transpose_conv2d to match spec
Increasing out pad values now leads to increasing pad. Reference model changes, and test generator changes to match specification definition Change-Id: I4f3ebfbca5048354fb15bedc7ab640ff28ed853a Signed-off-by: Eric Kunze <eric.kunze@arm.com> Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Diffstat (limited to 'reference_model/src/ops/tensor_ops.cc')
-rw-r--r--reference_model/src/ops/tensor_ops.cc22
1 files changed, 10 insertions, 12 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 03cb9fb..ef6dfa7 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -1498,14 +1498,7 @@ int OpTransposeConv2d<InDtype, WeightDtype>::checkTensorAttributes()
return 1;
}
- for (int32_t i : attribute->out_pad())
- {
- if (i < 0)
- {
- printNodeValidationError("OpTransposeConv2d: At least one pad is smaller than zero");
- return 1;
- }
- }
+
for (int32_t i : attribute->stride())
{
@@ -1540,8 +1533,13 @@ int OpTransposeConv2d<InDtype, WeightDtype>::checkTensorAttributes()
int32_t out_pad_left = attribute->out_pad()[2];
int32_t out_pad_right = attribute->out_pad()[3];
- int32_t H = (IH - 1) * stride_y - out_pad_top - out_pad_bottom + kernel_h;
- int32_t W = (IW - 1) * stride_x - out_pad_left - out_pad_right + kernel_w;
+ for (size_t i = 0; i < attribute->out_pad().size(); i++)
+ {
+ ERROR_IF(attribute->out_pad()[i] <= -(weight->getShape()[(i / 2) + 1]), "OpTransposeConv2d: At least one out_pad value is larger than kernel size");
+ }
+
+ int32_t H = (IH - 1) * stride_y + out_pad_top + out_pad_bottom + kernel_h;
+ int32_t W = (IW - 1) * stride_x + out_pad_left + out_pad_right + kernel_w;
if ((OH != H) || (OW != W))
{
@@ -1632,8 +1630,8 @@ int OpTransposeConv2d<InDtype, WeightDtype>::eval()
{
for (int iw = 0; iw < in_width; iw++)
{
- out_x_origin = iw * stride_w - out_pad_left;
- out_y_origin = ih * stride_h - out_pad_top;
+ out_x_origin = iw * stride_w + out_pad_left;
+ out_y_origin = ih * stride_h + out_pad_top;
for (int ic = 0; ic < in_channels; ic++)
{
for (int fh = 0; fh < f_height; fh++)