diff options
-rw-r--r-- | verif/frameworks/tensor_gen.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py index 3e70c87..90bda34 100644 --- a/verif/frameworks/tensor_gen.py +++ b/verif/frameworks/tensor_gen.py @@ -171,8 +171,9 @@ class TGen: # TODO: Hard-code the test by making the OFM depth 2x the IFM depth. # Could randomize this in the future. - out_channels = ifm_shape[3] * 2 - filter_shape = (filter_d, filter_h, filter_w, ifm_shape[3], out_channels) + in_channels = ifm_shape[4] + out_channels = in_channels * 2 + filter_shape = (filter_d, filter_h, filter_w, in_channels, out_channels) return TGen.tgConvCommon(op, ifm_shape, filter_shape, out_channels, dtype, rng) |