aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDwight Lidman <dwight.lidman@arm.com>2020-11-05 15:56:08 +0100
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-11-16 09:10:54 +0000
commite05de456770926fdc057478739d1b96b7f651756 (patch)
tree784fb9792a959a97c897ece67b23787abbcfcae1
parente8a5a78dd16ec979c7a7bb1f5bd87e9b2909c32d (diff)
downloadethos-u-vela-e05de456770926fdc057478739d1b96b7f651756.tar.gz
MLBEDSW-3301: Vela fails ungracefully when reading string buffers
When encountering a sparse string buffer, Vela fails both due to missing a mapping for a Numpy string type and also for not being able to read sparse buffers. The failing line is attempting to reshape a [100] buffer into a [3, 5] tensor which does not work due to Vela treating the buffer as non-sparse. The solution here is to simply not do the reshape for string buffers (which all appear to be sparse) since it is not something that will be supported in the future anyway. The related operator can then be pushed to the CPU as expected. Signed-off-by: Dwight Lidman <dwight.lidman@arm.com> Change-Id: Iea0af6cd60a691f975209014b6aa098dde8d6a4b
-rw-r--r--ethosu/vela/test/test_tflite_reader.py33
-rw-r--r--ethosu/vela/tflite_mapping.py1
-rw-r--r--ethosu/vela/tflite_reader.py7
3 files changed, 38 insertions, 3 deletions
diff --git a/ethosu/vela/test/test_tflite_reader.py b/ethosu/vela/test/test_tflite_reader.py
index 23abb4a..14c9b20 100644
--- a/ethosu/vela/test/test_tflite_reader.py
+++ b/ethosu/vela/test/test_tflite_reader.py
@@ -18,9 +18,11 @@
from unittest.mock import MagicMock
from unittest.mock import patch
+import numpy as np
import pytest
from ethosu.vela.operation import Op
+from ethosu.vela.tflite.TensorType import TensorType
from ethosu.vela.tflite_reader import TFLiteSubgraph
@@ -79,3 +81,34 @@ class TestTFLiteSubgraph:
assert len(created_op.inputs) == expected
assert created_op.outputs[0].name == "tensor_{}".format(output)
assert inputs[-1] != -1 or not created_op.inputs[-1]
+
+ string_buffer_testdata = [
+ (np.array([np.random.randint(256) for _ in range(100)], dtype=np.uint8), [3, 5]),
+ (np.array([np.random.randint(256) for _ in range(100)], dtype=np.int16), [10, 10]),
+ (np.array([np.random.randint(256) for _ in range(100)], dtype=np.float32), [100]),
+ (np.array([], dtype=np.int8), [30]),
+ ]
+
+ @pytest.mark.parametrize("buffer, tens_shape", string_buffer_testdata)
+ def test_parse_tensor_with_string_buffer(self, buffer, tens_shape):
+ tens_data = MagicMock()
+ tens_data.ShapeAsNumpy = MagicMock(return_value=np.array(tens_shape), dtype=np.int32)
+ tens_data.Name = MagicMock(return_value=b"test_data")
+ tens_data.Type = MagicMock(return_value=TensorType.STRING)
+ tens_data.Quantization = MagicMock(return_value=None)
+ tens_data.Buffer = MagicMock(return_value=0)
+
+ tfl_sg = MagicMock()
+ tfl_sg.Name = MagicMock(return_value=b"test_sg")
+ tfl_sg.TensorsLength = MagicMock(return_value=0)
+ tfl_sg.OperatorsLength = MagicMock(return_value=0)
+ tfl_sg.OutputsAsNumpy = MagicMock(return_value=[])
+ tfl_sg.InputsAsNumpy = MagicMock(return_value=[])
+
+ graph = MagicMock()
+ graph.buffers = [buffer]
+
+ subgraph = TFLiteSubgraph(graph, tfl_sg)
+
+ tens = subgraph.parse_tensor(tens_data)
+ assert tens.values is None
diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py
index bbd44b0..44ecedc 100644
--- a/ethosu/vela/tflite_mapping.py
+++ b/ethosu/vela/tflite_mapping.py
@@ -171,6 +171,7 @@ datatype_map_numpy = {
TensorType.BOOL: np.bool,
TensorType.COMPLEX64: np.complex64,
TensorType.COMPLEX128: np.complex128,
+ TensorType.STRING: np.dtype("S1"),
}
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index b3b0720..c190f7e 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -106,7 +106,8 @@ class TFLiteSubgraph:
np_shape = tens_data.ShapeAsNumpy()
shape = list(np_shape) if type(np_shape) is np.ndarray else []
name = decode_str(tens_data.Name())
- dtype = datatype_map[tens_data.Type()]
+ tens_dtype = tens_data.Type()
+ dtype = datatype_map[tens_dtype]
tens = Tensor(shape, dtype, name)
quant = tens_data.Quantization()
@@ -129,8 +130,8 @@ class TFLiteSubgraph:
tens.values = None
buf = self.graph.buffers[tens_data.Buffer()]
- if buf is not None:
- tens.values = np.array(buf.view(datatype_map_numpy[tens_data.Type()]).reshape(shape))
+ if buf is not None and dtype != DataType.string:
+ tens.values = np.array(buf.view(datatype_map_numpy[tens_dtype]).reshape(shape))
if tens.quantization is not None:
tens.quant_values = tens.values
tens.values = tens.quantization.dequantize(tens.quant_values)