From e05de456770926fdc057478739d1b96b7f651756 Mon Sep 17 00:00:00 2001 From: Dwight Lidman Date: Thu, 5 Nov 2020 15:56:08 +0100 Subject: MLBEDSW-3301: Vela fails ungracefully when reading string buffers When encountering a sparse string buffer, Vela fails both due to missing a mapping for a Numpy string type and also for not being able to read sparse buffers. The failing line is attempting to reshape a [100] buffer into a [3, 5] tensor which does not work due to Vela treating the buffer as non-sparse. The solution here is to simply not do the reshape for string buffers (which all appear to be sparse) since it is not something that will be supported in the future anyway. The related operator can then be pushed to the CPU as expected. Signed-off-by: Dwight Lidman Change-Id: Iea0af6cd60a691f975209014b6aa098dde8d6a4b --- ethosu/vela/test/test_tflite_reader.py | 33 +++++++++++++++++++++++++++++++++ ethosu/vela/tflite_mapping.py | 1 + ethosu/vela/tflite_reader.py | 7 ++++--- 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/ethosu/vela/test/test_tflite_reader.py b/ethosu/vela/test/test_tflite_reader.py index 23abb4a0..14c9b204 100644 --- a/ethosu/vela/test/test_tflite_reader.py +++ b/ethosu/vela/test/test_tflite_reader.py @@ -18,9 +18,11 @@ from unittest.mock import MagicMock from unittest.mock import patch +import numpy as np import pytest from ethosu.vela.operation import Op +from ethosu.vela.tflite.TensorType import TensorType from ethosu.vela.tflite_reader import TFLiteSubgraph @@ -79,3 +81,34 @@ class TestTFLiteSubgraph: assert len(created_op.inputs) == expected assert created_op.outputs[0].name == "tensor_{}".format(output) assert inputs[-1] != -1 or not created_op.inputs[-1] + + string_buffer_testdata = [ + (np.array([np.random.randint(256) for _ in range(100)], dtype=np.uint8), [3, 5]), + (np.array([np.random.randint(256) for _ in range(100)], dtype=np.int16), [10, 10]), + (np.array([np.random.randint(256) for _ in range(100)], dtype=np.float32), [100]), + (np.array([], dtype=np.int8), [30]), + ] + + @pytest.mark.parametrize("buffer, tens_shape", string_buffer_testdata) + def test_parse_tensor_with_string_buffer(self, buffer, tens_shape): + tens_data = MagicMock() + tens_data.ShapeAsNumpy = MagicMock(return_value=np.array(tens_shape), dtype=np.int32) + tens_data.Name = MagicMock(return_value=b"test_data") + tens_data.Type = MagicMock(return_value=TensorType.STRING) + tens_data.Quantization = MagicMock(return_value=None) + tens_data.Buffer = MagicMock(return_value=0) + + tfl_sg = MagicMock() + tfl_sg.Name = MagicMock(return_value=b"test_sg") + tfl_sg.TensorsLength = MagicMock(return_value=0) + tfl_sg.OperatorsLength = MagicMock(return_value=0) + tfl_sg.OutputsAsNumpy = MagicMock(return_value=[]) + tfl_sg.InputsAsNumpy = MagicMock(return_value=[]) + + graph = MagicMock() + graph.buffers = [buffer] + + subgraph = TFLiteSubgraph(graph, tfl_sg) + + tens = subgraph.parse_tensor(tens_data) + assert tens.values is None diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py index bbd44b0b..44ecedcc 100644 --- a/ethosu/vela/tflite_mapping.py +++ b/ethosu/vela/tflite_mapping.py @@ -171,6 +171,7 @@ datatype_map_numpy = { TensorType.BOOL: np.bool, TensorType.COMPLEX64: np.complex64, TensorType.COMPLEX128: np.complex128, + TensorType.STRING: np.dtype("S1"), } diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index b3b0720a..c190f7e2 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -106,7 +106,8 @@ class TFLiteSubgraph: np_shape = tens_data.ShapeAsNumpy() shape = list(np_shape) if type(np_shape) is np.ndarray else [] name = decode_str(tens_data.Name()) - dtype = datatype_map[tens_data.Type()] + tens_dtype = tens_data.Type() + dtype = datatype_map[tens_dtype] tens = Tensor(shape, dtype, name) quant = tens_data.Quantization() @@ -129,8 +130,8 @@ class TFLiteSubgraph: tens.values = None buf = self.graph.buffers[tens_data.Buffer()] - if buf is not None: - tens.values = np.array(buf.view(datatype_map_numpy[tens_data.Type()]).reshape(shape)) + if buf is not None and dtype != DataType.string: + tens.values = np.array(buf.view(datatype_map_numpy[tens_dtype]).reshape(shape)) if tens.quantization is not None: tens.quant_values = tens.values tens.values = tens.quantization.dequantize(tens.quant_values) -- cgit v1.2.1