aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Hall <tim.hall@arm.com>2023-01-13 17:57:25 +0000
committertim.hall <tim.hall@arm.com>2023-01-20 14:07:21 +0000
commit3b1578e44b4c6a8c8c9a8e0891d3866a89bd66af (patch)
tree491c337bc854d435b80f0a535496084ea9ebc9ac
parentf34904717f643499f3ea6210322bbe1b635db088 (diff)
downloadethos-u-vela-3b1578e44b4c6a8c8c9a8e0891d3866a89bd66af.tar.gz
MLBEDSW-7151: MLCE: Difference in model output between x86 & aarch64
- The issue is due to undefined behaviour when casting a NumPy float to a NumPy unsigned integer which occurs in create_const_tensor() - The fix is to make sure that the values are first cast to a Python float - In addition, the values datatype argument has been removed from create_const_tensor() to stop the tensor and values datatypes getting out of sync Change-Id: I134b9be8c941b361929a5ae7db8cb35f2e9728f2 Signed-off-by: Tim Hall <tim.hall@arm.com>
-rw-r--r--ethosu/vela/data_type.py4
-rw-r--r--ethosu/vela/graph_optimiser_util.py2
-rw-r--r--ethosu/vela/lut.py5
-rw-r--r--ethosu/vela/softmax.py50
-rw-r--r--ethosu/vela/tensor.py26
-rw-r--r--ethosu/vela/test/test_graph_optimiser.py21
-rw-r--r--ethosu/vela/test/test_lut.py16
-rw-r--r--ethosu/vela/test/test_tflite_model_semantic.py13
-rw-r--r--ethosu/vela/test/test_tflite_supported_operators.py33
-rw-r--r--ethosu/vela/test/testutil.py26
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py35
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py20
12 files changed, 99 insertions, 152 deletions
diff --git a/ethosu/vela/data_type.py b/ethosu/vela/data_type.py
index 829cef38..5d0320b2 100644
--- a/ethosu/vela/data_type.py
+++ b/ethosu/vela/data_type.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -110,7 +110,7 @@ class DataType:
BaseType.Complex: "c",
}
assert self.type in numpy_dtype_code, f"Failed to interpret {self} as a numpy dtype"
- return np.dtype(numpy_dtype_code[self.type] + str(self.size_in_bytes()))
+ return np.dtype(numpy_dtype_code[self.type] + str(self.size_in_bytes())).type
stem_name = {
BaseType.UnsignedInt: ("uint%s", True),
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index d90c06bd..2822feb8 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -429,7 +429,7 @@ def convert_to_lut(op, lut_values, lut_name):
quantization = QuantizationParameters(0.0, 255.0)
quantization.scale_f32 = ifm.quantization.scale_f32
quantization.zero_point = 0
- tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], np.uint8, quantization=quantization)
+ tens = create_const_tensor(op.inputs[0].name + "_scalar0", [], ifm.dtype, [0], quantization=quantization)
op.add_input_tensor(tens)
op.ifm_shapes.append(Shape4D(tens.shape)) # TODO no shape?
diff --git a/ethosu/vela/lut.py b/ethosu/vela/lut.py
index fdf9d0ff..d0ac9706 100644
--- a/ethosu/vela/lut.py
+++ b/ethosu/vela/lut.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -88,8 +88,7 @@ def create_lut_tensor(name, values, dtype):
# address in constant memory, and unnecessary DMA operations can be avoided.
sz = len(values)
assert sz in (256, 512)
- ntype = np.uint8 if dtype.size_in_bytes() == 1 else np.uint32
- tens = create_const_tensor(name, [1, 1, 1, sz], dtype, values, ntype, TensorPurpose.LUT)
+ tens = create_const_tensor(name, [1, 1, 1, sz], dtype, values, TensorPurpose.LUT)
tens.equivalence_id = create_equivalence_id(tuple(values))
return tens
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index a92d0bb2..575e1e66 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
@@ -270,9 +270,7 @@ class SoftMax:
ifm2_shape=ifm_max_shape,
)
sub_op.set_activation_lut(
- create_const_tensor(
- f"{sub_op.name}_exp_lut", [1, 1, 1, 256], DataType.int32, exp_lut, np.int32, TensorPurpose.LUT
- )
+ create_const_tensor(f"{sub_op.name}_exp_lut", [1, 1, 1, 256], DataType.int32, exp_lut, TensorPurpose.LUT)
)
ifm_exp = add_op_get_ofm(sub_op)
# Note: activation.min/max are non-quantized values
@@ -281,9 +279,7 @@ class SoftMax:
# PASS 2 - SHR
name = f"{self.op.name}_shr{pass_number}"
- shift = create_const_tensor(
- f"{name}_const", [1, 1, 1, 1], DataType.int32, [12], np.int32, quantization=no_scale_quant
- )
+ shift = create_const_tensor(f"{name}_const", [1, 1, 1, 1], DataType.int32, [12], quantization=no_scale_quant)
shr_op = create_shr(name, ifm_exp, shift, no_scale_quant, activation)
shr_op.rounding_mode = NpuRoundingMode.NATURAL
rescaled_exp = add_op_get_ofm(shr_op)
@@ -304,7 +300,6 @@ class SoftMax:
[1, 1, 1, 1],
DataType.int32,
[12 + 31 - 8],
- np.int32,
quantization=no_scale_quant,
)
right_shift = add_op_get_ofm(
@@ -318,7 +313,7 @@ class SoftMax:
)
# PASS 6 - Sub
- one = create_const_tensor("one_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant)
+ one = create_const_tensor("one_const", [1, 1, 1, 1], DataType.int32, [1], quantization=no_scale_quant)
headroom = add_op_get_ofm(
create_sub(f"{self.op.name}_sub{pass_number}", headroom_plus_one, one, no_scale_quant, activation)
)
@@ -330,7 +325,7 @@ class SoftMax:
# PASS 8 - Sub
shifted_one = create_const_tensor(
- "shifted_one_const", [1, 1, 1, 1], DataType.int32, [1 << 30], np.int32, quantization=no_scale_quant
+ "shifted_one_const", [1, 1, 1, 1], DataType.int32, [1 << 30], quantization=no_scale_quant
)
shifted_sum_minus_one = add_op_get_ofm(
create_sub(f"{self.op.name}_sub{pass_number}", shifted_sum, shifted_one, no_scale_quant, activation)
@@ -349,7 +344,7 @@ class SoftMax:
# PASS 10 - Add
f0_one_const = create_const_tensor(
- "F0_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 31) - 1], np.int32, quantization=no_scale_quant
+ "F0_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 31) - 1], quantization=no_scale_quant
)
add_op = create_add(
f"{self.op.name}_add{pass_number}",
@@ -363,7 +358,7 @@ class SoftMax:
# PASS 11 - Multiply
neg_32_over_17 = create_const_tensor(
- "neg_32_over_17_const", [1, 1, 1, 1], DataType.int32, [-1010580540], np.int32, quantization=one_scale_quant
+ "neg_32_over_17_const", [1, 1, 1, 1], DataType.int32, [-1010580540], quantization=one_scale_quant
)
rescaled = add_op_get_ofm(
create_mul(
@@ -377,7 +372,7 @@ class SoftMax:
# PASS 12 - Add
const_48_over_17 = create_const_tensor(
- "48_over_17_const", [1, 1, 1, 1], DataType.int32, [1515870810], np.int32, quantization=no_scale_quant
+ "48_over_17_const", [1, 1, 1, 1], DataType.int32, [1515870810], quantization=no_scale_quant
)
rescale_w_offset = add_op_get_ofm(
create_add(
@@ -392,11 +387,9 @@ class SoftMax:
# PASS 13 - 27
nr_x = rescale_w_offset
F2_one = create_const_tensor(
- "F2_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 29)], np.int32, quantization=no_scale_quant
- )
- four = create_const_tensor(
- "four_const", [1, 1, 1, 1], DataType.int32, [4], np.int32, quantization=no_scale_quant
+ "F2_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 29)], quantization=no_scale_quant
)
+ four = create_const_tensor("four_const", [1, 1, 1, 1], DataType.int32, [4], quantization=no_scale_quant)
for _ in range(3):
# PASS 13, 18, 23 - MUL
half_denominator_times_x = add_op_get_ofm(
@@ -438,7 +431,7 @@ class SoftMax:
)
# PASS 28 - Multiply
- two = create_const_tensor("two_const", [1, 1, 1, 1], DataType.int32, [2], np.int32, quantization=no_scale_quant)
+ two = create_const_tensor("two_const", [1, 1, 1, 1], DataType.int32, [2], quantization=no_scale_quant)
scale_factor = add_op_get_ofm(
create_mul(f"{self.op.name}_mul{pass_number}", nr_x, two, one_scale_quant, activation)
)
@@ -502,20 +495,18 @@ class SoftMax:
mul2_quant = ofm.quantization.clone()
mul2_quant.scale_f32 = mul2_out_range
scale = create_const_tensor(
- f"{name}_scale_const", [1, 1, 1, 1], DataType.int32, [mul2_scale], np.int32, quantization=scale_quant
+ f"{name}_scale_const", [1, 1, 1, 1], DataType.int32, [mul2_scale], quantization=scale_quant
)
mul2_ofm = add_op_get_ofm(create_mul(name, sub1_ofm, scale, mul2_quant))
# PASS 3 - Add+LUT(exp)
name = f"{self.op.name}_add{pass_number}"
const_add = create_const_tensor(
- f"{name}_const", [1, 1, 1, 1], DataType.int32, [32767], np.int32, quantization=no_scale_quant
+ f"{name}_const", [1, 1, 1, 1], DataType.int32, [32767], quantization=no_scale_quant
)
add_op = create_add(name, mul2_ofm, const_add, mul2_ofm.quantization.clone(), dtype=DataType.int16)
add_op.set_activation_lut(
- create_const_tensor(
- f"{name}_exp_lut", [1, 1, 1, 512], DataType.int32, self.EXP_LUT, np.int32, TensorPurpose.LUT
- )
+ create_const_tensor(f"{name}_exp_lut", [1, 1, 1, 512], DataType.int32, self.EXP_LUT, TensorPurpose.LUT)
)
ifm_exp = add_op_get_ofm(add_op)
@@ -529,13 +520,11 @@ class SoftMax:
# PASS 6 - Sub
name = f"{self.op.name}_sub{pass_number}"
- const_31 = create_const_tensor(
- f"{name}_const", [1, 1, 1, 1], DataType.int32, [31], np.int32, quantization=no_scale_quant
- )
+ const_31 = create_const_tensor(f"{name}_const", [1, 1, 1, 1], DataType.int32, [31], quantization=no_scale_quant)
reciprocal_right_shift = add_op_get_ofm(create_sub(name, const_31, headroom_plus_one, no_scale_quant))
# PASS 7 - SHL
- one = create_const_tensor("one_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant)
+ one = create_const_tensor("one_const", [1, 1, 1, 1], DataType.int32, [1], quantization=no_scale_quant)
constant_one = add_op_get_ofm(
create_shl(f"{self.op.name}_shl{pass_number}", one, reciprocal_right_shift, no_scale_quant)
)
@@ -552,15 +541,13 @@ class SoftMax:
# PASS 10 - SHR
name = f"{self.op.name}_shr{pass_number}"
- shift = create_const_tensor(
- f"{name}_const", [1, 1, 1, 1], DataType.int32, [15], np.int32, quantization=no_scale_quant
- )
+ shift = create_const_tensor(f"{name}_const", [1, 1, 1, 1], DataType.int32, [15], quantization=no_scale_quant)
shifted_sum_minus_one_16 = add_op_get_ofm(create_shr(name, shifted_sum_minus_one, shift, no_scale_quant))
# PASS 11 - Sub+LUT(one over one plus x)
name = f"{self.op.name}_sub{pass_number}"
sub11_const = create_const_tensor(
- f"{name}_const", [1, 1, 1, 1], DataType.int32, [32768], np.int32, quantization=no_scale_quant
+ f"{name}_const", [1, 1, 1, 1], DataType.int32, [32768], quantization=no_scale_quant
)
sub11_op = create_sub(name, shifted_sum_minus_one_16, sub11_const, no_scale_quant, dtype=DataType.int16)
sub11_op.set_activation_lut(
@@ -569,7 +556,6 @@ class SoftMax:
[1, 1, 1, 512],
DataType.int32,
self.ONE_OVER_ONE_PLUS_X_LUT,
- np.uint32,
TensorPurpose.LUT,
)
)
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 899b1bed..6a95bad4 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -300,17 +300,31 @@ class QuantizationParameters:
def create_const_tensor(
name: str,
shape: Shape,
- dtype: DataType,
- values: np.ndarray,
- value_dtype: np.dtype = None,
+ dtype: DataType, # datatype of the tensor
+ values: Optional[Union[np.ndarray, list]], # list-like data of some type, or scalar (skip mypy), or None
purpose: TensorPurpose = TensorPurpose.Unknown,
- quantization: QuantizationParameters = None,
+ quantization: Optional[QuantizationParameters] = None,
):
+ assert isinstance(dtype, DataType)
+
# Tensor
const_tensor = Tensor(shape, dtype, name + "_0")
const_tensor.purpose = purpose
const_tensor.quantization = quantization
- const_tensor.values = np.array(values, dtype=value_dtype)
+
+ # if the tensor datatype does not match that of the values then np.array() will perform a cast operation. this can
+ # result in undefined behaviour if casting from a numpy float to a numpy unsigned integer. therefore, we need to
+ # avoid this undefined behaviour by converting the numpy floats to python floats as these give the desired behaviour
+ # when casting to unsigned integers
+ if (
+ values is not None
+ and shape != [] # values are not a scalar
+ and isinstance(values[0], np.floating)
+ and dtype.type == BaseType.Unsigned
+ ):
+ values = [float(v) for v in values]
+
+ const_tensor.values = np.array(values, dtype=dtype.as_numpy_type())
# Operator
const_op = Operation(Op.Const, name)
const_op.set_output_tensor(const_tensor)
diff --git a/ethosu/vela/test/test_graph_optimiser.py b/ethosu/vela/test/test_graph_optimiser.py
index 152669f7..54dd70f6 100644
--- a/ethosu/vela/test/test_graph_optimiser.py
+++ b/ethosu/vela/test/test_graph_optimiser.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -40,9 +40,9 @@ from ethosu.vela.tflite_graph_optimiser import rewrite_fully_connected_input
def test_convert_batched_fc():
"""Tests shape conversion of batched fully connected"""
ifm_shape = [4, 8]
- ifm = create_const_tensor("test_in", ifm_shape, np.uint8, np.zeros(ifm_shape))
+ ifm = create_const_tensor("test_in", ifm_shape, DataType.uint8, np.zeros(ifm_shape))
w_shape = [8, 4]
- weights = create_const_tensor("weight_in", w_shape, np.uint8, np.zeros(w_shape))
+ weights = create_const_tensor("weight_in", w_shape, DataType.uint8, np.zeros(w_shape))
ofm = Tensor(ifm.shape, np.uint8, "test_out")
op = testutil.create_op(Op.FullyConnected, [ifm, weights], ofm)
@@ -132,7 +132,8 @@ def create_pad_and_conv2d(
qp = testutil.default_quant_params()
in0 = Tensor(in_shape, in_dtype, "in")
in0.quantization = qp
- pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
+ shape = [] if padding == [] else list(np.shape(padding))
+ pad_tensor = create_const_tensor(name="pad", shape=shape, values=padding, dtype=pad_dtype)
out = Tensor(out_shape, out_dtype, "out")
out.quantization = qp.clone()
op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
@@ -543,9 +544,7 @@ def test_quant_static_optimisations():
Tests if the quant value at vela compile time is calculated correctly
"""
- quant_ifm = create_const_tensor(
- "const_quant_ifm", values=np.array(127), value_dtype=np.int8, shape=[], dtype=DataType.int8
- )
+ quant_ifm = create_const_tensor("const_quant_ifm", values=np.array(127), shape=[], dtype=DataType.int8)
quant_ifm.quantization = testutil.default_quant_params()
quant_ifm.quantization.scale_f32 = 0.15748031
quant_ifm.quantization.quant_min = -128
@@ -568,9 +567,7 @@ def test_quant_static_optimisations():
assert op.ofm.values == 127
- quant_ifm = create_const_tensor(
- "const_quant_ifm", values=np.array(127), value_dtype=np.int8, shape=[], dtype=DataType.int8
- )
+ quant_ifm = create_const_tensor("const_quant_ifm", values=np.array(127), shape=[], dtype=DataType.int8)
quant_ifm.quantization = testutil.default_quant_params()
quant_ifm.quantization.scale_f32 = 0.15748031
quant_ifm.quantization.quant_min = -128
@@ -600,9 +597,7 @@ def test_optimise_quantize_multiple_values():
when passing multiple values to quantize node
"""
- quant_ifm = create_const_tensor(
- "const_quant_ifm", values=np.array([127, 127]), value_dtype=np.int8, shape=[], dtype=DataType.int8
- )
+ quant_ifm = create_const_tensor("const_quant_ifm", values=np.array([127, 127]), shape=[], dtype=DataType.int8)
quant_ifm.quantization = testutil.default_quant_params()
quant_ifm.quantization.scale_f32 = 0.15748031
quant_ifm.quantization.quant_min = -128
diff --git a/ethosu/vela/test/test_lut.py b/ethosu/vela/test/test_lut.py
index 90732707..712be7a2 100644
--- a/ethosu/vela/test/test_lut.py
+++ b/ethosu/vela/test/test_lut.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -18,8 +18,6 @@
# Unit tests for LUT support
import random
-import numpy as np
-
from ethosu.vela import lut
from ethosu.vela import mark_tensors
from ethosu.vela import pass_packing
@@ -37,9 +35,7 @@ from ethosu.vela.test import testutil
def set_256_lut(op, key, arch):
random.seed(key)
values = random.choices(range(256), k=256)
- lut_tensor = create_const_tensor(
- op.name + "_lut", [1, 1, 1, 256], DataType.int8, values, np.uint8, TensorPurpose.LUT
- )
+ lut_tensor = create_const_tensor(op.name + "_lut", [1, 1, 1, 256], DataType.int8, values, TensorPurpose.LUT)
scratch_lut_tensor = lut_tensor.clone_into_fast_storage(arch)
op.set_activation_lut(scratch_lut_tensor)
@@ -47,9 +43,7 @@ def set_256_lut(op, key, arch):
def set_1K_lut(op, key, arch):
random.seed(key)
values = random.choices(range(256), k=256)
- lut_tensor = create_const_tensor(
- op.name + "_lut", [1, 1, 1, 256], DataType.int32, values, np.uint32, TensorPurpose.LUT
- )
+ lut_tensor = create_const_tensor(op.name + "_lut", [1, 1, 1, 256], DataType.int32, values, TensorPurpose.LUT)
scratch_lut_tensor = lut_tensor.clone_into_fast_storage(arch)
op.set_activation_lut(scratch_lut_tensor)
@@ -57,9 +51,7 @@ def set_1K_lut(op, key, arch):
def set_2K_lut(op, key, arch):
random.seed(key)
values = random.choices(range(512), k=512)
- lut_tensor = create_const_tensor(
- op.name + "_lut", [1, 1, 1, 512], DataType.int32, values, np.uint32, TensorPurpose.LUT
- )
+ lut_tensor = create_const_tensor(op.name + "_lut", [1, 1, 1, 512], DataType.int32, values, TensorPurpose.LUT)
scratch_lut_tensor = lut_tensor.clone_into_fast_storage(arch)
op.set_activation_lut(scratch_lut_tensor)
diff --git a/ethosu/vela/test/test_tflite_model_semantic.py b/ethosu/vela/test/test_tflite_model_semantic.py
index c242063d..2e0936d0 100644
--- a/ethosu/vela/test/test_tflite_model_semantic.py
+++ b/ethosu/vela/test/test_tflite_model_semantic.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -195,11 +195,11 @@ def test_constraint_splitv_inferred():
# SplitV requires a maximum of one inferred shape (-1)
qp = testutil.default_quant_params()
op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
- sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, -1, 2, -1]]]], np.int16, quantization=qp)
+ sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, -1, 2, -1]]]], quantization=qp)
op.add_input_tensor(sizes)
assert not semantic_checker.is_operator_semantic_valid(op)
op = testutil.create_op_with_quant_tensors(Op.SplitV, [1, 1, 1, 8], [1, 1, 1, 8])
- sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, 1, 2, -1]]]], np.int16, quantization=qp)
+ sizes = create_const_tensor("sizes", [1, 1, 1, 4], DataType.int16, [[[[0, 1, 2, -1]]]], quantization=qp)
op.add_input_tensor(sizes)
assert semantic_checker.is_operator_semantic_valid(op)
@@ -278,7 +278,8 @@ def create_pad_op(
qp = testutil.default_quant_params()
in0 = Tensor(in_shape, in_dtype, "in")
in0.quantization = qp
- pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
+ shape = [] if padding == [] else list(np.shape(padding))
+ pad_tensor = create_const_tensor(name="pad", shape=shape, values=padding, dtype=pad_dtype)
out = Tensor(out_shape, out_dtype, "out")
out.quantization = qp.clone()
op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
@@ -449,9 +450,9 @@ def create_mean(input_shape, output_shape, axis, datatype, attrs):
ofm = Tensor(output_shape, datatype, "out")
ofm.quantization = testutil.default_quant_params()
if type(axis) is list:
- indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis, np.uint8)
+ indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis)
elif type(axis) is int:
- indices = create_const_tensor("indices", [], DataType.int32, axis, np.uint8)
+ indices = create_const_tensor("indices", [], DataType.int32, axis)
op = testutil.create_op(Op.Mean, [ifm, indices], ofm, attrs)
return op
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index d091531d..6a0b58e3 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -303,55 +303,55 @@ def test_constraint_resize():
for resize_op in Op.op_set(Op.is_resize_op):
# IFM W and H == 1
op = testutil.create_op_with_quant_tensors(resize_op, [1, 1, 1, 8], [1, 8, 8, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8]))
assert support.is_operator_supported(op)
# IFM == OFM
op = testutil.create_op_with_quant_tensors(resize_op, [1, 8, 8, 8], [1, 8, 8, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8]))
assert support.is_operator_supported(op)
# IFM x2 == OFM ; align_corners = False
op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 8, 8, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8]))
assert support.is_operator_supported(op)
# IFM x4 == OFM ; align_corners = False
op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 16, 16, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [16, 16], np.int32))
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [16, 16]))
assert support.is_operator_supported(op)
# IFM x8 == OFM ; align_corners = False
op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 32, 32, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [32, 32], np.int32))
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [32, 32]))
assert support.is_operator_supported(op)
# IFM -1 x2 == OFM -1 ; align_corners = True
op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 7, 7, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [7, 7], np.int32))
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [7, 7]))
op.attrs["align_corners"] = True
assert support.is_operator_supported(op)
# IFM -1 x4 == OFM -1 ; align_corners = True
op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 13, 13, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [13, 13], np.int32))
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [13, 13]))
op.attrs["align_corners"] = True
assert support.is_operator_supported(op)
# IFM -1 x8 == OFM -1 ; align_corners = True
op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 25, 25, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [25, 25], np.int32))
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [25, 25]))
op.attrs["align_corners"] = True
assert support.is_operator_supported(op)
# Invalid case - upscale size
op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 17, 17, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [17, 17], np.int32))
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [17, 17]))
assert not support.is_operator_supported(op)
# Invalid case - upscale size with align corners
op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 15, 15, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [15, 15], np.int32))
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [15, 15]))
op.attrs["align_corners"] = True
assert not support.is_operator_supported(op)
@@ -360,7 +360,7 @@ def test_constraint_resize_size():
for resize_op in Op.op_set(Op.is_resize_op):
# Invalid case - size != ofm size
op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 8, 8, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [7, 7], np.int32))
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [7, 7]))
assert not support.is_operator_supported(op)
@@ -368,7 +368,7 @@ def test_constraint_resize_attrs():
for resize_op in Op.op_set(Op.is_resize_op):
# Invalid case - both align corners and half-pixel centers
op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 8, 8, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8]))
op.attrs["align_corners"] = True
op.attrs["half_pixel_centers"] = True
assert not support.is_operator_supported(op)
@@ -395,7 +395,8 @@ def create_pad_op(
qp = testutil.default_quant_params()
in0 = Tensor(in_shape, in_dtype, "in")
in0.quantization = qp
- pad_tensor = create_const_tensor(name="pad", shape=list(np.shape(padding)), values=padding, dtype=pad_dtype)
+ shape = [] if padding == [] else list(np.shape(padding))
+ pad_tensor = create_const_tensor(name="pad", shape=shape, values=padding, dtype=pad_dtype)
out = Tensor(out_shape, out_dtype, "out")
out.quantization = qp.clone()
op = testutil.create_op(Op.Pad, [in0, pad_tensor], out)
@@ -587,9 +588,9 @@ def create_mean(input_shape, output_shape, axis, datatype, attrs):
ofm = Tensor(output_shape, datatype, "out")
ofm.quantization = testutil.default_quant_params()
if type(axis) is list:
- indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis, np.uint8)
+ indices = create_const_tensor("indices", [len(axis)], DataType.int32, axis)
elif type(axis) is int:
- indices = create_const_tensor("indices", [], DataType.int32, axis, np.uint8)
+ indices = create_const_tensor("indices", [], DataType.int32, axis)
op = testutil.create_op(Op.Mean, [ifm, indices], ofm, attrs)
return op
diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py
index acf35fe3..88fc8747 100644
--- a/ethosu/vela/test/testutil.py
+++ b/ethosu/vela/test/testutil.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2021 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -53,21 +53,13 @@ def create_elemwise_op(
ofm_quant=default_quant_params(),
):
# Creates elementwise operation with constant IFM/IFM2
- if datatype.size_in_bytes() == 1:
- np_type = np.uint8
- elif datatype.size_in_bytes() == 2:
- np_type = np.int16
- else:
- np_type = np.int32
op = Operation(op_type, name)
op.add_input_tensor(
- create_const_tensor(name + "_ifm", ifm_shape, datatype, np.zeros(ifm_shape), np_type, quantization=ifm_quant)
+ create_const_tensor(name + "_ifm", ifm_shape, datatype, np.zeros(ifm_shape), quantization=ifm_quant)
)
if ifm2_shape is not None:
op.add_input_tensor(
- create_const_tensor(
- name + "_ifm2", ifm2_shape, datatype, np.zeros(ifm2_shape), np_type, quantization=ifm2_quant
- )
+ create_const_tensor(name + "_ifm2", ifm2_shape, datatype, np.zeros(ifm2_shape), quantization=ifm2_quant)
)
ofm = Tensor(ofm_shape, datatype, name + "_ofm")
ofm.quantization = ofm_quant
@@ -89,25 +81,17 @@ def create_op_with_quant_tensors(
op.set_output_tensor(ofm)
# Optional weight tensor
if weights_shape is not None:
- if datatype.size_in_bytes() == 1:
- np_type = np.uint8
- elif datatype.size_in_bytes() == 2:
- np_type = np.int16
- else:
- np_type = np.int32
qp = default_quant_params()
if op.type is not Op.FullyConnected:
qp.zero_point = np.zeros(weights_shape)
- weights = create_const_tensor(
- "weights", weights_shape, datatype, np.zeros(weights_shape), np_type, quantization=qp
- )
+ weights = create_const_tensor("weights", weights_shape, datatype, np.zeros(weights_shape), quantization=qp)
op.add_input_tensor(weights)
# Optional bias tensor
if bias_shape is not None:
qp = default_quant_params()
if op.type is not Op.FullyConnected:
qp.zero_point = np.zeros(bias_shape)
- bias = create_const_tensor("bias", bias_shape, DataType.int32, np.zeros(bias_shape), np.int32, quantization=qp)
+ bias = create_const_tensor("bias", bias_shape, DataType.int32, np.zeros(bias_shape), quantization=qp)
op.add_input_tensor(bias)
if set_ifm_ofm_shapes:
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 242f0ea6..ff7b4863 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2020-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -343,17 +343,10 @@ def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
weight_quant.zero_point = 0
weight_quant.quant_dim = 0
ofm_dtype = ofm.dtype
- if ofm_dtype == DataType.uint8:
- weight_value_dtype = np.uint8
+ if ofm_dtype.type == BaseType.UnsignedInt:
weight_quant.quant_min = 0
weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
else:
- if ofm_dtype == DataType.int8:
- weight_value_dtype = np.int8
- else:
- assert ofm_dtype == DataType.int16
- weight_value_dtype = np.int16
-
weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
@@ -376,9 +369,8 @@ def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
create_const_tensor(
"weights",
weight_shape,
- ofm.dtype,
+ ofm_dtype,
np.array(weight_values).reshape(weight_shape),
- value_dtype=weight_value_dtype,
quantization=weight_quant,
),
1, # inputs tensor weight index
@@ -586,7 +578,6 @@ def convert_resizebilinear_to_depthwise_convolutions(op, half_pixel_centers=True
shape,
intermediate_tens.dtype,
np.array(kernel).reshape(shape),
- value_dtype=np.int8,
quantization=quant,
),
)
@@ -1227,9 +1218,7 @@ def convert_lrelu_to_mul_max(op, arch):
scalar, _ = scaling.elementwise_mul_scale(ifm.quantization.scale_f32, alpha, ofm.quantization.scale_f32)
else:
scalar = 1
- alpha_tens = create_const_tensor(
- op.name + "_alpha_scalar", [1], alpha_dtype, [scalar], alpha_dtype.as_numpy_type(), quantization=quantization
- )
+ alpha_tens = create_const_tensor(op.name + "_alpha_scalar", [1], alpha_dtype, [scalar], quantization=quantization)
mul_alpha.add_input_tensor(alpha_tens)
fm_alpha = ofm.clone(op.name + "_alpha", set_unique=True)
mul_alpha.set_output_tensor(fm_alpha)
@@ -1256,9 +1245,7 @@ def convert_lrelu_to_mul_max(op, arch):
quantization.max = quantization.quant_max - quantization.quant_min
quantization.scale_f32 = np.float32(1)
quantization.zero_point = 0
- identity_tens = create_const_tensor(
- op.name + "_id_scalar", [], ifm.dtype, [1], np.uint8, quantization=quantization
- )
+ identity_tens = create_const_tensor(op.name + "_id_scalar", [], ifm.dtype, [1], quantization=quantization)
mul_identity.add_input_tensor(identity_tens)
# Make sure that fm_id is allocated to a different address than fm_alpha
fm_id = ofm.clone(op.name + "_id", set_unique=True)
@@ -1470,7 +1457,6 @@ def replace_pad_by_hw_pad(op: Operation, arch, nng):
shape,
op.ifm.dtype,
weights,
- np.uint8,
purpose=TensorPurpose.Weights,
quantization=quantization,
)
@@ -1526,7 +1512,7 @@ def convert_pad(op: Operation, arch, nng):
if top > 0:
shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
zero_tens = create_const_tensor(
- op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
+ op.name + "_top", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
)
# If top/bottom or left/right are equal, the const tensors can be allocated to the same address
zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
@@ -1538,7 +1524,6 @@ def convert_pad(op: Operation, arch, nng):
shape.as_list(),
ofm.dtype,
shape.elements() * [pad_value],
- np.uint8,
quantization=quant,
)
zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
@@ -1548,14 +1533,14 @@ def convert_pad(op: Operation, arch, nng):
if left > 0:
shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
zero_tens = create_const_tensor(
- op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
+ op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
)
zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
create_avg_pool_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
if right > 0:
shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
zero_tens = create_const_tensor(
- op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
+ op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
)
zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
create_avg_pool_for_concat(
@@ -1715,7 +1700,6 @@ def convert_mean_to_depthwise_conv_or_avgpool(op, arch, nng):
weight_shape,
inp.dtype,
np.ones(weight_shape),
- value_dtype=np.uint8,
quantization=weight_quant,
),
1,
@@ -2008,8 +1992,7 @@ def tflite_optimise_graph(nng, arch):
ofm_clone = ofm.clone()
ofm_clone.values = ofm.values
ofm.values = None
- np_dtype = ofm.dtype.as_numpy_type()
- zero = create_const_tensor("zero", [1], ofm.dtype, [0], np_dtype, quantization=ofm.quantization)
+ zero = create_const_tensor("zero", [1], ofm.dtype, [0], quantization=ofm.quantization)
memcpy = create_add_nop(f"{ofm.name}_copy")
memcpy.add_input_tensor(ofm_clone)
memcpy.add_input_tensor(zero)
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index 25d3dbc6..2a599aaa 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -1,4 +1,4 @@
-# SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2021-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -164,7 +164,6 @@ def insert_add_copy_for_const(op, ifm_ofm_shape):
[1],
copy_tens.dtype,
[0],
- copy_tens.dtype.as_numpy_type(),
quantization=copy_tens.quantization,
)
copy_op = create_add_nop(name)
@@ -190,7 +189,6 @@ def insert_add_copy_op_after_tens(tens, ifm_ofm_shape):
[1],
copy_tens.dtype,
[0],
- copy_tens.dtype.as_numpy_type(),
quantization=copy_tens.quantization,
)
copy_op = create_add_nop(name)
@@ -267,9 +265,7 @@ def fix_sg_input_output_tosa(op, arch, nng):
def create_add_for_concat(concat_op, name, ifm, ifm_shape: Shape4D, write_offset: Shape4D):
"""Creates an add op for the given concat op/input feature map"""
ofm = concat_op.ofm
- ifm2 = create_const_tensor(
- name + "_zero_scalar", [1], ofm.dtype, [0], ofm.dtype.as_numpy_type(), quantization=ofm.quantization
- )
+ ifm2 = create_const_tensor(name + "_zero_scalar", [1], ofm.dtype, [0], quantization=ofm.quantization)
add_op = create_add_nop(name)
add_op.inputs = [ifm, ifm2]
@@ -306,9 +302,7 @@ def remove_splitsliceread(op, arch):
else:
name = op.name + "_add"
ofm = op.ofm
- ifm2 = create_const_tensor(
- name + "_zero_scalar", [1], ofm.dtype, [0], ofm.dtype.as_numpy_type(), quantization=ofm.quantization
- )
+ ifm2 = create_const_tensor(name + "_zero_scalar", [1], ofm.dtype, [0], quantization=ofm.quantization)
add_op = create_add_nop(name)
add_op.inputs = [op.ifm, ifm2]
add_op.outputs = [ofm]
@@ -476,14 +470,14 @@ def convert_pad_in_width(op):
if left > 0:
shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
zero_tens = create_const_tensor(
- op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
+ op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
)
zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
create_add_for_concat(op, op.name + "_left", zero_tens, shape, shp0)
if right > 0:
shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
zero_tens = create_const_tensor(
- op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
+ op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], quantization=quant
)
zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
create_add_for_concat(op, op.name + "_right", zero_tens, shape, shp0.with_width(ofm_shape.width - right))
@@ -816,9 +810,7 @@ def decomp_rewrite_pad(op, arch):
new_pad_tens = op.inputs[1].clone("_dim_{dim}")
name = op.inputs[1].name + f"_dim_{dim}"
- new_pad_tens = create_const_tensor(
- name, list(new_pad_input.shape), DataType.int32, new_pad_input, np.int32
- )
+ new_pad_tens = create_const_tensor(name, list(new_pad_input.shape), DataType.int32, new_pad_input)
pad_op.add_input_tensor(new_pad_tens)
new_ofm_shape = new_ifm_shape.copy()