aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/stats_writer.py
diff options
context:
space:
mode:
authorMichael McGeagh <michael.mcgeagh@arm.com>2020-07-30 14:36:40 +0100
committertim.hall <tim.hall@arm.com>2020-08-04 11:20:09 +0000
commitb4249745057a0f71f2710013e7f263e46a496f53 (patch)
tree43d6aa3d1f7d4c2c85785746454ef639ad6a579b /ethosu/vela/stats_writer.py
parent775e396e2939b55f5f1ea5261260533cf168d12c (diff)
downloadethos-u-vela-b4249745057a0f71f2710013e7f263e46a496f53.tar.gz
vela: Protect against divide by zero
If the total cycle count is zero (for whatever reason), then a divide by zero can occur when calculating the midpoint_fps. This change protects against that by detecting when that is the case and instead setting the midpoint_fps to nan. Further calculations using that variable is safe and results in nan throughout. Change-Id: I2d29545d331a6eb5b27b6d9c931587c15f877e74 Signed-off-by: Michael McGeagh <michael.mcgeagh@arm.com>
Diffstat (limited to 'ethosu/vela/stats_writer.py')
-rw-r--r--ethosu/vela/stats_writer.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/ethosu/vela/stats_writer.py b/ethosu/vela/stats_writer.py
index c90d9876..af7b6997 100644
--- a/ethosu/vela/stats_writer.py
+++ b/ethosu/vela/stats_writer.py
@@ -85,7 +85,10 @@ def write_summary_metrics_csv(nng, summary_filename, arch):
)
midpoint_inference_time = nng.cycles[PassCycles.Total] / arch.npu_clock
- midpoint_fps = 1 / midpoint_inference_time
+ if midpoint_inference_time > 0:
+ midpoint_fps = 1 / midpoint_inference_time
+ else:
+ midpoint_fps = np.nan
n_passes = sum(len(sg.passes) for sg in nng.subgraphs)
n_cascaded_passes = sum(len(sg.cascaded_passes) for sg in nng.subgraphs)
@@ -231,7 +234,10 @@ def print_performance_metrics_for_strat(
orig_mem_areas_labels = [(v, v.display_name()) for v in MemArea.all()]
midpoint_inference_time = cycles[PassCycles.Total] / arch.npu_clock
- midpoint_fps = 1 / midpoint_inference_time
+ if midpoint_inference_time > 0:
+ midpoint_fps = 1 / midpoint_inference_time
+ else:
+ midpoint_fps = np.nan
mem_area_labels = [
(mem_area, label) for mem_area, label in orig_mem_areas_labels if np.sum(bandwidths[mem_area]) > 0