aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--verif/frameworks/tensor_gen.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index 3e70c87..90bda34 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -171,8 +171,9 @@ class TGen:
# TODO: Hard-code the test by making the OFM depth 2x the IFM depth.
# Could randomize this in the future.
- out_channels = ifm_shape[3] * 2
- filter_shape = (filter_d, filter_h, filter_w, ifm_shape[3], out_channels)
+ in_channels = ifm_shape[4]
+ out_channels = in_channels * 2
+ filter_shape = (filter_d, filter_h, filter_w, in_channels, out_channels)
return TGen.tgConvCommon(op, ifm_shape, filter_shape, out_channels, dtype, rng)