blob: a1cb0e98e9deb3bdeefd0139c4bb90c88809aa6c [file] [log] [blame]
// Copyright 2023 The Pigweed Authors
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.
use std::{
collections::HashMap,
path::{Path, PathBuf},
};
use tokio::{
fs::{self, File},
io::{AsyncSeekExt, AsyncWriteExt},
sync::mpsc,
};
use crate::{
digest::ExpectedDigest,
executor::{ExecutionContext, ExecutionStatus, ExecutionStatusMsg},
target::{Download, DownloadVariant},
Error, Result,
};
const MAX_DOWNLOAD_ATTEMPTS: u32 = 5;
/// The file format of a download target.
#[derive(Debug)]
pub enum Format {
/// An uncompressed binary file, storing the relative path to which it
/// should be downloaded.
Binary(PathBuf),
}
/// Returns the URL of a download target, filling in any parameter substitutions.
fn download_url(metadata: &Download, variant: Option<&DownloadVariant>) -> Result<String> {
let mut params: HashMap<&str, &str> = metadata
.url_parameters
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect();
if let Some(variant) = variant {
params.extend(
variant
.url_parameters
.iter()
.map(|(k, v)| (k.as_str(), v.as_str())),
);
}
metadata.url.substitute(&params)
}
pub(crate) struct Downloader<'a> {
context: &'a ExecutionContext,
metadata: &'a Download,
status_tx: &'a mpsc::UnboundedSender<ExecutionStatusMsg>,
}
enum Status {
Retry,
Failure(Error),
}
impl<'a> Downloader<'a> {
/// Instantiates a downloader for a target within the context of an
/// [`Executor`](crate::executor::Executor) run.
///
/// # Errors
/// If `context.target` is not a downloadable target, returns an error.
pub fn new(
context: &'a ExecutionContext,
metadata: &'a Download,
status_tx: &'a mpsc::UnboundedSender<ExecutionStatusMsg>,
) -> Self {
Self {
context,
metadata,
status_tx,
}
}
/// Run the target to completion, downloading the file to its output directory.
pub async fn run(&self) -> Result<()> {
let status = match self.run_result().await {
Ok(_) => ExecutionStatus::Complete,
Err(e) => ExecutionStatus::Failed(e.to_string()),
};
// TODO(frolv): Remove this and use the return value of this function.
self.status_tx
.send(ExecutionStatusMsg {
name: self.context.target.full_name(),
status,
})
.map_err(|e| Error::StringErrorPlaceholder(format!("error sending status: {e}")))
}
async fn run_result(&self) -> Result<()> {
let variant = self
.metadata
.variants
.iter()
.find(|v| v.matches.matches(&self.context.target_platform));
let url = download_url(self.metadata, variant)?;
let digest = match variant {
Some(v) => v.digest.as_ref().or(self.metadata.digest.as_ref()),
None => self.metadata.digest.as_ref(),
};
let Some(digest) = digest else {
return Err(Error::StringErrorPlaceholder("no digest provided".into()));
};
let tmpfile_path = self.temporary_file();
let mut tmpfile = File::create(&tmpfile_path).await?;
self.retry_download(&url, digest, &mut tmpfile).await?;
match &self.metadata.format {
Format::Binary(bin_name) => {
self.handle_binary_file(&tmpfile, &tmpfile_path, bin_name)
.await
}
}
}
/// Attempt to download from the URL to `tmpfile` several times unless a fatal error occurs.
async fn retry_download(
&self,
url: &str,
digest: &ExpectedDigest,
tmpfile: &mut File,
) -> Result<()> {
for _attempt in 0..MAX_DOWNLOAD_ATTEMPTS {
tmpfile.rewind().await?;
match self.run_download(url, digest, tmpfile).await {
Ok(()) => return Ok(()),
Err(Status::Retry) => continue,
Err(Status::Failure(e)) => return Err(e),
};
}
Err(Error::StringErrorPlaceholder(format!(
"failed to download after {} attempts",
MAX_DOWNLOAD_ATTEMPTS,
)))
}
async fn run_download(
&self,
url: &str,
digest: &ExpectedDigest,
tmpfile: &mut File,
) -> std::result::Result<(), Status> {
let download_error_to_status = |e: reqwest::Error| {
// TODO(frolv): Expand this with more error cases and logging.
if e.is_status() {
Status::Failure(Error::StringErrorPlaceholder(format!(
"failed to download: {e}",
)))
} else {
Status::Retry
}
};
let mut verifier = digest.verifier();
let mut response = reqwest::get(url).await.map_err(download_error_to_status)?;
let maybe_length = response.content_length();
let mut downloaded_size = 0u64;
// Stream the chunks of the file, updating the checksum and sending
// progress reports.
loop {
let chunk = match response.chunk().await {
Ok(Some(c)) => c,
Ok(None) => break,
Err(e) => return Err(download_error_to_status(e)),
};
tmpfile
.write_all(&chunk)
.await
.map_err(|e| Status::Failure(e.into()))?;
verifier.update(&chunk);
downloaded_size += chunk.len() as u64;
if let Some(len) = maybe_length {
// TODO(frolv): Create a context API for sending progress updates.
self.status_tx
.send(ExecutionStatusMsg {
name: self.context.target.full_name(),
status: ExecutionStatus::InProgress {
current: downloaded_size,
total: len,
unit: "B",
},
})
.map_err(|e| {
Status::Failure(Error::StringErrorPlaceholder(format!(
"error sending status: {e}"
)))
})?;
}
}
if !verifier.verify() {
return Err(Status::Failure(Error::StringErrorPlaceholder(
"digest of downloaded file does not match expected".into(),
)));
}
Ok(())
}
/// Makes `tmpfile` executable and renames it to its final binary path.
async fn handle_binary_file(
&self,
tmpfile: &File,
tmpfile_path: &Path,
bin_name: &Path,
) -> Result<()> {
// TODO(frolv): Handle non-UNIX platforms.
if cfg!(unix) {
use std::os::unix::prelude::PermissionsExt;
let mut permissions = tmpfile.metadata().await?.permissions();
permissions.set_mode(0o755);
tmpfile.set_permissions(permissions).await?;
}
let download_path = self.context.output_dir.join(bin_name);
fs::rename(&tmpfile_path, &download_path).await?;
Ok(())
}
/// Returns the path to a temporary file to which downloaded content can be
/// written.
fn temporary_file(&self) -> PathBuf {
match &self.metadata.format {
Format::Binary(bin_name) => self.context.work_dir.join(bin_name),
}
}
}