aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_reader.py
diff options
context:
space:
mode:
authorDwight Lidman <dwight.lidman@arm.com>2020-11-05 15:56:08 +0100
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-11-16 09:10:54 +0000
commite05de456770926fdc057478739d1b96b7f651756 (patch)
tree784fb9792a959a97c897ece67b23787abbcfcae1 /ethosu/vela/tflite_reader.py
parente8a5a78dd16ec979c7a7bb1f5bd87e9b2909c32d (diff)
downloadethos-u-vela-e05de456770926fdc057478739d1b96b7f651756.tar.gz
MLBEDSW-3301: Vela fails ungracefully when reading string buffers
When encountering a sparse string buffer, Vela fails both due to missing a mapping for a Numpy string type and also for not being able to read sparse buffers. The failing line is attempting to reshape a [100] buffer into a [3, 5] tensor which does not work due to Vela treating the buffer as non-sparse. The solution here is to simply not do the reshape for string buffers (which all appear to be sparse) since it is not something that will be supported in the future anyway. The related operator can then be pushed to the CPU as expected. Signed-off-by: Dwight Lidman <dwight.lidman@arm.com> Change-Id: Iea0af6cd60a691f975209014b6aa098dde8d6a4b
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r--ethosu/vela/tflite_reader.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index b3b0720a..c190f7e2 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -106,7 +106,8 @@ class TFLiteSubgraph:
np_shape = tens_data.ShapeAsNumpy()
shape = list(np_shape) if type(np_shape) is np.ndarray else []
name = decode_str(tens_data.Name())
- dtype = datatype_map[tens_data.Type()]
+ tens_dtype = tens_data.Type()
+ dtype = datatype_map[tens_dtype]
tens = Tensor(shape, dtype, name)
quant = tens_data.Quantization()
@@ -129,8 +130,8 @@ class TFLiteSubgraph:
tens.values = None
buf = self.graph.buffers[tens_data.Buffer()]
- if buf is not None:
- tens.values = np.array(buf.view(datatype_map_numpy[tens_data.Type()]).reshape(shape))
+ if buf is not None and dtype != DataType.string:
+ tens.values = np.array(buf.view(datatype_map_numpy[tens_dtype]).reshape(shape))
if tens.quantization is not None:
tens.quant_values = tens.values
tens.values = tens.quantization.dequantize(tens.quant_values)