diff options
-rw-r--r-- | driver_library/src/ethosu.cpp | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/driver_library/src/ethosu.cpp b/driver_library/src/ethosu.cpp index 6b30827..2c498a8 100644 --- a/driver_library/src/ethosu.cpp +++ b/driver_library/src/ethosu.cpp @@ -57,12 +57,25 @@ size_t getShapeSize(const flatbuffers::Vector<int32_t> *shape) { return size; } +size_t getTensorTypeSize(const enum tflite::TensorType type) { + switch (type) { + case tflite::TensorType::TensorType_UINT8: + case tflite::TensorType::TensorType_INT8: + return 1; + case tflite::TensorType::TensorType_INT16: + return 2; + default: + throw EthosU::Exception("Unsupported tensor type"); + } +} + vector<size_t> getSubGraphDims(const tflite::SubGraph *subgraph, const flatbuffers::Vector<int32_t> *tensorMap) { vector<size_t> dims; for (auto index = tensorMap->begin(); index != tensorMap->end(); ++index) { auto tensor = subgraph->tensors()->Get(*index); size_t size = getShapeSize(tensor->shape()); + size *= getTensorTypeSize(tensor->type()); if (size > 0) { dims.push_back(size); |