| import argparse |
| import asyncio |
| import datetime |
| import ipaddress |
| import json |
| import logging |
| import os.path |
| import pathlib |
| import random |
| import re |
| import signal |
| import socket |
| import ssl |
| import string |
| import subprocess |
| import sys |
| import tempfile |
| from enum import Enum |
| from pathlib import Path |
| from typing import Awaitable, Callable, Literal, Optional |
| |
| from cryptography import x509 |
| from cryptography.hazmat.primitives import hashes, serialization |
| from cryptography.hazmat.primitives.asymmetric import rsa |
| from cryptography.hazmat.primitives.asymmetric.types import CertificateIssuerPrivateKeyTypes, CertificatePublicKeyTypes |
| from cryptography.x509.oid import NameOID |
| from fastapi import APIRouter, FastAPI, HTTPException, Request, Response |
| from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.templating import Jinja2Templates |
| from pydantic import BaseModel |
| from zeroconf import ServiceInfo, Zeroconf |
| |
| module_dir_path = os.path.dirname(os.path.realpath(__file__)) |
| templates_path = os.path.join(module_dir_path, "templates") |
| static_path = os.path.join(module_dir_path, "static") |
| ''' |
| The initialisation segments must have .init extension as per CMAF-Ingest requirements. |
| https://dashif.org/Ingest/#interface-2-naming |
| ''' |
| VALID_EXTENSIONS = ["mpd", "m3u8", "m4s", "init"] |
| |
| |
| class WorkingDirectory: |
| """ |
| Collection of utilities to add convention to the files used by this program. |
| """ |
| |
| tmp = None |
| |
| def __init__(self, directory: Optional[str] = None) -> None: |
| |
| if directory is None: |
| self.tmp = tempfile.TemporaryDirectory(prefix="TC_PAVS_1_0") |
| else: |
| d = pathlib.Path(directory) |
| d.mkdir(parents=True, exist_ok=True) |
| self.directory = d |
| |
| def __enter__(self): |
| return self |
| |
| def __exit__(self, exc, value, tb): |
| self.cleanup() |
| |
| def cleanup(self): |
| if self.tmp: |
| self.tmp.cleanup() |
| |
| def root_dir(self) -> Path: |
| return Path(self.tmp.name) if self.tmp else self.directory |
| |
| def path(self, *paths: str) -> Path: |
| return Path(os.path.join(self.root_dir(), *paths)) |
| |
| def mkdir(self, *paths: str, is_file=False) -> Path: |
| """ |
| Create a directory using the given path rooted in the working directory. |
| If a file is provided, the directory up to that file will be created instead. |
| Returns the full path. |
| """ |
| p = self.path(*paths) |
| |
| # Let's create the parent directories exist |
| p2 = pathlib.Path(p) |
| if is_file: |
| p2 = p2.parent |
| |
| p2.mkdir(parents=True, exist_ok=True) |
| |
| return p |
| |
| def print_tree(self): |
| # TODO Convert this helper to build a HTML representation for use in the UI |
| |
| def tree(dir_path: pathlib.Path, prefix: str = ""): |
| """A recursive generator, given a directory Path object |
| will yield a visual tree structure line by line |
| with each line prefixed by the same characters |
| """ |
| # prefix components: |
| space = " " |
| branch = "│ " |
| # pointers: |
| tee = "├── " |
| last = "└── " |
| |
| contents = list(dir_path.iterdir()) |
| # contents each get pointers that are ├── with a final └── : |
| pointers = [tee] * (len(contents) - 1) + [last] |
| for pointer, path in zip(pointers, contents): |
| is_dir = path.is_dir() |
| yield prefix + pointer + path.name + ("/" if is_dir else "") |
| if path.is_dir(): # extend the prefix and recurse: |
| extension = branch if pointer == tee else space |
| # i.e. space because last, └── , above so no more | |
| yield from tree(path, prefix=prefix + extension) |
| |
| root = self.root_dir() |
| print(root) |
| for line in tree(pathlib.Path(root)): |
| print(line) |
| |
| |
| class CAHierarchy: |
| """ |
| Utilities to manage a CA hierarchy on disk. |
| """ |
| |
| default_ca_duration = datetime.timedelta(days=365.25*20) |
| |
| client_key_usage_cert = x509.KeyUsage( |
| digital_signature=True, |
| content_commitment=False, |
| key_encipherment=True, |
| data_encipherment=False, |
| key_agreement=False, |
| key_cert_sign=False, |
| crl_sign=False, |
| encipher_only=False, |
| decipher_only=False, |
| ) |
| server_key_usage_cert = x509.KeyUsage( |
| digital_signature=True, |
| content_commitment=False, |
| key_encipherment=False, |
| data_encipherment=False, |
| key_agreement=False, |
| key_cert_sign=False, |
| crl_sign=False, |
| encipher_only=False, |
| decipher_only=False, |
| ) |
| |
| def __init__(self, base: Path, name: str, kind: Literal['server', 'client']) -> None: |
| self.name = name |
| self.kind = kind |
| self.directory = base |
| |
| self.root_cert_path = self.directory / "root.pem" |
| self.root_key_path = self.directory / "root.key" |
| |
| if self.root_key_path.exists() and self.root_cert_path.exists(): |
| # Root certificate already exists, re-using them |
| self.root_cert = x509.load_pem_x509_certificate( |
| self.root_cert_path.read_bytes() |
| ) |
| self.root_key = serialization.load_pem_private_key( |
| self.root_key_path.read_bytes(), None |
| ) |
| |
| logging.info(f"CA Hierarchy loaded from disk: {self.name}") |
| elif self.root_key_path.exists() or self.root_cert_path.exists(): |
| # Only one of the two file exists, bailing out |
| logging.error("root certificate partially exist on disk, stopping early") |
| sys.exit(1) |
| else: |
| # Start generating the root certificate |
| self.root_key = rsa.generate_private_key( |
| public_exponent=65537, key_size=2048 |
| ) |
| rand_suffix = "".join( |
| random.choices(string.ascii_letters + string.digits, k=16) |
| ) |
| root_cert_subject = x509.Name( |
| [ |
| x509.NameAttribute(NameOID.ORGANIZATION_NAME, "CSA"), |
| x509.NameAttribute( |
| NameOID.COMMON_NAME, "TC_PAVS root " + rand_suffix |
| ), |
| ] |
| ) |
| self.root_cert = ( |
| x509.CertificateBuilder() |
| .subject_name(root_cert_subject) |
| .issuer_name(root_cert_subject) |
| .public_key(self.root_key.public_key()) |
| .serial_number(x509.random_serial_number()) |
| .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) |
| .not_valid_after( |
| datetime.datetime.now(datetime.timezone.utc) + self.default_ca_duration |
| ) |
| .add_extension( |
| # We make it so that our root can only issue leaf certificates, no intermediate here. |
| x509.BasicConstraints(ca=True, path_length=0), critical=True |
| ) |
| .add_extension( |
| x509.KeyUsage( |
| digital_signature=True, |
| content_commitment=False, |
| key_encipherment=False, |
| data_encipherment=False, |
| key_agreement=False, |
| key_cert_sign=True, |
| crl_sign=True, |
| encipher_only=False, |
| decipher_only=False, |
| ), |
| critical=True, |
| ) |
| .add_extension( |
| x509.SubjectKeyIdentifier.from_public_key( |
| self.root_key.public_key() |
| ), |
| critical=False, |
| ) |
| .sign(self.root_key, hashes.SHA256()) |
| ) |
| |
| self._save_cert("root", self.root_cert, self.root_key, False) |
| |
| logging.info(f"CA Hierarchy generated: {self.name}") |
| |
| def _save_cert( |
| self, |
| name: str, |
| cert: x509.Certificate, |
| key: Optional[CertificateIssuerPrivateKeyTypes], |
| bundle_root: bool, |
| ) -> tuple[Optional[Path], Path]: |
| """ |
| Private method that help with saving certificate and key to the hierarchy folder. |
| This tool isn't meant to be used in production, but instead to help with development |
| and as such have the goal to make the CA hierarchy as available as possible, which in |
| turn make it very unsecure. |
| """ |
| cert_path = self.directory / f"{name}.pem" |
| key_path = self.directory / f"{name}.key" if key else None |
| |
| if key and key_path: |
| with open(key_path, "wb") as f: |
| f.write( |
| key.private_bytes( |
| encoding=serialization.Encoding.PEM, |
| format=serialization.PrivateFormat.TraditionalOpenSSL, |
| encryption_algorithm=serialization.NoEncryption(), |
| ) |
| ) |
| |
| with open(cert_path, "wb") as f: |
| f.write(cert.public_bytes(serialization.Encoding.PEM)) |
| |
| if bundle_root: |
| f.write(b"\n") |
| f.write(self.root_cert.public_bytes(serialization.Encoding.PEM)) |
| |
| return (key_path, cert_path) |
| |
| def _sign_cert( |
| self, |
| dns: str, |
| public_key: CertificatePublicKeyTypes, |
| duration: datetime.timedelta, |
| ip_address: Optional[str] = None |
| ) -> x509.Certificate: |
| """ |
| Generate and sign a certificate. |
| """ |
| # Use ip_address for Common Name if provided, otherwise use dns |
| common_name = ip_address if ip_address else dns |
| |
| # Sign certificate |
| subject = x509.Name( |
| [ |
| x509.NameAttribute(NameOID.ORGANIZATION_NAME, "CSA"), |
| x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "TC_PAVS"), |
| x509.NameAttribute(NameOID.COMMON_NAME, common_name), |
| ] |
| ) |
| |
| extended_key_usage = [x509.ExtendedKeyUsageOID.CLIENT_AUTH] if self.kind == "client" else [ |
| x509.ExtendedKeyUsageOID.SERVER_AUTH] |
| |
| builder = (x509.CertificateBuilder() |
| .subject_name(subject) |
| .issuer_name(self.root_cert.subject) |
| .public_key(public_key) |
| .serial_number(x509.random_serial_number()) |
| .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) |
| .not_valid_after( |
| datetime.datetime.now(datetime.timezone.utc) + duration |
| ) |
| .add_extension( |
| x509.BasicConstraints(ca=False, path_length=None), |
| critical=False, |
| ) |
| .add_extension( |
| self.client_key_usage_cert if self.kind == "client" else self.server_key_usage_cert, |
| critical=True, |
| ) |
| .add_extension( |
| x509.ExtendedKeyUsage(extended_key_usage), |
| critical=False, |
| ) |
| .add_extension( |
| x509.SubjectKeyIdentifier.from_public_key(public_key), |
| critical=False, |
| ) |
| .add_extension( |
| x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier( |
| self.root_cert.extensions.get_extension_for_class( |
| x509.SubjectKeyIdentifier |
| ).value |
| ), |
| critical=False, |
| ) |
| .add_extension(x509.CRLDistributionPoints([x509.DistributionPoint( |
| full_name=[x509.UniformResourceIdentifier("http://not.a.valid.website.com/some/path/to/a.crl")], |
| relative_name=None, |
| reasons=None, |
| crl_issuer=None |
| )]), critical=False) |
| ) |
| |
| if self.kind == 'server': |
| san_names = [x509.DNSName(dns)] |
| if ip_address: |
| san_names.append(x509.IPAddress(ipaddress.ip_address(ip_address))) |
| builder.add_extension( |
| x509.SubjectAlternativeName(san_names), |
| critical=False, |
| ) |
| |
| return builder.sign(self.root_key, hashes.SHA256()) |
| |
| def gen_cert(self, dns: str, csr: str, override=False, duration: datetime.timedelta = datetime.timedelta(hours=1)) -> tuple[Path, Path, bool]: |
| """ |
| Generate a certificate signed by this CA hierarchy using the provided CSR. |
| Returns the path to the key, cert, and whether it was reused or not. |
| """ |
| signing_request = x509.load_pem_x509_csr(csr.encode('utf-8')) |
| signing_request.public_key() |
| |
| # If we don't always override, first check if an existing keypair already exists |
| if not override: |
| cert_path = self.directory / f"{dns}.pem" |
| key_path = self.directory / f"{dns}.key" |
| |
| if cert_path.exists() and key_path.exists(): |
| return (key_path, cert_path, True) |
| |
| # Sign certificate |
| cert = self._sign_cert(dns, signing_request.public_key(), duration) |
| |
| # Save that information to disk |
| (key_path, cert_bundle_path) = self._save_cert( |
| dns, cert, None, bundle_root=True |
| ) |
| |
| logging.debug("leaf generated. dns=%s; path=%s", dns, cert_bundle_path) |
| |
| return (key_path, cert_bundle_path, False) |
| |
| def gen_keypair(self, dns: str, override=False, duration: datetime.timedelta = datetime.timedelta(hours=1), ip_address: Optional[str] = None) -> tuple[Path, Path, bool]: |
| """ |
| Generate a private key as well as the associated certificate signed by this CA |
| hierarchy. Returns the path to the key, cert, and whether it was reused or not. |
| """ |
| |
| # If we don't always override, first check if an existing keypair already exists |
| if not override: |
| cert_path = self.directory / f"{dns}.pem" |
| key_path = self.directory / f"{dns}.key" |
| |
| if cert_path.exists() and key_path.exists(): |
| return (key_path, cert_path, True) |
| |
| # Generate private key |
| key = rsa.generate_private_key(public_exponent=65537, key_size=2048) |
| |
| # Sign certificate |
| cert = self._sign_cert(dns, key.public_key(), duration, ip_address=ip_address) |
| |
| # Save that information to disk |
| (key_path, cert_bundle_path) = self._save_cert(dns, cert, key, bundle_root=True) |
| |
| logging.debug("leaf generated. dns=%s; path=%s", dns, cert_bundle_path) |
| |
| return (key_path, cert_bundle_path, False) |
| |
| |
| class SignClientCertificate(BaseModel): |
| """Request model to sign a client certificate""" |
| csr: str |
| |
| |
| class TrackNameRequest(BaseModel): |
| """Request model to update track name for a stream""" |
| trackName: str |
| |
| |
| class SupportedIngestInterface(str, Enum): |
| cmaf = "cmaf-ingest" |
| dash = "dash" |
| hls = "hls" |
| |
| |
| class PushAvServer: |
| |
| templates = Jinja2Templates(directory=templates_path) |
| |
| def __init__(self, wd: WorkingDirectory, device_hierarchy: CAHierarchy, strict_mode: bool): |
| self.wd = wd |
| self.device_hierarchy = device_hierarchy |
| self.strict_mode = strict_mode |
| self.router = APIRouter() |
| |
| # In-memory map to track stream files: {stream_id: {"valid_files": [], "invalid_files": []}} |
| self.stream_files_map = {} |
| |
| # UI |
| self.router.add_api_route("/", self.index, methods=["GET"], response_class=RedirectResponse) |
| self.router.add_api_route("/ui/streams", self.ui_streams_list, methods=["GET"], response_class=HTMLResponse) |
| self.router.add_api_route("/ui/streams/{stream_id}/{file_path:path}", self.ui_streams_details, methods=["GET"]) |
| self.router.add_api_route("/ui/certificates", self.ui_certificates_list, methods=["GET"], response_class=HTMLResponse) |
| self.router.add_api_route("/ui/certificates/{hierarchy}/{name}", |
| self.ui_certificates_details, methods=["GET"], response_class=HTMLResponse) |
| |
| # HTTP APIs |
| self.router.add_api_route("/streams", self.create_stream, methods=["POST"], status_code=201) |
| self.router.add_api_route("/streams", self.list_streams, methods=["GET"]) |
| self.router.add_api_route("/streams/probe/{stream_id}/{file_path:path}", self.ffprobe_check, methods=["GET"]) |
| |
| self.router.add_api_route("/streams/{stream_id}/{file_path:path}.{ext}", self.handle_upload, methods=["PUT"]) |
| |
| self.router.add_api_route("/streams/{stream_id}/{file_path:path}", self.segment_download, methods=["GET"]) |
| self.router.add_api_route("/streams/{stream_id}/trackName", self.update_track_name, methods=["POST"], status_code=202) |
| self.router.add_api_route("/certs", self.list_certs, methods=["GET"], status_code=200) |
| self.router.add_api_route("/certs/{hierarchy}/{name}", self.certificate_details, methods=["GET"], status_code=200) |
| self.router.add_api_route("/certs/{name}/keypair", self.create_client_keypair, methods=["POST"]) |
| self.router.add_api_route("/certs/{name}/sign", self.sign_client_certificate, methods=["POST"]) |
| |
| # Utilities |
| |
| def _read_stream_details(self, stream_id: int): |
| p = self.wd.path("streams", str(stream_id), "details.json") |
| |
| try: |
| with open(p, 'r') as file: |
| return json.load(file) |
| except FileNotFoundError: |
| raise HTTPException(404, detail="Stream doesn't exists") |
| except Exception as e: |
| raise HTTPException(500, f"An unexpected error occurred: {e}") |
| |
| # UI website |
| |
| def index(self): |
| return RedirectResponse("/ui/streams") |
| |
| def ui_streams_list(self, request: Request): |
| s = self.list_streams() |
| return self.templates.TemplateResponse( |
| request=request, name="streams_list.jinja2", context={"streams": s["streams"]} |
| ) |
| |
| def ui_streams_details(self, request: Request, stream_id: int, file_path: str): |
| context = {} |
| context['streams'] = self.list_streams()['streams'] |
| context['stream_id'] = stream_id |
| context['file_path'] = file_path |
| |
| if file_path.endswith('.crt'): |
| context['type'] = 'cert' |
| p = self.wd.path("streams", str(stream_id), file_path) |
| with open(p, "r") as f: |
| context['cert'] = json.load(f) |
| elif file_path == 'details.json': |
| context['type'] = 'details' |
| context['details'] = self._read_stream_details(stream_id) |
| else: |
| context['type'] = 'media' |
| context['probe'] = self.ffprobe_check(stream_id, file_path) |
| context['pretty_probe'] = json.dumps(context['probe'], sort_keys=True, indent=4) |
| |
| return self.templates.TemplateResponse(request=request, name="streams_details.jinja2", context=context) |
| |
| def ui_certificates_list(self, request: Request): |
| return self.templates.TemplateResponse( |
| request=request, name="certificates_list.jinja2", context={"certs": self.list_certs()} |
| ) |
| |
| def ui_certificates_details(self, request: Request, hierarchy: str, name: str): |
| context = self.certificate_details(hierarchy, name) |
| context["certs"] = self.list_certs() |
| |
| return self.templates.TemplateResponse(request=request, name="certificates_details.jinja2", context=context) |
| |
| # APIs |
| |
| def create_stream(self, interface: Optional[SupportedIngestInterface] = None): |
| # Find the last registered stream |
| dirs = [d for d in pathlib.Path(self.wd.path("streams")).iterdir() if d.is_dir()] |
| last_stream = int(dirs[-1].name) if len(dirs) > 0 else 0 |
| stream_id = last_stream + 1 |
| |
| # TODO Add option to specify Interface-1, Interface-2 DASH, or I2-HLS to improve the strict mode |
| p = self.wd.mkdir("streams", str(stream_id)) |
| stream = {"stream_id": stream_id, "strict_mode": self.strict_mode, "interface": interface} |
| |
| with open(p / "details.json", 'w', encoding='utf-8') as f: |
| json.dump(stream, f, ensure_ascii=False, indent=4) |
| |
| # Initialize entry in stream files map |
| self.stream_files_map[str(stream_id)] = {"valid_files": [], "invalid_files": []} |
| |
| return stream |
| |
| def list_streams(self): |
| # Return streams directly from the in-memory map |
| streams = [] |
| |
| for stream_id, stream_data in self.stream_files_map.items(): |
| streams.append({ |
| "id": int(stream_id), |
| "valid_files": stream_data["valid_files"], |
| "invalid_files": stream_data["invalid_files"] |
| }) |
| |
| return {"streams": streams} |
| |
| async def _handle_upload(self, dst: Path, req: Request): |
| """ Handle an upload, sending content to disk at 'dst'. |
| |
| Extract the parsed version of a client certificate via a patched TLS |
| extension. See https://docs.python.org/3/library/ssl.html#ssl.SSLSocket.getpeercert |
| for the exact content. |
| """ |
| |
| cert_details = req.scope["extensions"]["ssl"]["client_certificate"] |
| |
| with open(dst.with_suffix(dst.suffix + ".crt"), "w") as f: |
| f.write(json.dumps(cert_details)) |
| |
| with open(dst, "wb") as f: |
| async for chunk in req.stream(): |
| f.write(chunk) |
| |
| return Response(status_code=202) |
| |
| async def handle_upload(self, stream_id: int, file_path: str, ext: str, req: Request): |
| """ |
| Handle any upload if strict-mode isn't enabled. |
| Otherwise, check if the segment path format complies with Matter Specification path. |
| """ |
| stream = self._read_stream_details(stream_id) |
| is_valid = True |
| validation_error_reason = "" |
| |
| if stream.get('strict_mode', False): |
| if ext not in VALID_EXTENSIONS: |
| is_valid = False |
| validation_error_reason = f"Invalid extension: {ext}, valid extensions are {', '.join(VALID_EXTENSIONS)}" |
| elif ext in ["mpd", "m3u8"]: |
| iface = stream.get('interface', None) |
| if (iface == SupportedIngestInterface.dash and ext != "mpd" or |
| iface == SupportedIngestInterface.hls and ext != "m3u8"): |
| is_valid = False |
| validation_error_reason = "Unsupported manifest object extension" |
| elif ext == "m4s": |
| # Checks if CMAF extended path matches the pattern session_<SessionNumber>/<TrackName>/segment_<SegmentNumber> |
| # https://github.com/CHIP-Specifications/connectedhomeip-spec/blob/master/src/app_clusters/PushAVStreamTransport.adoc#12-operation |
| segment_pattern = re.compile(r"^session_\d+/(?P<trackName>[^/]+)/segment_\d+$") |
| match = segment_pattern.match(file_path) |
| if not match: |
| is_valid = False |
| validation_error_reason = "Path does not adhere to Matter's extended path format: session_<SessionNumber>/<TrackName>/segment_<SegmentNumber>" |
| else: |
| # Validate if the trackName is same as the one assigned during transport allocation. |
| # https://github.com/CHIP-Specifications/connectedhomeip-spec/blob/master/src/app_clusters/PushAVStreamTransport.adoc#685-trackname-field |
| track_name_in_path = match.group("trackName") |
| track_name = stream.get('trackName', None) |
| if track_name and track_name != track_name_in_path: |
| is_valid = False |
| validation_error_reason = ("Track name mismatch: " |
| f"{track_name_in_path} != {track_name}, " |
| "must match TrackName provided in ContainerOptions") |
| |
| dst = self.wd.mkdir("streams", str(stream_id), f"{file_path}.{ext}", is_file=True) |
| extended_path = f"{file_path}.{ext}" |
| |
| # Add file to the appropriate list in the stream files map |
| logging.debug(f"Upload received: {extended_path}") |
| stream_id_str = str(stream_id) |
| if stream_id_str in self.stream_files_map: |
| if is_valid and extended_path not in self.stream_files_map[stream_id_str]["valid_files"]: |
| self.stream_files_map[stream_id_str]["valid_files"].append(extended_path) |
| if not is_valid: |
| logging.error(f"{extended_path}: {validation_error_reason}") |
| if extended_path not in self.stream_files_map[stream_id_str]["invalid_files"]: |
| self.stream_files_map[stream_id_str]["invalid_files"].append({ |
| "file_path": extended_path, |
| "validation_error_reason": validation_error_reason |
| }) |
| |
| return await self._handle_upload(dst, req) |
| |
| def ffprobe_check(self, stream_id: int, file_path: str): |
| |
| p = self.wd.path("streams", str(stream_id), file_path) |
| |
| if not p.exists(): |
| return HTTPException(404, detail="Stream doesn't exists") |
| |
| proc = subprocess.run( |
| ["ffprobe", "-show_streams", "-show_format", "-output_format", "json", str(p.absolute())], |
| capture_output=True |
| ) |
| |
| if proc.returncode != 0: |
| # TODO Add more details (maybe stderr) to the response |
| return HTTPException(500) |
| |
| return json.loads(proc.stdout) |
| |
| async def segment_download(self, file_path: str, stream_id: int): |
| return FileResponse(self.wd.path("streams", str(stream_id), file_path)) |
| |
| def list_certs(self): |
| server = [f.name for f in pathlib.Path(self.wd.path("certs", "server")).iterdir()] |
| device = [f.name for f in pathlib.Path(self.wd.path("certs", "device")).iterdir()] |
| |
| return {"server": server, "device": device} |
| |
| def certificate_details(self, hierarchy: str, name: str): |
| data = pathlib.Path(self.wd.path("certs", hierarchy, name)).read_bytes() |
| type = "key" if name.endswith(".key") else "cert" |
| |
| key = None |
| cert = None |
| if type == "key": |
| key = serialization.load_pem_private_key(data, None) |
| key = { |
| "key_size": key.key_size, |
| "private_key": key.private_bytes( |
| encoding=serialization.Encoding.PEM, |
| format=serialization.PrivateFormat.TraditionalOpenSSL, |
| encryption_algorithm=serialization.NoEncryption(), |
| ), |
| "public_key": key.public_key().public_bytes( |
| encoding=serialization.Encoding.PEM, |
| format=serialization.PublicFormat.PKCS1, |
| ), |
| } |
| else: |
| cert = x509.load_pem_x509_certificate(data) |
| cert = { |
| "public_cert": cert.public_bytes(serialization.Encoding.PEM), |
| "serial_number": hex(cert.serial_number), |
| "not_valid_before": cert.not_valid_before_utc, |
| "not_valid_after": cert.not_valid_after_utc, |
| # public_key? fingerprint? |
| "issuer": cert.issuer.rfc4514_string(), |
| "subject": cert.subject.rfc4514_string(), |
| "extensions": [str(ext) for ext in cert.extensions] |
| } |
| |
| return {"type": type, "key": key, "cert": cert} |
| |
| def create_client_keypair(self, name: str, override: bool = True): |
| (key, cert, created) = self.device_hierarchy.gen_keypair(name, override) |
| |
| return {key, cert, created} |
| |
| async def update_track_name(self, stream_id: int, track_request: TrackNameRequest): |
| """ |
| Updates the trackName for a given stream_id. |
| """ |
| stream_details = self._read_stream_details(stream_id) |
| |
| stream_details["trackName"] = track_request.trackName |
| |
| details_path = self.wd.path("streams", str(stream_id), "details.json") |
| try: |
| with open(details_path, 'w', encoding='utf-8') as f: |
| json.dump(stream_details, f, ensure_ascii=False, indent=4) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Failed to write stream details: {e}") |
| |
| def sign_client_certificate( |
| self, name: str, req: SignClientCertificate, override: bool = True |
| ): |
| (key, cert, created) = self.device_hierarchy.gen_cert(name, req.csr, override) |
| |
| return {key, cert, created} |
| |
| |
| class PushAvContext: |
| """Hold the context for a full Push AV Server including temporary disk, CA hierarchies and web server""" |
| |
| def __init__(self, host: Optional[str], port: Optional[int], working_directory: Optional[str], dns: Optional[str], server_ip: Optional[str], strict_mode: bool): |
| self.directory = WorkingDirectory(working_directory) |
| self.host = host |
| self.port = port |
| self.dns = "localhost" if dns is None else f"{dns}._http._tcp.local." |
| self.strict_mode = strict_mode |
| |
| # Create CA hierarchies (for webserver and devices) |
| self.device_hierarchy = CAHierarchy(self.directory.mkdir("certs", "device"), "device", "client") |
| self.server_hierarchy = CAHierarchy(self.directory.mkdir("certs", "server"), "server", "server") |
| (self.server_key_file, self.server_cert_file, _) = self.server_hierarchy.gen_keypair(self.dns, override=True, ip_address=server_ip) |
| |
| # mDNS configuration. Registration only happen if the dns isn't localhost. |
| self.zeroconf = Zeroconf() |
| self.svc_info = None |
| |
| if self.dns != "localhost": |
| self.svc_info = ServiceInfo( |
| "_http._tcp.local.", |
| name=self.dns, |
| addresses=[socket.inet_aton("127.0.0.1")], |
| port=1234, |
| ) |
| |
| # Streams holder |
| self.directory.mkdir("streams") |
| |
| logger = logging.getLogger("hypercorn.error") |
| self.app = FastAPI() |
| self.app.mount("/static", StaticFiles(directory=static_path), name="static") |
| pas = PushAvServer(self.directory, self.device_hierarchy, strict_mode) |
| self.app.include_router(pas.router) |
| |
| @self.app.exception_handler(HTTPException) |
| async def http_exception_handler(request: Request, exc: HTTPException): |
| logger.error( |
| f"HTTPExecption: {exc.status_code} {exc.detail}" |
| ) |
| return JSONResponse( |
| status_code=exc.status_code, |
| content={"detail": exc.detail} |
| ) |
| |
| async def start(self, shutdown_trigger: Optional[Callable[..., Awaitable]] = None): |
| """ |
| Start the PUSH AV server. Note that method do not check if a server is already running. |
| """ |
| # Advertise over mDNS |
| if self.svc_info: |
| logging.info("Advertising the service as %s", self.svc_info) |
| self.zeroconf.register_service(self.svc_info) |
| |
| # Start the web server |
| from hypercorn.asyncio import serve |
| from hypercorn.config import Config |
| bind = (self.host or "127.0.0.1") + ":" + (str(self.port or 8000)) |
| config = Config.from_mapping( |
| bind=bind, |
| quic_bind=bind, |
| alpn_protocols=["h2"], |
| keyfile=self.server_key_file, |
| certfile=self.server_cert_file, |
| ca_certs=self.device_hierarchy.root_cert_path, |
| verify_mode=ssl.CERT_OPTIONAL |
| ) |
| |
| try: |
| await serve(self.app, config, shutdown_trigger=shutdown_trigger) |
| |
| finally: |
| if self.svc_info: |
| self.zeroconf.unregister_service(self.svc_info) |
| |
| def cleanup(self): |
| self.directory.cleanup() |
| |
| |
| if __name__ == "__main__": |
| logging.basicConfig( |
| format="%(asctime)s|%(name)-8s|%(levelname)-5s|%(message)s", |
| level=logging.DEBUG, |
| datefmt="%H:%M:%S", |
| ) |
| logging.getLogger("hpack").setLevel(logging.WARN) |
| |
| parser = argparse.ArgumentParser( |
| prog="push_av_tool.py", |
| description="Tooling to help test Matter's Push AV capabilities", |
| ) |
| |
| parser.add_argument("--host", default="localhost") |
| parser.add_argument("--port", default=1234) |
| parser.add_argument( |
| "--working-directory", |
| help="Where to store content like certificates or uploaded streams. " |
| "Default to a temporary directory.", |
| ) |
| parser.add_argument( |
| "--dns", help="A mDNS record to adversise, or none if left empty." |
| ) |
| parser.add_argument("--server-ip", help="The IP address of the server to include in the SSL certificate.") |
| parser.add_argument("--strict-mode", action='store_true', |
| help="When enabled, upload must happen on the path described by the Matter specification") |
| |
| args = parser.parse_args() |
| |
| ctx = PushAvContext(args.host, args.port, args.working_directory, args.dns, args.server_ip, args.strict_mode) |
| |
| shutdown_event = asyncio.Event() |
| |
| def _signal_handler(): |
| print("SIGINT received. Shutting down web server.") |
| shutdown_event.set() |
| |
| loop = asyncio.get_event_loop() |
| loop.add_signal_handler(signal.SIGINT, _signal_handler) |
| loop.run_until_complete(ctx.start(shutdown_trigger=shutdown_event.wait)) |
| |
| ctx.cleanup() |