blob: af6798a35bc3f93c502b52aa83568a9e80052462 [file] [edit]
#!/usr/bin/env python3
# Copyright (c) 2025 Basalte bv
#
# SPDX-License-Identifier: Apache-2.0
"""
A script to plot data in a sunburst chart generated by size_report.
When you call the ram_report or rom_report targets you end up
with a json file in the build directory that can be used as input
for this script.
Example:
./scripts/footprint/plot.py build/ram.json
Requires plotly to be installed, for example with pip:
pip install plotly
"""
import argparse
import json
try:
import plotly.graph_objects as go
except ImportError:
print("Missing dependency: You need to install plotly (see scripts/requirements-extra.txt).")
raise
def parse_args():
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
allow_abbrev=False,
)
parser.add_argument('input', help='Input json file')
parser.add_argument('--html', help='Output html file')
parser.add_argument(
'--depth',
help='Maximum render depth, pass -1 to render all levels. Defaults to 4',
type=int,
default=4,
)
return parser.parse_args()
def generate_figure(data, depth=4):
totalsize = data.get('total_size', 0)
ids = []
labels = []
parents = []
values = []
hovertext = []
def iter_node(node: dict, parent=''):
identifier = node.get('identifier')
if identifier is None:
return
if identifier in ids:
# Identifiers aren't unique, add a suffix to make them unique
idx = 0
while f'{identifier}_{idx}' in ids:
idx += 1
identifier = f'{identifier}_{idx}'
ids.append(identifier)
labels.append(node.get('name', ''))
parents.append(parent)
values.append(node.get('size', 0))
details = []
if totalsize > 0:
details.append(f'percentage: {node.get("size") / totalsize:.2%}')
if 'address' in node:
details.append(f'address: 0x{node.get("address"):08x}')
if 'section' in node:
details.append(f'section: {node.get("section")}')
hovertext.append("<br>".join(details))
for child in node.get('children', ()):
iter_node(child, identifier)
iter_node(data.get('symbols', {}))
fig = go.Figure(
go.Sunburst(
ids=ids,
labels=labels,
parents=parents,
values=values,
hovertext=hovertext,
branchvalues='total',
maxdepth=depth,
),
skip_invalid=True,
)
fig.update_layout(margin={'t': 0, 'l': 0, 'r': 0, 'b': 0})
fig.update_traces(textfont=dict(size=24))
return fig
def main():
args = parse_args()
with open(args.input) as f:
data = json.load(f)
fig = generate_figure(data, args.depth)
if args.html:
fig.write_html(args.html, auto_open=False)
return
print("Opening the default browser to render the generated plot.")
fig.show(renderer="browser")
if __name__ == "__main__":
main()