blob: a977f74b003df8df20ad2b558219b84a51b92589 [file] [log] [blame]
import argparse
import datetime
import json
import logging
import multiprocessing
import os.path
import pathlib
import random
import socket
import ssl
import string
import subprocess
import sys
import tempfile
from pathlib import Path
from typing import Literal, Optional, Union
import uvicorn
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric.types import CertificateIssuerPrivateKeyTypes, CertificatePublicKeyTypes
from cryptography.x509.oid import NameOID
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response
from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
# Monkey patch uvicorn to make the underlying transport available to us.
# That will let us access the ssl context and get the client certificate information.
from uvicorn.protocols.http.h11_impl import H11Protocol
from zeroconf import ServiceInfo, Zeroconf
http_tools_protocol_old__should_upgrade = H11Protocol._should_upgrade
def http_tools_protocol_new__should_upgrade(self):
http_tools_protocol_old__should_upgrade(self)
self.scope["transport"] = self.transport
H11Protocol._should_upgrade = http_tools_protocol_new__should_upgrade
# End monkey patch
class WorkingDirectory:
"""
Collection of utilities to add convention to the files used by this program.
"""
tmp = None
def __init__(self, directory: Optional[str] = None) -> None:
if directory is None:
self.tmp = tempfile.TemporaryDirectory(prefix="TC_PAVS_1_0")
else:
d = pathlib.Path(directory)
d.mkdir(parents=True, exist_ok=True)
self.directory = d
def __enter__(self):
return self
def __exit__(self, exc, value, tb):
self.cleanup()
def cleanup(self):
if self.tmp:
self.tmp.cleanup()
def root_dir(self) -> Path:
return Path(self.tmp.name) if self.tmp else self.directory
def path(self, *paths: str) -> Path:
return Path(os.path.join(self.root_dir(), *paths))
def mkdir(self, *paths: str, is_file=False) -> Path:
"""
Create a directory using the given path rooted in the working directory.
If a file is provided, the directory up to that file will be created instead.
Returns the full path.
"""
p = self.path(*paths)
# Let's create the parent directories exist
p2 = pathlib.Path(p)
if is_file:
p2 = p2.parent
p2.mkdir(parents=True, exist_ok=True)
return p
def print_tree(self):
# TODO Convert this helper to build a HTML representation for use in the UI
def tree(dir_path: pathlib.Path, prefix: str = ""):
"""A recursive generator, given a directory Path object
will yield a visual tree structure line by line
with each line prefixed by the same characters
"""
# prefix components:
space = " "
branch = "│ "
# pointers:
tee = "├── "
last = "└── "
contents = list(dir_path.iterdir())
# contents each get pointers that are ├── with a final └── :
pointers = [tee] * (len(contents) - 1) + [last]
for pointer, path in zip(pointers, contents):
is_dir = path.is_dir()
yield prefix + pointer + path.name + ("/" if is_dir else "")
if path.is_dir(): # extend the prefix and recurse:
extension = branch if pointer == tee else space
# i.e. space because last, └── , above so no more |
yield from tree(path, prefix=prefix + extension)
root = self.root_dir()
print(root)
for line in tree(pathlib.Path(root)):
print(line)
class CAHierarchy:
"""
Utilities to manage a CA hierarchy on disk.
"""
default_ca_duration = datetime.timedelta(days=365.25*20)
client_key_usage_cert = x509.KeyUsage(
digital_signature=True,
content_commitment=False,
key_encipherment=True,
data_encipherment=False,
key_agreement=False,
key_cert_sign=False,
crl_sign=False,
encipher_only=False,
decipher_only=False,
)
server_key_usage_cert = x509.KeyUsage(
digital_signature=True,
content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False,
key_cert_sign=False,
crl_sign=False,
encipher_only=False,
decipher_only=False,
)
def __init__(self, base: Path, name: str, kind: Literal['server', 'client']) -> None:
self.name = name
self.kind = kind
self.directory = base
self.root_cert_path = self.directory / "root.pem"
self.root_key_path = self.directory / "root.key"
if self.root_key_path.exists() and self.root_cert_path.exists():
# Root certificate already exists, re-using them
self.root_cert = x509.load_pem_x509_certificate(
self.root_cert_path.read_bytes()
)
self.root_key = serialization.load_pem_private_key(
self.root_key_path.read_bytes(), None
)
logging.info(f"CA Hierarchy loaded from disk: {self.name}")
elif self.root_key_path.exists() or self.root_cert_path.exists():
# Only one of the two file exists, bailing out
logging.error("root certificate partially exist on disk, stopping early")
sys.exit(1)
else:
# Start generating the root certificate
self.root_key = rsa.generate_private_key(
public_exponent=65537, key_size=2048
)
rand_suffix = "".join(
random.choices(string.ascii_letters + string.digits, k=16)
)
root_cert_subject = x509.Name(
[
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "CSA"),
x509.NameAttribute(
NameOID.COMMON_NAME, "TC_PAVS root " + rand_suffix
),
]
)
self.root_cert = (
x509.CertificateBuilder()
.subject_name(root_cert_subject)
.issuer_name(root_cert_subject)
.public_key(self.root_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.now(datetime.timezone.utc))
.not_valid_after(
datetime.datetime.now(datetime.timezone.utc) + self.default_ca_duration
)
.add_extension(
# We make it so that our root can only issue leaf certificates, no intermediate here.
x509.BasicConstraints(ca=True, path_length=0), critical=True
)
.add_extension(
x509.KeyUsage(
digital_signature=True,
content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False,
key_cert_sign=True,
crl_sign=True,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
.add_extension(
x509.SubjectKeyIdentifier.from_public_key(
self.root_key.public_key()
),
critical=False,
)
.sign(self.root_key, hashes.SHA256())
)
self._save_cert("root", self.root_cert, self.root_key, False)
logging.info(f"CA Hierarchy generated: {self.name}")
def _save_cert(
self,
name: str,
cert: x509.Certificate,
key: Union[CertificateIssuerPrivateKeyTypes, None],
bundle_root: bool,
) -> tuple[str, str]:
"""
Private method that help with saving certificate and key to the hierarchy folder.
This tool isn't meant to be used in production, but instead to help with development
and as such have the goal to make the CA hierarchy as available as possible, which in
turn make it very unsecure.
"""
cert_path = self.directory / f"{name}.pem"
key_path = self.directory / f"{name}.key" if key else None
if key:
with open(key_path, "wb") as f:
f.write(
key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
)
with open(cert_path, "wb") as f:
f.write(cert.public_bytes(serialization.Encoding.PEM))
if bundle_root:
f.write(b"\n")
f.write(self.root_cert.public_bytes(serialization.Encoding.PEM))
return (key_path, cert_path)
def _sign_cert(
self,
dns: str,
public_key: CertificatePublicKeyTypes,
duration: datetime.timedelta
) -> x509.Certificate:
"""
Generate and sign a certificate.
"""
# Sign certificate
subject = x509.Name(
[
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "CSA"),
x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "TC_PAVS"),
x509.NameAttribute(NameOID.COMMON_NAME, dns),
]
)
extended_key_usage = [x509.ExtendedKeyUsageOID.CLIENT_AUTH] if self.kind == "client" else [
x509.ExtendedKeyUsageOID.SERVER_AUTH]
builder = (x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(self.root_cert.subject)
.public_key(public_key)
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.now(datetime.timezone.utc))
.not_valid_after(
datetime.datetime.now(datetime.timezone.utc) + duration
)
.add_extension(
x509.BasicConstraints(ca=False, path_length=None),
critical=False,
)
.add_extension(
self.client_key_usage_cert if self.kind == "client" else self.server_key_usage_cert,
critical=True,
)
.add_extension(
x509.ExtendedKeyUsage(extended_key_usage),
critical=False,
)
.add_extension(
x509.SubjectKeyIdentifier.from_public_key(public_key),
critical=False,
)
.add_extension(
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
self.root_cert.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier
).value
),
critical=False,
)
.add_extension(x509.CRLDistributionPoints([x509.DistributionPoint(
full_name=[x509.UniformResourceIdentifier("http://not.a.valid.website.com/some/path/to/a.crl")],
relative_name=None,
reasons=None,
crl_issuer=None
)]), critical=False)
)
if self.kind == 'server':
builder.add_extension(
x509.SubjectAlternativeName([x509.DNSName(dns)]),
critical=False,
)
return builder.sign(self.root_key, hashes.SHA256())
def gen_cert(self, dns: str, csr: str, override=False, duration: datetime.timedelta = datetime.timedelta(hours=1)) -> tuple[Path, Path, bool]:
"""
Generate a certificate signed by this CA hierarchy using the provided CSR.
Returns the path to the key, cert, and whether it was reused or not.
"""
signing_request = x509.load_pem_x509_csr(csr.encode('utf-8'))
signing_request.public_key()
# If we don't always override, first check if an existing keypair already exists
if not override:
cert_path = self.directory / f"{dns}.pem"
key_path = self.directory / f"{dns}.key"
if cert_path.exists() and key_path.exists():
return (key_path, cert_path, True)
# Sign certificate
cert = self._sign_cert(dns, signing_request.public_key(), duration)
# Save that information to disk
(key_path, cert_bundle_path) = self._save_cert(
dns, cert, None, bundle_root=True
)
logging.debug("leaf generated. dns=%s; path=%s", dns, cert_bundle_path)
return (key_path, cert_bundle_path, False)
def gen_keypair(self, dns: str, override=False, duration: datetime.timedelta = datetime.timedelta(hours=1)) -> tuple[Path, Path, bool]:
"""
Generate a private key as well as the associated certificate signed by this CA
hierarchy. Returns the path to the key, cert, and whether it was reused or not.
"""
# If we don't always override, first check if an existing keypair already exists
if not override:
cert_path = self.directory / f"{dns}.pem"
key_path = self.directory / f"{dns}.key"
if cert_path.exists() and key_path.exists():
return (key_path, cert_path, True)
# Generate private key
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
# Sign certificate
cert = self._sign_cert(dns, key.public_key(), duration)
# Save that information to disk
(key_path, cert_bundle_path) = self._save_cert(dns, cert, key, bundle_root=True)
logging.debug("leaf generated. dns=%s; path=%s", dns, cert_bundle_path)
return (key_path, cert_bundle_path, False)
class SignClientCertificate(BaseModel):
"""Request model to sign a client certificate"""
csr: str
class PushAvServer:
templates = Jinja2Templates(directory="templates")
def __init__(self, wd: WorkingDirectory, device_hierarchy: CAHierarchy):
self.wd = wd
self.device_hierarchy = device_hierarchy
self.router = APIRouter()
# UI
self.router.add_api_route("/", self.index, methods=["GET"], response_class=RedirectResponse)
self.router.add_api_route("/ui/streams", self.ui_streams_list, methods=["GET"], response_class=HTMLResponse)
self.router.add_api_route("/ui/streams/{stream_id}/{file_path:path}", self.ui_streams_details, methods=["GET"])
self.router.add_api_route("/ui/certificates", self.ui_certificates_list, methods=["GET"], response_class=HTMLResponse)
self.router.add_api_route("/ui/certificates/{hierarchy}/{name}",
self.ui_certificates_details, methods=["GET"], response_class=HTMLResponse)
# HTTP APIs
self.router.add_api_route("/streams", self.create_stream, methods=["POST"], status_code=201)
self.router.add_api_route("/streams", self.list_streams, methods=["GET"])
self.router.add_api_route("/streams/probe/{stream_id}/{file_path:path}", self.ffprobe_check, methods=["GET"])
self.router.add_api_route("/streams/{stream_id}", self.manifest_upload, methods=["PUT"])
self.router.add_api_route("/streams/{stream_id}/{file_path:path}", self.segment_upload, methods=["PUT"], status_code=202)
self.router.add_api_route("/streams/{stream_id}/{file_path:path}", self.segment_download, methods=["GET"])
self.router.add_api_route("/certs", self.list_certs, methods=["GET"], status_code=200)
self.router.add_api_route("/certs/{hierarchy}/{name}", self.certificate_details, methods=["GET"], status_code=200)
self.router.add_api_route("/certs/{name}/keypair", self.create_client_keypair, methods=["POST"])
self.router.add_api_route("/certs/{name}/sign", self.sign_client_certificate, methods=["POST"])
# UI website
def index(self):
return RedirectResponse("/ui/streams")
def ui_streams_list(self, request: Request):
s = self.list_streams()
return self.templates.TemplateResponse(
request=request, name="streams_list.jinja2", context={"streams": s["streams"]}
)
def ui_streams_details(self, request: Request, stream_id: int, file_path: str):
context = {}
context['streams'] = self.list_streams()['streams']
context['stream_id'] = stream_id
context['file_path'] = file_path
if file_path.endswith('.crt'):
context['type'] = 'cert'
p = self.wd.path("streams", str(stream_id), file_path)
with open(p, "r") as f:
context['cert'] = json.load(f)
else:
context['type'] = 'media'
context['probe'] = self.ffprobe_check(stream_id, file_path)
context['pretty_probe'] = json.dumps(context['probe'], sort_keys=True, indent=4)
return self.templates.TemplateResponse(request=request, name="streams_details.jinja2", context=context)
def ui_certificates_list(self, request: Request):
return self.templates.TemplateResponse(
request=request, name="certificates_list.jinja2", context={"certs": self.list_certs()}
)
def ui_certificates_details(self, request: Request, hierarchy: str, name: str):
context = self.certificate_details(hierarchy, name)
context["certs"] = self.list_certs()
return self.templates.TemplateResponse(request=request, name="certificates_details.jinja2", context=context)
# APIs
def create_stream(self):
# Find the last registered stream
dirs = [d for d in pathlib.Path(self.wd.path("streams")).iterdir() if d.is_dir()]
last_stream = int(dirs[-1].name) if len(dirs) > 0 else 0
stream_id = last_stream + 1
self.wd.mkdir("streams", str(stream_id))
return {"stream_id": stream_id}
def list_streams(self):
dirs = [d for d in pathlib.Path(self.wd.path("streams")).iterdir() if d.is_dir()]
def stream_files(dir: Path):
return [f.relative_to(dir) for f in dir.glob("**/*") if f.is_file()]
streams = [{"id": d.name, "files": stream_files(d)} for d in dirs]
return {"streams": streams}
async def manifest_upload(self, stream_id: int, req: Request):
"""The DASH manifest is uploaded onto the base path without any file path"""
# Here we assume that no camera will upload an index.mpd file on their own.
# That is something that may not be true, in which case we would have to add
# another layer of abstraction on the file system where we can store the mpd
# file and the camera direct uploads.
return await self.segment_upload("index.mpd", stream_id, req)
async def segment_upload(self, file_path: str, stream_id: int, req: Request):
"""Extract the parsed version of a client certificate.
See https://docs.python.org/3/library/ssl.html#ssl.SSLSocket.getpeercert
for the exact content.
"""
cert_details = req.scope["transport"].get_extra_info("ssl_object").getpeercert()
logging.debug(f"segment_upload. stream_id={stream_id} file_path:{file_path}")
if not self.wd.path("streams", str(stream_id)).exists():
raise HTTPException(404, detail="Stream doesn't exists")
dst = self.wd.mkdir("streams", str(stream_id), file_path, is_file=True)
with open(dst.with_suffix(dst.suffix + ".crt"), "w") as f:
f.write(json.dumps(cert_details))
with open(dst, "wb") as f:
async for chunk in req.stream():
f.write(chunk)
return Response(status_code=202)
def ffprobe_check(self, stream_id: int, file_path: str):
p = self.wd.path("streams", str(stream_id), file_path)
if not p.exists():
return HTTPException(404, detail="Stream doesn't exists")
proc = subprocess.run(
["ffprobe", "-show_streams", "-show_format", "-output_format", "json", str(p.absolute())],
capture_output=True
)
if proc.returncode != 0:
# TODO Add more details (maybe stderr) to the response
return HTTPException(500)
return json.loads(proc.stdout)
async def segment_download(self, file_path: str, stream_id: int):
return FileResponse(self.wd.path("streams", str(stream_id), file_path))
def list_certs(self):
server = [f.name for f in pathlib.Path(self.wd.path("certs", "server")).iterdir()]
device = [f.name for f in pathlib.Path(self.wd.path("certs", "device")).iterdir()]
return {"server": server, "device": device}
def certificate_details(self, hierarchy: str, name: str):
data = pathlib.Path(self.wd.path("certs", hierarchy, name)).read_bytes()
type = "key" if name.endswith(".key") else "cert"
key = None
cert = None
if type == "key":
key = serialization.load_pem_private_key(data, None)
key = {
"key_size": key.key_size,
"private_key": key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
),
"public_key": key.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.PKCS1,
),
}
else:
cert = x509.load_pem_x509_certificate(data)
cert = {
"public_cert": cert.public_bytes(serialization.Encoding.PEM),
"serial_number": hex(cert.serial_number),
"not_valid_before": cert.not_valid_before_utc,
"not_valid_after": cert.not_valid_after_utc,
# public_key? fingerprint?
"issuer": cert.issuer.rfc4514_string(),
"subject": cert.subject.rfc4514_string(),
"extensions": [str(ext) for ext in cert.extensions]
}
return {"type": type, "key": key, "cert": cert}
def create_client_keypair(self, name: str, override: bool = True):
(key, cert, created) = self.device_hierarchy.gen_keypair(name, override)
return {key, cert, created}
def sign_client_certificate(
self, name: str, req: SignClientCertificate, override: bool = True
):
(key, cert, created) = self.device_hierarchy.gen_cert(name, req.csr, override)
return {key, cert, created}
class PushAvContext:
"""Hold the context for a full Push AV Server including temporary disk, CA hierarchies and web server"""
def __init__(self, host: Optional[str], port: Optional[int], working_directory: Optional[str], dns: Optional[str]):
self.directory = WorkingDirectory(working_directory)
self.host = host
self.port = port
self.dns = "localhost" if dns is None else f"{dns}._http._tcp_.local."
self.proc: multiprocessing.Process | None = None
# Create CA hierarchies (for webserver and devices)
self.device_hierarchy = CAHierarchy(self.directory.mkdir("certs", "device"), "device", "client")
self.server_hierarchy = CAHierarchy(self.directory.mkdir("certs", "server"), "server", "server")
(self.server_key_file, self.server_cert_file, _) = self.server_hierarchy.gen_keypair(self.dns, override=True)
# mDNS configuration. Registration only happen if the dns isn't localhost.
self.zeroconf = Zeroconf()
self.svc_info = None
if self.dns != "localhost":
self.svc_info = ServiceInfo(
"_http._tcp.local.",
name=self.dns,
addresses=[socket.inet_aton("127.0.0.1")],
port=1234,
)
# Streams holder
self.directory.mkdir("streams")
self.app = FastAPI()
self.app.mount("/static", StaticFiles(directory="static"), name="static")
pas = PushAvServer(self.directory, self.device_hierarchy)
self.app.include_router(pas.router)
def start_in_background(self):
if self.proc:
logging.warning("Attempting to start a server when one is already running, no new server is being started.")
return
# Advertise over mDNS
if self.svc_info:
logging.info("Advertising the service as %s", self.svc_info)
self.zeroconf.register_service(self.svc_info)
def background_job():
# Start the web server
try:
uvicorn.run(
self.app,
host=self.host,
port=self.port,
ssl_keyfile=self.server_key_file,
ssl_certfile=self.server_cert_file,
ssl_cert_reqs=ssl.CERT_OPTIONAL,
ssl_ca_certs=self.device_hierarchy.root_cert_path,
)
finally:
if self.svc_info:
self.zeroconf.unregister_service(self.svc_info)
# Spawning the function results in python not being able to pickle the full context
# (most notably cryptography's rust bindings). So instead we force use forks as the
# way to create processes.
multiprocessing.set_start_method('fork')
self.proc = multiprocessing.Process(target=background_job, daemon=True)
self.proc.start()
def terminate(self):
self.proc.terminate()
self.directory.cleanup()
if __name__ == "__main__":
logging.basicConfig(
format="%(asctime)s|%(name)-8s|%(levelname)-5s|%(message)s",
level=logging.DEBUG,
datefmt="%H:%M:%S",
)
parser = argparse.ArgumentParser(
prog="push_av_tool.py",
description="Tooling to help test Matter's Push AV capabilities",
)
parser.add_argument("--host", default="localhost")
parser.add_argument("--port", default=1234)
parser.add_argument(
"--working-directory",
help="Where to store content like certificates or uploaded streams. "
"Default to a temporary directory.",
)
parser.add_argument(
"--dns", help="A mDNS record to adversise, or none if left empty."
)
args = parser.parse_args()
ctx = PushAvContext(args.host, args.port, args.working_directory, args.dns)
ctx.start_in_background()
print(ctx.proc)
ctx.proc.join()
ctx.terminate()