blob: 89d01222a8a2a9c95f3384be57f8fe8db007786c [file] [log] [blame]
# Copyright 2024 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Plot ADC updates as an SVG from a device logfile."""
import argparse
from datetime import datetime
from datetime import timedelta
import logging
from pathlib import Path
import sys
import pw_cli.log
from gonk_tools.gonk import DEFAULT_CSV_LOGFILE
from gonk_tools.adc import (
get_channel_names,
ADC_COUNT,
)
import matplotlib.pyplot as plt
import numpy as np
_LOG = logging.getLogger(__package__)
def _parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'-i',
'--input-csv',
type=Path,
default=Path(DEFAULT_CSV_LOGFILE),
help=f'Input CSV text file. Default: {DEFAULT_CSV_LOGFILE}',
)
parser.add_argument(
'-o',
'--output-svg',
type=Path,
default=None,
help='Output svg file.',
)
return parser.parse_args()
def plot(
input_csv: Path,
output_svg: Path,
) -> int:
"""Plot ADC values."""
# pylint: disable=too-many-locals
pw_cli.log.install()
_LOG.info('Input CSV: %s', input_csv)
if output_svg:
_LOG.info('Output svg file: %s', output_svg)
interactive_plotting = not (output_svg)
if interactive_plotting:
_LOG.info('No outputs specified; plotting interactively')
start_time: datetime | None = None
time_values = []
vbus_values: list[list[float]] = []
vshunt_values: list[list[float]] = []
power_values: list[list[float]] = []
channel_names = get_channel_names()
for i in range(ADC_COUNT):
vbus_values.append([])
vshunt_values.append([])
power_values.append([])
# CSV Fields in order:
csv_field_count = 1 # For host timestamp
csv_field_count += 1 # For delta microseconds
csv_field_count += ADC_COUNT * 3 # Voltage, current, and power for each ADC
csv_field_count += 1 # For 'Header pin assert' string at the end
with input_csv.open() as f:
current_time = datetime.now()
for line in f.readlines():
parts = [line.strip() for line in line.split(',')]
if len(parts) != csv_field_count:
_LOG.warning(
'Skipping line due to unexpected number of '
'CSV fields: %i',
len(parts),
)
_LOG.warning('Fields: %s', parts)
_LOG.warning('Line: "%s"', line)
continue # Skip this line
# Extract the host timestamp
dtstr = parts[0]
try:
dt = datetime.strptime(dtstr, '%Y%m%d %H:%M:%S.%f')
except ValueError as err:
# Output a warning if this isn't the csv header line.
if not any(
channel_name in line for channel_name in get_channel_names()
):
_LOG.warning(
'Skipping line due to error parsing host time: ' '%s',
err,
)
_LOG.warning(' Line: "%s"', line)
continue # Skip this line
# Extract delta_micros
delta_micros = int(parts[1])
# Set start timestamp if not found already.
if not start_time:
start_time = dt
current_time = start_time - timedelta(microseconds=delta_micros)
# Increment delta_micros
current_time += timedelta(microseconds=delta_micros)
# Save the timestamp
time_values.append((current_time - start_time).total_seconds())
index = 2
voltages = list(float(i) for i in parts[index : index + ADC_COUNT])
index += ADC_COUNT
current = list(float(i) for i in parts[index : index + ADC_COUNT])
index += ADC_COUNT
power = list(float(i) for i in parts[index : index + ADC_COUNT])
index += 1
# TODO(tonymd): Plot GPIO events somehow.
_gpio_assert = parts[index]
for i, voltage in enumerate(voltages):
vbus_values[i].append(voltage)
vshunt_values[i].append(current[i])
power_values[i].append(power[i])
# Plot vbus and vshunt values.
_fig, (ax1, ax2, ax3) = plt.subplots( # type: ignore
3, 1, layout='constrained', figsize=[11.67, 8.27]
)
times = np.asarray(time_values)
ax1.set_xlabel('Time (s)')
ax2.set_xlabel('Time (s)')
ax3.set_xlabel('Time (s)')
linewidth = 0.7
ax1.set_ylabel('vbus')
for i in range(ADC_COUNT):
ax1.plot(
times,
np.asarray(vbus_values[i]),
label=channel_names[i],
linestyle='solid',
linewidth=linewidth,
)
ax2.set_ylabel('Ishunt')
for i in range(ADC_COUNT):
ax2.plot(
times,
np.asarray(vshunt_values[i]),
label=channel_names[i],
linestyle='solid',
linewidth=linewidth,
)
ax3.set_ylabel('Power')
for i in range(ADC_COUNT):
ax3.plot(
times,
np.asarray(power_values[i]),
label=channel_names[i],
linestyle='solid',
linewidth=linewidth,
)
ax1.legend()
ax1.grid(True)
ax2.legend()
ax2.grid(True)
ax3.legend()
ax3.grid(True)
if output_svg:
plt.savefig(output_svg)
_LOG.info('Output svg saved: %s', output_svg.resolve())
if interactive_plotting:
plt.show()
return 0
def main() -> None:
sys.exit(plot(**vars(_parse_args())))
if __name__ == '__main__':
main()