aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/test_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/test/test_graph_optimiser.py')
-rw-r--r--ethosu/vela/test/test_graph_optimiser.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py
index b37bac80..e0eedd66 100644
--- a/ethosu/vela/test/test_graph_optimiser.py
+++ b/ethosu/vela/test/test_graph_optimiser.py
@@ -139,8 +139,7 @@ def create_pad_and_conv2d(
conv_out_tens = Tensor(in_shape, in_dtype, "output")
conv_out_tens.quantization = qp.clone()
weight_tens = Tensor([kernel_size, kernel_size, in_shape[-1], out_shape[-1]], in_dtype, "weights")
- weight_tens.values = np.zeros(weight_tens.shape)
- weight_tens.quant_values = np.zeros(weight_tens.shape, np.int8)
+ weight_tens.values = np.zeros(weight_tens.shape, in_dtype.as_numpy_type())
weight_tens.quantization = qp.clone()
bias_tens = Tensor(out_shape, pad_dtype, "biases")
attrs = {"padding": pad_setting, "stride_w": 2, "stride_h": 2, "dilation_w_factor": 1, "dilation_h_factor": 1}
@@ -349,8 +348,7 @@ def test_remove_reshape():
conv_ofm = Tensor([1, 8, 8, 16], DataType.uint8, "output")
conv_ofm.quantization = quant.clone()
weight_tens = Tensor([1, 1, 16, 16], DataType.uint8, "weights")
- weight_tens.values = np.zeros(weight_tens.shape)
- weight_tens.quant_values = np.zeros(weight_tens.shape, np.uint8)
+ weight_tens.values = np.zeros(weight_tens.shape, np.uint8)
weight_tens.quantization = quant.clone()
bias_tens = Tensor([16], DataType.int32, "biases")