aboutsummaryrefslogtreecommitdiff
path: root/verif/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/tosa_test_gen.py')
-rw-r--r--verif/tosa_test_gen.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 5670d1b..6f9acf4 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -314,12 +314,12 @@ class TosaTensorGen:
def tgMatmul(testGen, op, rank):
pl, const = op["operands"]
- assert rank == 2
+ assert rank == 3
assert pl == 2 and const == 0
a_shape = testGen.makeShape(rank)
b_oc = testGen.makeShape(1)[0]
- b_shape = np.asarray([a_shape[1], b_oc])
+ b_shape = np.asarray([a_shape[0], a_shape[2], b_oc])
return [a_shape, b_shape]
@@ -1994,7 +1994,7 @@ class TosaTestGen:
"matmul": {
"op": Op.MATMUL,
"operands": (2, 0),
- "rank": (2, 2),
+ "rank": (3, 3),
"build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
"qgen": TosaQuantGen.qgMatmul,
"types": TYPE_NARROW_INT_FP,
@@ -2630,11 +2630,11 @@ class OutputShaper:
@staticmethod
def matmulOp(ser, a, b):
- # a: M, K
- # b: K, N
- # out: M, N
+ # a: N, H, C
+ # b: N, C, W
+ # out: N, H, W
- output_shape = [a.shape[0], b.shape[1]]
+ output_shape = [a.shape[0], a.shape[1], b.shape[2]]
if a.dtype == DType.INT8:
out_dtype = DType.INT32