aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/datasets/Col2ImLayerDataset.h4
-rw-r--r--tests/validation/reference/Col2Im.cpp2
2 files changed, 3 insertions, 3 deletions
diff --git a/tests/datasets/Col2ImLayerDataset.h b/tests/datasets/Col2ImLayerDataset.h
index 96a3cab134..b39cedbde6 100644
--- a/tests/datasets/Col2ImLayerDataset.h
+++ b/tests/datasets/Col2ImLayerDataset.h
@@ -128,7 +128,7 @@ public:
add_config(TensorShape(8U, 16U, 3U, 1U), 4U, 4U, 3U);
add_config(TensorShape(8U, 16U, 3U, 3U), 4U, 4U, 3U);
add_config(TensorShape(12U, 20U, 4U, 1U), 5U, 4U, 4U);
- add_config(TensorShape(12U, 20U, 4U, 3U), 5U, 4U, 4U);
+ add_config(TensorShape(12U, 20U, 4U, 3U, 2U), 5U, 4U, 4U);
}
};
@@ -142,7 +142,7 @@ public:
add_config(TensorShape(333U, 280U, 1U, 77U), 14U, 20U, 1U);
add_config(TensorShape(333U, 280U, 77U, 1U), 14U, 20U, 1U);
add_config(TensorShape(120U, 300U, 8U, 3U), 20U, 15U, 8U);
- add_config(TensorShape(233U, 300U, 8U, 3U), 20U, 15U, 8U);
+ add_config(TensorShape(233U, 300U, 8U, 3U, 2U), 20U, 15U, 8U);
add_config(TensorShape(333U, 280U, 12U, 5U), 20U, 14U, 12U);
add_config(TensorShape(177U, 300U, 12U, 5U), 15U, 20U, 12U);
add_config(TensorShape(450U, 400U, 16U, 5U), 20U, 20U, 16U);
diff --git a/tests/validation/reference/Col2Im.cpp b/tests/validation/reference/Col2Im.cpp
index 90e488f928..53969d4725 100644
--- a/tests/validation/reference/Col2Im.cpp
+++ b/tests/validation/reference/Col2Im.cpp
@@ -40,7 +40,7 @@ SimpleTensor<T> col2im(const SimpleTensor<T> &src, const TensorShape &dst_shape,
SimpleTensor<T> dst{ dst_shape, src.data_type(), 1 };
// Compute reference
- const size_t batches = dst_shape[3];
+ const size_t batches = dst_shape.total_size() / (dst_shape.x() * dst_shape.y() * dst_shape.z());
const size_t src_width = src.shape().x();
const size_t src_height = src.shape().y();