blob: 4b4a1119334290ac0e03fc66db085250a4317a21 [file] [log] [blame]
import argparse
import asyncio
import contextlib
import datetime
import ipaddress
import json
import logging
import os.path
import pathlib
import random
import re
import signal
import socket
import ssl
import string
import subprocess
import sys
import tempfile
import xml.etree.ElementTree
from enum import Enum
from pathlib import Path
from typing import Awaitable, Callable, Literal, Optional, Tuple
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric.types import CertificateIssuerPrivateKeyTypes, CertificatePublicKeyTypes
from cryptography.x509.oid import ExtendedKeyUsageOID, NameOID
from fastapi import APIRouter, FastAPI, HTTPException, Query, Request, Response
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
from zeroconf import ServiceInfo, Zeroconf
log = logging.getLogger(__name__)
module_dir_path = os.path.dirname(os.path.realpath(__file__))
templates_path = os.path.join(module_dir_path, "templates")
static_path = os.path.join(module_dir_path, "static")
'''
The initialisation segments must have .init extension as per CMAF-Ingest requirements.
https://dashif.org/Ingest/#interface-2-naming
'''
VALID_EXTENSIONS = ["mpd", "m3u8", "m4s", "init"]
class WorkingDirectory:
"""
Collection of utilities to add convention to the files used by this program.
"""
tmp = None
def __init__(self, directory: Optional[str] = None) -> None:
if directory is None:
self.tmp = tempfile.TemporaryDirectory(prefix="TC_PAVS_1_0")
else:
d = pathlib.Path(directory)
d.mkdir(parents=True, exist_ok=True)
self.directory = d
def __enter__(self):
return self
def __exit__(self, exc, value, tb):
self.cleanup()
def cleanup(self):
if self.tmp:
self.tmp.cleanup()
def root_dir(self) -> Path:
return Path(self.tmp.name) if self.tmp else self.directory
def path(self, *paths: str) -> Path:
return Path(os.path.join(self.root_dir(), *paths))
def mkdir(self, *paths: str, is_file=False) -> Path:
"""
Create a directory using the given path rooted in the working directory.
If a file is provided, the directory up to that file will be created instead.
Returns the full path.
"""
p = self.path(*paths)
# Let's create the parent directories exist
p2 = pathlib.Path(p)
if is_file:
p2 = p2.parent
p2.mkdir(parents=True, exist_ok=True)
return p
def print_tree(self):
# TODO Convert this helper to build a HTML representation for use in the UI
def tree(dir_path: pathlib.Path, prefix: str = ""):
"""A recursive generator, given a directory Path object
will yield a visual tree structure line by line
with each line prefixed by the same characters
"""
# prefix components:
space = " "
branch = "│ "
# pointers:
tee = "├── "
last = "└── "
contents = list(dir_path.iterdir())
# contents each get pointers that are ├── with a final └── :
pointers = [tee] * (len(contents) - 1) + [last]
for pointer, path in zip(pointers, contents):
is_dir = path.is_dir()
yield prefix + pointer + path.name + ("/" if is_dir else "")
if path.is_dir(): # extend the prefix and recurse:
extension = branch if pointer == tee else space
# i.e. space because last, └── , above so no more |
yield from tree(path, prefix=prefix + extension)
root = self.root_dir()
print(root)
for line in tree(pathlib.Path(root)):
print(line)
class CAHierarchy:
"""
Utilities to manage a CA hierarchy on disk.
"""
default_ca_duration = datetime.timedelta(days=365.25*20)
client_key_usage_cert = x509.KeyUsage(
digital_signature=True,
content_commitment=False,
key_encipherment=True,
data_encipherment=False,
key_agreement=False,
key_cert_sign=False,
crl_sign=False,
encipher_only=False,
decipher_only=False,
)
server_key_usage_cert = x509.KeyUsage(
digital_signature=True,
content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False,
key_cert_sign=False,
crl_sign=False,
encipher_only=False,
decipher_only=False,
)
def __init__(self, base: Path, name: str, kind: Literal['server', 'client']) -> None:
self.name = name
self.kind = kind
self.directory = base
self.root_cert_path = self.directory / "root.pem"
self.root_key_path = self.directory / "root.key"
if self.root_key_path.exists() and self.root_cert_path.exists():
# Root certificate already exists, re-using them
self.root_cert = x509.load_pem_x509_certificate(
self.root_cert_path.read_bytes()
)
self.root_key = serialization.load_pem_private_key(
self.root_key_path.read_bytes(), None
)
log.info(f"CA Hierarchy loaded from disk: {self.name}")
elif self.root_key_path.exists() or self.root_cert_path.exists():
# Only one of the two file exists, bailing out
log.error("root certificate partially exist on disk, stopping early")
sys.exit(1)
else:
# Start generating the root certificate
self.root_key = rsa.generate_private_key(
public_exponent=65537, key_size=2048
)
rand_suffix = "".join(
random.choices(string.ascii_letters + string.digits, k=16)
)
root_cert_subject = x509.Name(
[
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "CSA"),
x509.NameAttribute(
NameOID.COMMON_NAME, "TC_PAVS root " + rand_suffix
),
]
)
self.root_cert = (
x509.CertificateBuilder()
.subject_name(root_cert_subject)
.issuer_name(root_cert_subject)
.public_key(self.root_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.now(datetime.timezone.utc))
.not_valid_after(
datetime.datetime.now(datetime.timezone.utc) + self.default_ca_duration
)
.add_extension(
# We make it so that our root can only issue leaf certificates, no intermediate here.
x509.BasicConstraints(ca=True, path_length=0), critical=True
)
.add_extension(
x509.KeyUsage(
digital_signature=True,
content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False,
key_cert_sign=True,
crl_sign=True,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
.add_extension(
x509.SubjectKeyIdentifier.from_public_key(
self.root_key.public_key()
),
critical=False,
)
.sign(self.root_key, hashes.SHA256())
)
self._save_cert("root", self.root_cert, self.root_key, False)
log.info(f"CA Hierarchy generated: {self.name}")
def _save_cert(
self,
name: str,
cert: x509.Certificate,
key: Optional[CertificateIssuerPrivateKeyTypes],
bundle_root: bool,
) -> tuple[Optional[Path], Path]:
"""
Private method that help with saving certificate and key to the hierarchy folder.
This tool isn't meant to be used in production, but instead to help with development
and as such have the goal to make the CA hierarchy as available as possible, which in
turn make it very unsecure.
"""
cert_path = self.directory / f"{name}.pem"
key_path = self.directory / f"{name}.key" if key else None
if key and key_path:
with open(key_path, "wb") as f:
f.write(
key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
)
with open(cert_path, "wb") as f:
f.write(cert.public_bytes(serialization.Encoding.PEM))
if bundle_root:
f.write(b"\n")
f.write(self.root_cert.public_bytes(serialization.Encoding.PEM))
return (key_path, cert_path)
def _sign_cert(
self,
dns: str,
public_key: CertificatePublicKeyTypes,
duration: datetime.timedelta,
ip_address: Optional[str] = None
) -> x509.Certificate:
"""
Generate and sign a certificate.
"""
# Use ip_address for Common Name if provided, otherwise use dns
common_name = ip_address if ip_address else dns
# Sign certificate
subject = x509.Name(
[
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "CSA"),
x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "TC_PAVS"),
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
]
)
extended_key_usage = [ExtendedKeyUsageOID.CLIENT_AUTH] if self.kind == "client" else [
ExtendedKeyUsageOID.SERVER_AUTH]
builder = (x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(self.root_cert.subject)
.public_key(public_key)
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.now(datetime.timezone.utc))
.not_valid_after(
datetime.datetime.now(datetime.timezone.utc) + duration
)
.add_extension(
x509.BasicConstraints(ca=False, path_length=None),
critical=False,
)
.add_extension(
self.client_key_usage_cert if self.kind == "client" else self.server_key_usage_cert,
critical=True,
)
.add_extension(
x509.ExtendedKeyUsage(extended_key_usage),
critical=False,
)
.add_extension(
x509.SubjectKeyIdentifier.from_public_key(public_key),
critical=False,
)
.add_extension(
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
self.root_cert.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier
).value
),
critical=False,
)
.add_extension(x509.CRLDistributionPoints([x509.DistributionPoint(
full_name=[x509.UniformResourceIdentifier("http://not.a.valid.website.com/some/path/to/a.crl")],
relative_name=None,
reasons=None,
crl_issuer=None
)]), critical=False)
)
if self.kind == 'server':
san_names: list[x509.DNSName | x509.IPAddress] = [x509.DNSName(dns)]
if ip_address:
san_names.append(x509.IPAddress(ipaddress.ip_address(ip_address)))
builder.add_extension(
x509.SubjectAlternativeName(san_names),
critical=False,
)
return builder.sign(self.root_key, hashes.SHA256())
def gen_cert(self, dns: str, csr: str, override=False, duration: datetime.timedelta = datetime.timedelta(hours=1)) -> tuple[Path, Path, bool]:
"""
Generate a certificate signed by this CA hierarchy using the provided CSR.
Returns the path to the key, cert, and whether it was reused or not.
"""
signing_request = x509.load_pem_x509_csr(csr.encode('utf-8'))
signing_request.public_key()
# If we don't always override, first check if an existing keypair already exists
if not override:
cert_path = self.directory / f"{dns}.pem"
key_path = self.directory / f"{dns}.key"
if cert_path.exists() and key_path.exists():
return (key_path, cert_path, True)
# Sign certificate
cert = self._sign_cert(dns, signing_request.public_key(), duration)
# Save that information to disk
(key_path, cert_bundle_path) = self._save_cert(
dns, cert, None, bundle_root=True
)
log.debug("leaf generated. dns=%s; path=%s", dns, cert_bundle_path)
return (key_path, cert_bundle_path, False)
def gen_keypair(self, dns: str,
override=False,
duration: datetime.timedelta = datetime.timedelta(hours=1),
ip_address: Optional[str] = None) -> tuple[Path, Path, bool]:
"""
Generate a private key as well as the associated certificate signed by this CA
hierarchy. Returns the path to the key, cert, and whether it was reused or not.
"""
# If we don't always override, first check if an existing keypair already exists
if not override:
cert_path = self.directory / f"{dns}.pem"
key_path = self.directory / f"{dns}.key"
if cert_path.exists() and key_path.exists():
cert = x509.load_pem_x509_certificate(cert_path.read_bytes())
if datetime.datetime.now(datetime.timezone.utc) < cert.not_valid_after:
# We only reuse the certificate/key if the cert is still valid
return (key_path, cert_path, True)
# Generate private key
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
# Sign certificate
cert = self._sign_cert(dns, key.public_key(), duration, ip_address=ip_address)
# Save that information to disk
(key_path, cert_bundle_path) = self._save_cert(dns, cert, key, bundle_root=True)
if key_path is None:
raise ValueError("Key path should always be set")
log.debug("leaf generated. dns=%s; path=%s", dns, cert_bundle_path)
return (key_path, cert_bundle_path, False)
class SignClientCertificate(BaseModel):
"""Request model to sign a client certificate"""
csr: str
class TrackNameRequest(BaseModel):
"""Request model to update track name for a stream"""
track_name: str
class SupportedIngestInterface(str, Enum):
cmaf = "cmaf-ingest" # Interface 1
dash = "dash" # Interface 2, DASH version
hls = "hls" # Interface 2, HLS version
class UploadError(BaseModel):
session_id: Optional[int]
file_path: str
reasons: list[str]
class ValidUpload(BaseModel):
session_id: Optional[int]
file_path: str
class Session(BaseModel):
# The id is the index in the stream's list.
# Keeping a duplicated value here to have it included in API responses.
id: int
uploaded_segments: list[Tuple[str, str]] = []
uploaded_manifests: list[Tuple[str, str]] = []
complete: bool = False
class Stream(BaseModel):
# Configuration of the PushAv stream
id: int
strict_mode: bool = True
interface: SupportedIngestInterface
track_name: Optional[str] = None
# Keep track of the various sessions encountered
sessions: list[Session] = []
# tracking uploads with unique file paths
error_uploads: list[UploadError] = []
valid_uploads: list[ValidUpload] = []
# Utilities
def save_to_disk(self, wd: WorkingDirectory):
p = wd.path("streams", str(self.id), "stream.json")
with open(p, 'w', encoding='utf-8') as f:
json.dump(self.model_dump(), f, ensure_ascii=False, indent=4)
def new_session(self) -> Session:
session_id = len(self.sessions) + 1
session = Session(id=session_id)
self.sessions.append(session)
return session
def last_in_progress_session(self) -> Optional[Session]:
if len(self.sessions) == 0:
return None
last_session = self.sessions[-1]
if not last_session.complete:
return last_session
return None
def _is_file_in_error_uploads(self, file_path: str) -> bool:
"""Check if a file path already exists in error_uploads"""
return any(error.file_path == file_path for error in self.error_uploads)
def _is_file_in_valid_uploads(self, file_path: str) -> bool:
"""Check if a file path already exists in valid_uploads"""
return any(valid.file_path == file_path for valid in self.valid_uploads)
def add_error_upload(self, session_id: Optional[int], file_path: str, reasons: list[str]):
"""Add a file to error_uploads if it doesn't already exist"""
if not self._is_file_in_error_uploads(file_path):
# Check if file exists in valid_uploads and remove it
if self._is_file_in_valid_uploads(file_path):
self.valid_uploads = [valid for valid in self.valid_uploads if valid.file_path != file_path]
self.error_uploads.append(UploadError(session_id=session_id, file_path=file_path, reasons=reasons))
def add_valid_upload(self, session_id: Optional[int], file_path: str):
"""Add a file to valid_uploads if it doesn't already exist and isn't in error_uploads"""
if not self._is_file_in_valid_uploads(file_path) and not self._is_file_in_error_uploads(file_path):
self.valid_uploads.append(ValidUpload(session_id=session_id, file_path=file_path))
class PushAvServer:
templates = Jinja2Templates(directory=templates_path)
def __init__(self, wd: WorkingDirectory, device_hierarchy: CAHierarchy, strict_mode: bool):
self.wd = wd
self.device_hierarchy = device_hierarchy
self.strict_mode = strict_mode
self.router = APIRouter()
# In-memory map to track camera streams
self.streams = self._list_streams()
# UI
self.router.add_api_route("/", self.index, methods=["GET"], response_class=RedirectResponse)
self.router.add_api_route("/ui/streams", self.ui_streams_list, methods=["GET"], response_class=HTMLResponse)
self.router.add_api_route("/ui/streams/{stream_id}", self.ui_streams_details, methods=["GET"])
self.router.add_api_route("/ui/streams/{stream_id}/{file_path:path}", self.ui_streams_file_details, methods=["GET"])
self.router.add_api_route("/ui/certificates", self.ui_certificates_list, methods=["GET"], response_class=HTMLResponse)
self.router.add_api_route("/ui/certificates/{hierarchy}/{name}",
self.ui_certificates_details, methods=["GET"], response_class=HTMLResponse)
# HTTP APIs
self.router.add_api_route("/streams", self.create_stream, methods=["POST"], status_code=201)
self.router.add_api_route("/streams", self.list_streams, methods=["GET"])
self.router.add_api_route("/streams/probe/{stream_id}/{file_path:path}", self.ffprobe_check, methods=["GET"])
self.router.add_api_route("/streams/{stream_id}/{file_path:path}.{ext}", self.handle_upload, methods=["PUT"])
self.router.add_api_route("/streams/{stream_id}/{file_path:path}", self.segment_download, methods=["GET"])
self.router.add_api_route("/streams/{stream_id}/trackName", self.update_track_name, methods=["POST"], status_code=202)
self.router.add_api_route("/certs", self.list_certs, methods=["GET"], status_code=200)
self.router.add_api_route("/certs/{hierarchy}/{name}", self.certificate_details, methods=["GET"], status_code=200)
self.router.add_api_route("/certs/{name}/keypair", self.create_client_keypair, methods=["POST"])
self.router.add_api_route("/certs/{name}/sign", self.sign_client_certificate, methods=["POST"])
# Utilities
def _list_streams(self):
streams: dict[str, Stream] = {}
for stream_path in self.wd.path("streams").iterdir():
if stream_path.is_dir():
stream_file = stream_path / "stream.json"
if stream_file.exists():
with open(stream_file, 'r', encoding='utf-8') as f:
stream_data = json.load(f)
streams[stream_path.name] = Stream.model_validate(stream_data)
return streams
@contextlib.contextmanager
def _open_stream(self, stream_id: int):
"""Context manager helper to save a stream after use.
Note that any exceptions raised within the context will prevent streams from being saved to disk.
"""
stream_id_str = str(stream_id)
stream = self.streams.get(stream_id_str)
if stream is None:
raise HTTPException(status_code=400, detail="Stream ID doesn't exist")
yield stream
self.streams[stream_id_str].save_to_disk(self.wd)
# UI website
def index(self):
return RedirectResponse("/ui/streams")
def ui_streams_list(self, request: Request):
s = self.list_streams()
return self.templates.TemplateResponse(
request=request, name="streams_list.jinja2", context={"streams": s["streams"]}
)
def ui_streams_details(self, request: Request, stream_id: int):
context = {}
stream = self.streams.get(str(stream_id))
if stream is None:
raise HTTPException(status_code=400, detail="Stream ID doesn't exist")
context['stream'] = stream
return self.templates.TemplateResponse(request=request, name="streams_details.jinja2", context=context)
def ui_streams_file_details(self, request: Request, stream_id: int, file_path: str):
context = {}
context['streams'] = self.list_streams()['streams']
context['stream_id'] = stream_id
context['file_path'] = file_path
if file_path.endswith('.crt'):
context['type'] = 'cert'
p = self.wd.path("streams", str(stream_id), file_path)
with open(p, "r") as f:
context['cert'] = json.load(f)
else:
context['type'] = 'media'
context['probe'] = self.ffprobe_check(stream_id, file_path)
context['pretty_probe'] = json.dumps(context['probe'], sort_keys=True, indent=4)
return self.templates.TemplateResponse(request=request, name="streams_file_details.jinja2", context=context)
def ui_certificates_list(self, request: Request):
return self.templates.TemplateResponse(
request=request, name="certificates_list.jinja2", context={"certs": self.list_certs()}
)
def ui_certificates_details(self, request: Request, hierarchy: str, name: str):
context = self.certificate_details(hierarchy, name)
context["certs"] = self.list_certs()
return self.templates.TemplateResponse(request=request, name="certificates_details.jinja2", context=context)
# APIs
def create_stream(self, interface: SupportedIngestInterface = Query(default=SupportedIngestInterface.cmaf)):
# Find the last registered stream
stream_ids = [int(d.name) for d in self.wd.path("streams").iterdir() if d.is_dir() and d.name.isdigit()]
last_stream = max(stream_ids) if stream_ids else 0
stream_id = last_stream + 1
stream_id_str = str(stream_id)
# Initialize entry in stream files map
stream = Stream(
id=stream_id,
strict_mode=self.strict_mode,
interface=interface,
)
self.streams[stream_id_str] = stream
self.wd.mkdir("streams", str(stream_id))
self.streams[stream_id_str].save_to_disk(self.wd)
log.info(f"Stream created: id={stream_id}, interface={interface}")
return stream # TODO Update TH to use sessions instead
def list_streams(self):
return {"streams": list(self.streams.values())}
async def handle_upload(self, stream_id: int, file_path: str, ext: str, req: Request):
"""
Handle file upload for a given stream.
Validate the file name based on the extension and path format.
Always save the uploaded file to disk for further analysis.
If strict mode is enabled, return bad requests with the errors if any.
TODO Currently doesn't gracefuly handle overwrite uploads (e.g. when the same session number is reused)
"""
log.debug(f"Upload started: stream={stream_id}, file={file_path}.{ext}")
with self._open_stream(stream_id) as stream:
file_path_with_ext = f"{file_path}.{ext}"
session = stream.last_in_progress_session()
body = await req.body()
# Validate the incoming file upload (path and extension)
errors = []
if ext == "mpd":
# DASH manifest files
if (stream.interface != SupportedIngestInterface.dash):
errors.append("Unsupported manifest object extension")
if session is None:
session = stream.new_session()
root = xml.etree.ElementTree.fromstring(body)
mpd_type = root.attrib.get('type')
if mpd_type == "dynamic" and len(session.uploaded_segments) > 0:
errors.append("Dynamic MPD cannot be uploaded after segments have been uploaded")
if mpd_type == "static" and len(session.uploaded_segments) == 0:
errors.append("Static MPD cannot be uploaded before segments have been uploaded")
if mpd_type == "static":
session.complete = True
path_regex = re.compile(r"^session_\d+/index$")
if not path_regex.match(file_path):
errors.append("DASH manifest must be uploaded as session_X/index.mpd")
session.uploaded_manifests.append((file_path_with_ext, file_path_with_ext + ".crt"))
elif ext == "m3u8":
# HLS manifest files
if stream.interface != SupportedIngestInterface.hls:
errors.append("Unsupported manifest object extension")
if session is None:
session = stream.new_session()
# TODO Lifecycle validation for HLS manifests
session.uploaded_manifests.append((file_path_with_ext, file_path_with_ext + ".crt"))
elif ext == "m4s" or ext == "init":
# Segmented video files
if session is not None:
session.uploaded_segments.append((file_path_with_ext, file_path_with_ext + ".crt"))
else:
errors.append("No active session when uploading " + file_path_with_ext + ", segment uploaded before mpd")
# The Track's init segment is uploaded as `session_name/track_name/track_name.init`.
# Note that the extension is not part of the `file_path` variable.
#
# `/session_1/index.mpd` - Initial upload. Has `MPD@type="dynamic"`.
# `/session_1/video1/video1.init`
# `/session_1/audio1/audio1.init`
# `/session_1/video1/segment_1001.m4s`
# `/session_1/audio1/segment_1001.m4s`
# `/session_1/video1/segment_1002.m4s`
# `/session_1/audio1/segment_1002.m4s`
# `/session_1/video1/segment_1003.m4s`
# `/session_1/audio1/segment_1003.m4s`
# `/session_1/index.mpd` - Final upload. Has `MPD@type="static"`.
path_regex = r"^session_\d+/(?P<trackName>[^/]+)/segment_\d+$"
if ext == "init":
path_regex = r"^session_\d+/(?P<trackName>[^/]+)/[^/]+"
path_regex = re.compile(path_regex)
match = path_regex.match(file_path)
if not match:
errors.append("Path does not adhere to Matter's path format")
else:
# Validate if the trackName is same as the one assigned during transport allocation.
# https://github.com/CHIP-Specifications/connectedhomeip-spec/blob/master/src/app_clusters/PushAVStreamTransport.adoc#685-trackname-field
track_name_in_path = match.group("trackName")
track_name = stream.track_name
# TODO The track naming strategy has changed in 1.5.1. It's now a per audio/video stream option.
if track_name and track_name != track_name_in_path:
errors.append("Track name mismatch: "
f"{track_name_in_path} != {track_name}, "
"must match TrackName provided in ContainerOptions")
else:
errors.append(f"Invalid extension: {ext}, valid extensions are {', '.join(VALID_EXTENSIONS)}")
# Validation complete, now saving data to disk
file_local_path = self.wd.mkdir("streams", str(stream_id), file_path_with_ext, is_file=True)
# If file already exists, create versioned backup.
# Especially useful for manifests that have a fixed location.
if file_local_path.exists():
# TODO Also needs to update the Session value so that we can retrieved the backed up file.
# Which imply we will need to keep both the original upload name (for validation) and the
# actual on-disk file name (for reading).
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = file_local_path.with_stem(f"{file_local_path.stem}.{timestamp}")
file_local_path.rename(backup_path)
log.info(f"Backed up existing file to {backup_path}")
cert_details = req.scope["extensions"]["ssl"].get('client_certificate', None)
# TODO If file already exists, come up with a way to version it instead of overwriting it.
if cert_details:
with open(file_local_path.with_suffix(file_local_path.suffix + ".crt"), "w") as f:
f.write(json.dumps(cert_details))
else:
errors.append("File upload did not happen with SSL context")
# Save the file to disk
with open(file_local_path, "wb") as f:
f.write(body)
session_id = session.id if session else None
if len(errors) > 0:
# Add to error uploads (only if not already present)
stream.add_error_upload(session_id, file_path_with_ext, errors)
else:
# Add to valid uploads (only if not already in error uploads)
stream.add_valid_upload(session_id, file_path_with_ext)
if stream.strict_mode and len(errors) > 0:
log.warning(f"Upload validation failed: {errors}")
return JSONResponse(
status_code=400,
content={"errors": errors}
)
log.info("Upload successful:"
f"stream={stream_id}, file={file_path}.{ext}, errors={errors}, strict={stream.strict_mode}")
return Response(status_code=202)
def ffprobe_check(self, stream_id: int, file_path: str):
p = self.wd.path("streams", str(stream_id), file_path)
if not p.exists():
raise HTTPException(404, detail="Media file doesn't exists")
cmd = [
"ffprobe", "-allowed_extensions", "init,m4s",
"-show_streams", "-show_format", "-output_format", "json",
str(p.absolute())
]
print(cmd)
# ffprobe -show_streams -show_format -output_format json /Users/francoismonniot/.pavstest/streams/1/index.mpd
proc = subprocess.run(
cmd,
capture_output=True
)
if proc.returncode != 0:
stderr_text = proc.stderr.decode('utf-8', errors='replace')
raise HTTPException(
500,
detail={
"message": "ffprobe failed to analyze the media file",
"stderr": stderr_text,
"command": " ".join(cmd)
}
)
return json.loads(proc.stdout)
async def segment_download(self, stream_id: int, file_path: str):
return FileResponse(self.wd.path("streams", str(stream_id), file_path))
def list_certs(self):
server = [f.name for f in pathlib.Path(self.wd.path("certs", "server")).iterdir()]
device = [f.name for f in pathlib.Path(self.wd.path("certs", "device")).iterdir()]
return {"server": server, "device": device}
def certificate_details(self, hierarchy: str, name: str):
data = pathlib.Path(self.wd.path("certs", hierarchy, name)).read_bytes()
type = "key" if name.endswith(".key") else "cert"
key = None
cert = None
if type == "key":
key = serialization.load_pem_private_key(data, None)
key = {
"key_size": key.key_size,
"private_key": key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
),
"public_key": key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.PKCS1,
),
}
else:
cert = x509.load_pem_x509_certificate(data)
cert = {
"public_cert": cert.public_bytes(serialization.Encoding.PEM),
"serial_number": hex(cert.serial_number),
"not_valid_before": cert.not_valid_before_utc,
"not_valid_after": cert.not_valid_after_utc,
# public_key? fingerprint?
"issuer": cert.issuer.rfc4514_string(),
"subject": cert.subject.rfc4514_string(),
"extensions": [str(ext) for ext in cert.extensions]
}
return {"type": type, "key": key, "cert": cert}
def create_client_keypair(self, name: str, override: bool = True):
(key, cert, created) = self.device_hierarchy.gen_keypair(name, override)
return {"key": key, "cert": cert, "created": created}
# Seems unused in the current TH tests
# TODO Verify in spec how a track name updated should be handled mid-stream
async def update_track_name(self, stream_id: int, track_request: TrackNameRequest):
"""Updates the track_name for a given stream_id."""
with self._open_stream(stream_id) as stream:
stream.track_name = track_request.track_name
def sign_client_certificate(
self, name: str, req: SignClientCertificate, override: bool = True
):
(key, cert, created) = self.device_hierarchy.gen_cert(name, req.csr, override)
return {"key": key, "cert": cert, "created": created}
class PushAvContext:
"""Hold the context for a full Push AV Server including temporary disk, CA hierarchies and web server"""
def __init__(self, host: Optional[str], port: Optional[int], working_directory: Optional[str], dns: Optional[str], server_ip: Optional[str], strict_mode: bool):
self.directory = WorkingDirectory(working_directory)
self.host = host
self.port = port
self.dns = "localhost" if dns is None else f"{dns}._http._tcp.local."
self.strict_mode = strict_mode
# Create CA hierarchies (for webserver and devices)
self.device_hierarchy = CAHierarchy(self.directory.mkdir("certs", "device"), "device", "client")
self.server_hierarchy = CAHierarchy(self.directory.mkdir("certs", "server"), "server", "server")
(self.server_key_file, self.server_cert_file, _) = self.server_hierarchy.gen_keypair(self.dns, ip_address=server_ip)
# mDNS configuration. Registration only happen if the dns isn't localhost.
self.zeroconf = Zeroconf()
self.svc_info = None
if self.dns != "localhost":
self.svc_info = ServiceInfo(
"_http._tcp.local.",
name=self.dns,
addresses=[socket.inet_aton("127.0.0.1")],
port=1234,
)
# Streams holder
self.directory.mkdir("streams")
logger = logging.getLogger("hypercorn.error")
self.app = FastAPI()
self.app.mount("/static", StaticFiles(directory=static_path), name="static")
pas = PushAvServer(self.directory, self.device_hierarchy, strict_mode)
self.app.include_router(pas.router)
@self.app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
logger.error(
f"HTTPExecption: {exc.status_code} {exc.detail}"
)
return JSONResponse(
status_code=exc.status_code,
content={"detail": exc.detail}
)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.cleanup()
async def start(self, shutdown_trigger: Optional[Callable[..., Awaitable]] = None):
"""
Start the PUSH AV server. Note that method do not check if a server is already running.
"""
# Advertise over mDNS
if self.svc_info:
log.info("Advertising the service as %s", self.svc_info)
self.zeroconf.register_service(self.svc_info)
# Start the web server
from hypercorn.asyncio import serve
from hypercorn.config import Config
bind = (self.host or "127.0.0.1") + ":" + (str(self.port or 8000))
config = Config.from_mapping(
bind=bind,
quic_bind=bind,
alpn_protocols=["h2"],
keyfile=self.server_key_file,
certfile=self.server_cert_file,
ca_certs=self.device_hierarchy.root_cert_path,
verify_mode=ssl.CERT_OPTIONAL
)
try:
await serve(self.app, config, shutdown_trigger=shutdown_trigger)
finally:
if self.svc_info:
self.zeroconf.unregister_service(self.svc_info)
def cleanup(self):
self.directory.cleanup()
if __name__ == "__main__":
logging.basicConfig(
format="%(asctime)s|%(name)-8s|%(levelname)-5s|%(message)s",
level=logging.DEBUG,
datefmt="%H:%M:%S",
)
logging.getLogger("hpack").setLevel(logging.WARNING)
parser = argparse.ArgumentParser(
prog="push_av_tool.py",
description="Tooling to help test Matter's Push AV capabilities",
)
parser.add_argument("--host", default="localhost")
parser.add_argument("--port", default=1234)
parser.add_argument(
"--working-directory",
help="Where to store content like certificates or uploaded streams. "
"Default to a temporary directory.",
)
parser.add_argument(
"--dns", help="A mDNS record to adversise, or none if left empty."
)
parser.add_argument("--server-ip", help="The IP address of the server to include in the SSL certificate.")
parser.add_argument("--strict-mode", action='store_true',
help="When enabled, upload must happen on the path described by the Matter specification")
args = parser.parse_args()
with PushAvContext(args.host, args.port, args.working_directory, args.dns, args.server_ip, args.strict_mode) as ctx:
shutdown_event = asyncio.Event()
def _signal_handler():
print("SIGINT received. Shutting down web server.")
shutdown_event.set()
with asyncio.Runner() as runner:
runner.get_loop().add_signal_handler(signal.SIGINT, _signal_handler)
runner.run(ctx.start(shutdown_trigger=shutdown_event.wait))