blob: 4dd3ef435a0d7859f7cf30cca02b88fd671d08a0 [file] [log] [blame]
import argparse
import asyncio
import datetime
import ipaddress
import json
import logging
import os.path
import pathlib
import random
import re
import signal
import socket
import ssl
import string
import subprocess
import sys
import tempfile
from enum import Enum
from pathlib import Path
from typing import Awaitable, Callable, Literal, Optional
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric.types import CertificateIssuerPrivateKeyTypes, CertificatePublicKeyTypes
from cryptography.x509.oid import NameOID
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
from zeroconf import ServiceInfo, Zeroconf
module_dir_path = os.path.dirname(os.path.realpath(__file__))
templates_path = os.path.join(module_dir_path, "templates")
static_path = os.path.join(module_dir_path, "static")
'''
The initialisation segments must have .init extension as per CMAF-Ingest requirements.
https://dashif.org/Ingest/#interface-2-naming
'''
VALID_EXTENSIONS = ["mpd", "m3u8", "m4s", "init"]
class WorkingDirectory:
"""
Collection of utilities to add convention to the files used by this program.
"""
tmp = None
def __init__(self, directory: Optional[str] = None) -> None:
if directory is None:
self.tmp = tempfile.TemporaryDirectory(prefix="TC_PAVS_1_0")
else:
d = pathlib.Path(directory)
d.mkdir(parents=True, exist_ok=True)
self.directory = d
def __enter__(self):
return self
def __exit__(self, exc, value, tb):
self.cleanup()
def cleanup(self):
if self.tmp:
self.tmp.cleanup()
def root_dir(self) -> Path:
return Path(self.tmp.name) if self.tmp else self.directory
def path(self, *paths: str) -> Path:
return Path(os.path.join(self.root_dir(), *paths))
def mkdir(self, *paths: str, is_file=False) -> Path:
"""
Create a directory using the given path rooted in the working directory.
If a file is provided, the directory up to that file will be created instead.
Returns the full path.
"""
p = self.path(*paths)
# Let's create the parent directories exist
p2 = pathlib.Path(p)
if is_file:
p2 = p2.parent
p2.mkdir(parents=True, exist_ok=True)
return p
def print_tree(self):
# TODO Convert this helper to build a HTML representation for use in the UI
def tree(dir_path: pathlib.Path, prefix: str = ""):
"""A recursive generator, given a directory Path object
will yield a visual tree structure line by line
with each line prefixed by the same characters
"""
# prefix components:
space = " "
branch = "│ "
# pointers:
tee = "├── "
last = "└── "
contents = list(dir_path.iterdir())
# contents each get pointers that are ├── with a final └── :
pointers = [tee] * (len(contents) - 1) + [last]
for pointer, path in zip(pointers, contents):
is_dir = path.is_dir()
yield prefix + pointer + path.name + ("/" if is_dir else "")
if path.is_dir(): # extend the prefix and recurse:
extension = branch if pointer == tee else space
# i.e. space because last, └── , above so no more |
yield from tree(path, prefix=prefix + extension)
root = self.root_dir()
print(root)
for line in tree(pathlib.Path(root)):
print(line)
class CAHierarchy:
"""
Utilities to manage a CA hierarchy on disk.
"""
default_ca_duration = datetime.timedelta(days=365.25*20)
client_key_usage_cert = x509.KeyUsage(
digital_signature=True,
content_commitment=False,
key_encipherment=True,
data_encipherment=False,
key_agreement=False,
key_cert_sign=False,
crl_sign=False,
encipher_only=False,
decipher_only=False,
)
server_key_usage_cert = x509.KeyUsage(
digital_signature=True,
content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False,
key_cert_sign=False,
crl_sign=False,
encipher_only=False,
decipher_only=False,
)
def __init__(self, base: Path, name: str, kind: Literal['server', 'client']) -> None:
self.name = name
self.kind = kind
self.directory = base
self.root_cert_path = self.directory / "root.pem"
self.root_key_path = self.directory / "root.key"
if self.root_key_path.exists() and self.root_cert_path.exists():
# Root certificate already exists, re-using them
self.root_cert = x509.load_pem_x509_certificate(
self.root_cert_path.read_bytes()
)
self.root_key = serialization.load_pem_private_key(
self.root_key_path.read_bytes(), None
)
logging.info(f"CA Hierarchy loaded from disk: {self.name}")
elif self.root_key_path.exists() or self.root_cert_path.exists():
# Only one of the two file exists, bailing out
logging.error("root certificate partially exist on disk, stopping early")
sys.exit(1)
else:
# Start generating the root certificate
self.root_key = rsa.generate_private_key(
public_exponent=65537, key_size=2048
)
rand_suffix = "".join(
random.choices(string.ascii_letters + string.digits, k=16)
)
root_cert_subject = x509.Name(
[
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "CSA"),
x509.NameAttribute(
NameOID.COMMON_NAME, "TC_PAVS root " + rand_suffix
),
]
)
self.root_cert = (
x509.CertificateBuilder()
.subject_name(root_cert_subject)
.issuer_name(root_cert_subject)
.public_key(self.root_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.now(datetime.timezone.utc))
.not_valid_after(
datetime.datetime.now(datetime.timezone.utc) + self.default_ca_duration
)
.add_extension(
# We make it so that our root can only issue leaf certificates, no intermediate here.
x509.BasicConstraints(ca=True, path_length=0), critical=True
)
.add_extension(
x509.KeyUsage(
digital_signature=True,
content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False,
key_cert_sign=True,
crl_sign=True,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
.add_extension(
x509.SubjectKeyIdentifier.from_public_key(
self.root_key.public_key()
),
critical=False,
)
.sign(self.root_key, hashes.SHA256())
)
self._save_cert("root", self.root_cert, self.root_key, False)
logging.info(f"CA Hierarchy generated: {self.name}")
def _save_cert(
self,
name: str,
cert: x509.Certificate,
key: Optional[CertificateIssuerPrivateKeyTypes],
bundle_root: bool,
) -> tuple[Optional[Path], Path]:
"""
Private method that help with saving certificate and key to the hierarchy folder.
This tool isn't meant to be used in production, but instead to help with development
and as such have the goal to make the CA hierarchy as available as possible, which in
turn make it very unsecure.
"""
cert_path = self.directory / f"{name}.pem"
key_path = self.directory / f"{name}.key" if key else None
if key and key_path:
with open(key_path, "wb") as f:
f.write(
key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
)
with open(cert_path, "wb") as f:
f.write(cert.public_bytes(serialization.Encoding.PEM))
if bundle_root:
f.write(b"\n")
f.write(self.root_cert.public_bytes(serialization.Encoding.PEM))
return (key_path, cert_path)
def _sign_cert(
self,
dns: str,
public_key: CertificatePublicKeyTypes,
duration: datetime.timedelta,
ip_address: Optional[str] = None
) -> x509.Certificate:
"""
Generate and sign a certificate.
"""
# Use ip_address for Common Name if provided, otherwise use dns
common_name = ip_address if ip_address else dns
# Sign certificate
subject = x509.Name(
[
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "CSA"),
x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "TC_PAVS"),
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
]
)
extended_key_usage = [x509.ExtendedKeyUsageOID.CLIENT_AUTH] if self.kind == "client" else [
x509.ExtendedKeyUsageOID.SERVER_AUTH]
builder = (x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(self.root_cert.subject)
.public_key(public_key)
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.now(datetime.timezone.utc))
.not_valid_after(
datetime.datetime.now(datetime.timezone.utc) + duration
)
.add_extension(
x509.BasicConstraints(ca=False, path_length=None),
critical=False,
)
.add_extension(
self.client_key_usage_cert if self.kind == "client" else self.server_key_usage_cert,
critical=True,
)
.add_extension(
x509.ExtendedKeyUsage(extended_key_usage),
critical=False,
)
.add_extension(
x509.SubjectKeyIdentifier.from_public_key(public_key),
critical=False,
)
.add_extension(
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
self.root_cert.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier
).value
),
critical=False,
)
.add_extension(x509.CRLDistributionPoints([x509.DistributionPoint(
full_name=[x509.UniformResourceIdentifier("http://not.a.valid.website.com/some/path/to/a.crl")],
relative_name=None,
reasons=None,
crl_issuer=None
)]), critical=False)
)
if self.kind == 'server':
san_names = [x509.DNSName(dns)]
if ip_address:
san_names.append(x509.IPAddress(ipaddress.ip_address(ip_address)))
builder.add_extension(
x509.SubjectAlternativeName(san_names),
critical=False,
)
return builder.sign(self.root_key, hashes.SHA256())
def gen_cert(self, dns: str, csr: str, override=False, duration: datetime.timedelta = datetime.timedelta(hours=1)) -> tuple[Path, Path, bool]:
"""
Generate a certificate signed by this CA hierarchy using the provided CSR.
Returns the path to the key, cert, and whether it was reused or not.
"""
signing_request = x509.load_pem_x509_csr(csr.encode('utf-8'))
signing_request.public_key()
# If we don't always override, first check if an existing keypair already exists
if not override:
cert_path = self.directory / f"{dns}.pem"
key_path = self.directory / f"{dns}.key"
if cert_path.exists() and key_path.exists():
return (key_path, cert_path, True)
# Sign certificate
cert = self._sign_cert(dns, signing_request.public_key(), duration)
# Save that information to disk
(key_path, cert_bundle_path) = self._save_cert(
dns, cert, None, bundle_root=True
)
logging.debug("leaf generated. dns=%s; path=%s", dns, cert_bundle_path)
return (key_path, cert_bundle_path, False)
def gen_keypair(self, dns: str, override=False, duration: datetime.timedelta = datetime.timedelta(hours=1), ip_address: Optional[str] = None) -> tuple[Path, Path, bool]:
"""
Generate a private key as well as the associated certificate signed by this CA
hierarchy. Returns the path to the key, cert, and whether it was reused or not.
"""
# If we don't always override, first check if an existing keypair already exists
if not override:
cert_path = self.directory / f"{dns}.pem"
key_path = self.directory / f"{dns}.key"
if cert_path.exists() and key_path.exists():
return (key_path, cert_path, True)
# Generate private key
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
# Sign certificate
cert = self._sign_cert(dns, key.public_key(), duration, ip_address=ip_address)
# Save that information to disk
(key_path, cert_bundle_path) = self._save_cert(dns, cert, key, bundle_root=True)
logging.debug("leaf generated. dns=%s; path=%s", dns, cert_bundle_path)
return (key_path, cert_bundle_path, False)
class SignClientCertificate(BaseModel):
"""Request model to sign a client certificate"""
csr: str
class TrackNameRequest(BaseModel):
"""Request model to update track name for a stream"""
trackName: str
class SupportedIngestInterface(str, Enum):
cmaf = "cmaf-ingest"
dash = "dash"
hls = "hls"
class PushAvServer:
templates = Jinja2Templates(directory=templates_path)
def __init__(self, wd: WorkingDirectory, device_hierarchy: CAHierarchy, strict_mode: bool):
self.wd = wd
self.device_hierarchy = device_hierarchy
self.strict_mode = strict_mode
self.router = APIRouter()
# In-memory map to track stream files: {stream_id: {"valid_files": [], "invalid_files": []}}
self.stream_files_map = {}
# UI
self.router.add_api_route("/", self.index, methods=["GET"], response_class=RedirectResponse)
self.router.add_api_route("/ui/streams", self.ui_streams_list, methods=["GET"], response_class=HTMLResponse)
self.router.add_api_route("/ui/streams/{stream_id}/{file_path:path}", self.ui_streams_details, methods=["GET"])
self.router.add_api_route("/ui/certificates", self.ui_certificates_list, methods=["GET"], response_class=HTMLResponse)
self.router.add_api_route("/ui/certificates/{hierarchy}/{name}",
self.ui_certificates_details, methods=["GET"], response_class=HTMLResponse)
# HTTP APIs
self.router.add_api_route("/streams", self.create_stream, methods=["POST"], status_code=201)
self.router.add_api_route("/streams", self.list_streams, methods=["GET"])
self.router.add_api_route("/streams/probe/{stream_id}/{file_path:path}", self.ffprobe_check, methods=["GET"])
self.router.add_api_route("/streams/{stream_id}/{file_path:path}.{ext}", self.handle_upload, methods=["PUT"])
self.router.add_api_route("/streams/{stream_id}/{file_path:path}", self.segment_download, methods=["GET"])
self.router.add_api_route("/streams/{stream_id}/trackName", self.update_track_name, methods=["POST"], status_code=202)
self.router.add_api_route("/certs", self.list_certs, methods=["GET"], status_code=200)
self.router.add_api_route("/certs/{hierarchy}/{name}", self.certificate_details, methods=["GET"], status_code=200)
self.router.add_api_route("/certs/{name}/keypair", self.create_client_keypair, methods=["POST"])
self.router.add_api_route("/certs/{name}/sign", self.sign_client_certificate, methods=["POST"])
# Utilities
def _read_stream_details(self, stream_id: int):
p = self.wd.path("streams", str(stream_id), "details.json")
try:
with open(p, 'r') as file:
return json.load(file)
except FileNotFoundError:
raise HTTPException(404, detail="Stream doesn't exists")
except Exception as e:
raise HTTPException(500, f"An unexpected error occurred: {e}")
# UI website
def index(self):
return RedirectResponse("/ui/streams")
def ui_streams_list(self, request: Request):
s = self.list_streams()
return self.templates.TemplateResponse(
request=request, name="streams_list.jinja2", context={"streams": s["streams"]}
)
def ui_streams_details(self, request: Request, stream_id: int, file_path: str):
context = {}
context['streams'] = self.list_streams()['streams']
context['stream_id'] = stream_id
context['file_path'] = file_path
if file_path.endswith('.crt'):
context['type'] = 'cert'
p = self.wd.path("streams", str(stream_id), file_path)
with open(p, "r") as f:
context['cert'] = json.load(f)
elif file_path == 'details.json':
context['type'] = 'details'
context['details'] = self._read_stream_details(stream_id)
else:
context['type'] = 'media'
context['probe'] = self.ffprobe_check(stream_id, file_path)
context['pretty_probe'] = json.dumps(context['probe'], sort_keys=True, indent=4)
return self.templates.TemplateResponse(request=request, name="streams_details.jinja2", context=context)
def ui_certificates_list(self, request: Request):
return self.templates.TemplateResponse(
request=request, name="certificates_list.jinja2", context={"certs": self.list_certs()}
)
def ui_certificates_details(self, request: Request, hierarchy: str, name: str):
context = self.certificate_details(hierarchy, name)
context["certs"] = self.list_certs()
return self.templates.TemplateResponse(request=request, name="certificates_details.jinja2", context=context)
# APIs
def create_stream(self, interface: Optional[SupportedIngestInterface] = None):
# Find the last registered stream
dirs = [d for d in pathlib.Path(self.wd.path("streams")).iterdir() if d.is_dir()]
last_stream = int(dirs[-1].name) if len(dirs) > 0 else 0
stream_id = last_stream + 1
# TODO Add option to specify Interface-1, Interface-2 DASH, or I2-HLS to improve the strict mode
p = self.wd.mkdir("streams", str(stream_id))
stream = {"stream_id": stream_id, "strict_mode": self.strict_mode, "interface": interface}
with open(p / "details.json", 'w', encoding='utf-8') as f:
json.dump(stream, f, ensure_ascii=False, indent=4)
# Initialize entry in stream files map
self.stream_files_map[str(stream_id)] = {"valid_files": [], "invalid_files": []}
return stream
def list_streams(self):
# Return streams directly from the in-memory map
streams = []
for stream_id, stream_data in self.stream_files_map.items():
streams.append({
"id": int(stream_id),
"valid_files": stream_data["valid_files"],
"invalid_files": stream_data["invalid_files"]
})
return {"streams": streams}
async def _handle_upload(self, dst: Path, req: Request):
""" Handle an upload, sending content to disk at 'dst'.
Extract the parsed version of a client certificate via a patched TLS
extension. See https://docs.python.org/3/library/ssl.html#ssl.SSLSocket.getpeercert
for the exact content.
"""
cert_details = req.scope["extensions"]["ssl"]["client_certificate"]
with open(dst.with_suffix(dst.suffix + ".crt"), "w") as f:
f.write(json.dumps(cert_details))
with open(dst, "wb") as f:
async for chunk in req.stream():
f.write(chunk)
return Response(status_code=202)
async def handle_upload(self, stream_id: int, file_path: str, ext: str, req: Request):
"""
Handle any upload if strict-mode isn't enabled.
Otherwise, check if the segment path format complies with Matter Specification path.
"""
stream = self._read_stream_details(stream_id)
is_valid = True
validation_error_reason = ""
if stream.get('strict_mode', False):
if ext not in VALID_EXTENSIONS:
is_valid = False
validation_error_reason = f"Invalid extension: {ext}, valid extensions are {', '.join(VALID_EXTENSIONS)}"
elif ext in ["mpd", "m3u8"]:
iface = stream.get('interface', None)
if (iface == SupportedIngestInterface.dash and ext != "mpd" or
iface == SupportedIngestInterface.hls and ext != "m3u8"):
is_valid = False
validation_error_reason = "Unsupported manifest object extension"
elif ext == "m4s":
# Checks if CMAF extended path matches the pattern session_<SessionNumber>/<TrackName>/segment_<SegmentNumber>
# https://github.com/CHIP-Specifications/connectedhomeip-spec/blob/master/src/app_clusters/PushAVStreamTransport.adoc#12-operation
segment_pattern = re.compile(r"^session_\d+/(?P<trackName>[^/]+)/segment_\d+$")
match = segment_pattern.match(file_path)
if not match:
is_valid = False
validation_error_reason = "Path does not adhere to Matter's extended path format: session_<SessionNumber>/<TrackName>/segment_<SegmentNumber>"
else:
# Validate if the trackName is same as the one assigned during transport allocation.
# https://github.com/CHIP-Specifications/connectedhomeip-spec/blob/master/src/app_clusters/PushAVStreamTransport.adoc#685-trackname-field
track_name_in_path = match.group("trackName")
track_name = stream.get('trackName', None)
if track_name and track_name != track_name_in_path:
is_valid = False
validation_error_reason = ("Track name mismatch: "
f"{track_name_in_path} != {track_name}, "
"must match TrackName provided in ContainerOptions")
dst = self.wd.mkdir("streams", str(stream_id), f"{file_path}.{ext}", is_file=True)
extended_path = f"{file_path}.{ext}"
# Add file to the appropriate list in the stream files map
logging.debug(f"Upload received: {extended_path}")
stream_id_str = str(stream_id)
if stream_id_str in self.stream_files_map:
if is_valid and extended_path not in self.stream_files_map[stream_id_str]["valid_files"]:
self.stream_files_map[stream_id_str]["valid_files"].append(extended_path)
if not is_valid:
logging.error(f"{extended_path}: {validation_error_reason}")
if extended_path not in self.stream_files_map[stream_id_str]["invalid_files"]:
self.stream_files_map[stream_id_str]["invalid_files"].append({
"file_path": extended_path,
"validation_error_reason": validation_error_reason
})
return await self._handle_upload(dst, req)
def ffprobe_check(self, stream_id: int, file_path: str):
p = self.wd.path("streams", str(stream_id), file_path)
if not p.exists():
return HTTPException(404, detail="Stream doesn't exists")
proc = subprocess.run(
["ffprobe", "-show_streams", "-show_format", "-output_format", "json", str(p.absolute())],
capture_output=True
)
if proc.returncode != 0:
# TODO Add more details (maybe stderr) to the response
return HTTPException(500)
return json.loads(proc.stdout)
async def segment_download(self, file_path: str, stream_id: int):
return FileResponse(self.wd.path("streams", str(stream_id), file_path))
def list_certs(self):
server = [f.name for f in pathlib.Path(self.wd.path("certs", "server")).iterdir()]
device = [f.name for f in pathlib.Path(self.wd.path("certs", "device")).iterdir()]
return {"server": server, "device": device}
def certificate_details(self, hierarchy: str, name: str):
data = pathlib.Path(self.wd.path("certs", hierarchy, name)).read_bytes()
type = "key" if name.endswith(".key") else "cert"
key = None
cert = None
if type == "key":
key = serialization.load_pem_private_key(data, None)
key = {
"key_size": key.key_size,
"private_key": key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
),
"public_key": key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.PKCS1,
),
}
else:
cert = x509.load_pem_x509_certificate(data)
cert = {
"public_cert": cert.public_bytes(serialization.Encoding.PEM),
"serial_number": hex(cert.serial_number),
"not_valid_before": cert.not_valid_before_utc,
"not_valid_after": cert.not_valid_after_utc,
# public_key? fingerprint?
"issuer": cert.issuer.rfc4514_string(),
"subject": cert.subject.rfc4514_string(),
"extensions": [str(ext) for ext in cert.extensions]
}
return {"type": type, "key": key, "cert": cert}
def create_client_keypair(self, name: str, override: bool = True):
(key, cert, created) = self.device_hierarchy.gen_keypair(name, override)
return {key, cert, created}
async def update_track_name(self, stream_id: int, track_request: TrackNameRequest):
"""
Updates the trackName for a given stream_id.
"""
stream_details = self._read_stream_details(stream_id)
stream_details["trackName"] = track_request.trackName
details_path = self.wd.path("streams", str(stream_id), "details.json")
try:
with open(details_path, 'w', encoding='utf-8') as f:
json.dump(stream_details, f, ensure_ascii=False, indent=4)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to write stream details: {e}")
def sign_client_certificate(
self, name: str, req: SignClientCertificate, override: bool = True
):
(key, cert, created) = self.device_hierarchy.gen_cert(name, req.csr, override)
return {key, cert, created}
class PushAvContext:
"""Hold the context for a full Push AV Server including temporary disk, CA hierarchies and web server"""
def __init__(self, host: Optional[str], port: Optional[int], working_directory: Optional[str], dns: Optional[str], server_ip: Optional[str], strict_mode: bool):
self.directory = WorkingDirectory(working_directory)
self.host = host
self.port = port
self.dns = "localhost" if dns is None else f"{dns}._http._tcp.local."
self.strict_mode = strict_mode
# Create CA hierarchies (for webserver and devices)
self.device_hierarchy = CAHierarchy(self.directory.mkdir("certs", "device"), "device", "client")
self.server_hierarchy = CAHierarchy(self.directory.mkdir("certs", "server"), "server", "server")
(self.server_key_file, self.server_cert_file, _) = self.server_hierarchy.gen_keypair(self.dns, override=True, ip_address=server_ip)
# mDNS configuration. Registration only happen if the dns isn't localhost.
self.zeroconf = Zeroconf()
self.svc_info = None
if self.dns != "localhost":
self.svc_info = ServiceInfo(
"_http._tcp.local.",
name=self.dns,
addresses=[socket.inet_aton("127.0.0.1")],
port=1234,
)
# Streams holder
self.directory.mkdir("streams")
logger = logging.getLogger("hypercorn.error")
self.app = FastAPI()
self.app.mount("/static", StaticFiles(directory=static_path), name="static")
pas = PushAvServer(self.directory, self.device_hierarchy, strict_mode)
self.app.include_router(pas.router)
@self.app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
logger.error(
f"HTTPExecption: {exc.status_code} {exc.detail}"
)
return JSONResponse(
status_code=exc.status_code,
content={"detail": exc.detail}
)
async def start(self, shutdown_trigger: Optional[Callable[..., Awaitable]] = None):
"""
Start the PUSH AV server. Note that method do not check if a server is already running.
"""
# Advertise over mDNS
if self.svc_info:
logging.info("Advertising the service as %s", self.svc_info)
self.zeroconf.register_service(self.svc_info)
# Start the web server
from hypercorn.asyncio import serve
from hypercorn.config import Config
bind = (self.host or "127.0.0.1") + ":" + (str(self.port or 8000))
config = Config.from_mapping(
bind=bind,
quic_bind=bind,
alpn_protocols=["h2"],
keyfile=self.server_key_file,
certfile=self.server_cert_file,
ca_certs=self.device_hierarchy.root_cert_path,
verify_mode=ssl.CERT_OPTIONAL
)
try:
await serve(self.app, config, shutdown_trigger=shutdown_trigger)
finally:
if self.svc_info:
self.zeroconf.unregister_service(self.svc_info)
def cleanup(self):
self.directory.cleanup()
if __name__ == "__main__":
logging.basicConfig(
format="%(asctime)s|%(name)-8s|%(levelname)-5s|%(message)s",
level=logging.DEBUG,
datefmt="%H:%M:%S",
)
logging.getLogger("hpack").setLevel(logging.WARN)
parser = argparse.ArgumentParser(
prog="push_av_tool.py",
description="Tooling to help test Matter's Push AV capabilities",
)
parser.add_argument("--host", default="localhost")
parser.add_argument("--port", default=1234)
parser.add_argument(
"--working-directory",
help="Where to store content like certificates or uploaded streams. "
"Default to a temporary directory.",
)
parser.add_argument(
"--dns", help="A mDNS record to adversise, or none if left empty."
)
parser.add_argument("--server-ip", help="The IP address of the server to include in the SSL certificate.")
parser.add_argument("--strict-mode", action='store_true',
help="When enabled, upload must happen on the path described by the Matter specification")
args = parser.parse_args()
ctx = PushAvContext(args.host, args.port, args.working_directory, args.dns, args.server_ip, args.strict_mode)
shutdown_event = asyncio.Event()
def _signal_handler():
print("SIGINT received. Shutting down web server.")
shutdown_event.set()
loop = asyncio.get_event_loop()
loop.add_signal_handler(signal.SIGINT, _signal_handler)
loop.run_until_complete(ctx.start(shutdown_trigger=shutdown_event.wait))
ctx.cleanup()