aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTatWai Chong <tatwai.chong@arm.com>2022-08-29 14:50:48 -0700
committerTatWai Chong <tatwai.chong@arm.com>2022-08-29 14:51:20 -0700
commit5a76b2a81ee1f62dede1f6549ab3f7924338a9eb (patch)
treef5ac33845ae12ae202f089e155db0550c2f8b927
parentfd62905d807b5976bea28b6d766e614c076faacf (diff)
downloadreference_model-5a76b2a81ee1f62dede1f6549ab3f7924338a9eb.tar.gz
[Fix] Wrong dimension is inferred in tensor shape generation
Signed-off-by: TatWai Chong <tatwai.chong@arm.com> Change-Id: I0bccfbe971f64986d71cef5a1d68daa7eb1697c4
-rw-r--r--verif/frameworks/tensor_gen.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/verif/frameworks/tensor_gen.py b/verif/frameworks/tensor_gen.py
index 3e70c87..90bda34 100644
--- a/verif/frameworks/tensor_gen.py
+++ b/verif/frameworks/tensor_gen.py
@@ -171,8 +171,9 @@ class TGen:
# TODO: Hard-code the test by making the OFM depth 2x the IFM depth.
# Could randomize this in the future.
- out_channels = ifm_shape[3] * 2
- filter_shape = (filter_d, filter_h, filter_w, ifm_shape[3], out_channels)
+ in_channels = ifm_shape[4]
+ out_channels = in_channels * 2
+ filter_shape = (filter_d, filter_h, filter_w, in_channels, out_channels)
return TGen.tgConvCommon(op, ifm_shape, filter_shape, out_channels, dtype, rng)