From 5e632c242e4e6445c067f74ccc3bf8f14da7f60b Mon Sep 17 00:00:00 2001 From: Kristofer Jonsson Date: Wed, 25 Nov 2020 10:58:38 +0100 Subject: 16 bit network support Change-Id: I65cd027dea115d2f50f302ebe997d6d2525e0d7e --- driver_library/src/ethosu.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/driver_library/src/ethosu.cpp b/driver_library/src/ethosu.cpp index 6b30827..2c498a8 100644 --- a/driver_library/src/ethosu.cpp +++ b/driver_library/src/ethosu.cpp @@ -57,12 +57,25 @@ size_t getShapeSize(const flatbuffers::Vector *shape) { return size; } +size_t getTensorTypeSize(const enum tflite::TensorType type) { + switch (type) { + case tflite::TensorType::TensorType_UINT8: + case tflite::TensorType::TensorType_INT8: + return 1; + case tflite::TensorType::TensorType_INT16: + return 2; + default: + throw EthosU::Exception("Unsupported tensor type"); + } +} + vector getSubGraphDims(const tflite::SubGraph *subgraph, const flatbuffers::Vector *tensorMap) { vector dims; for (auto index = tensorMap->begin(); index != tensorMap->end(); ++index) { auto tensor = subgraph->tensors()->Get(*index); size_t size = getShapeSize(tensor->shape()); + size *= getTensorTypeSize(tensor->type()); if (size > 0) { dims.push_back(size); -- cgit v1.2.1