| // Copyright 2023 The Pigweed Authors |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); you may not |
| // use this file except in compliance with the License. You may obtain a copy of |
| // the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
| // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
| // License for the specific language governing permissions and limitations under |
| // the License. |
| |
| use std::{ |
| collections::HashMap, |
| path::{Path, PathBuf}, |
| }; |
| |
| use tokio::{ |
| fs::{self, File}, |
| io::{AsyncSeekExt, AsyncWriteExt}, |
| sync::mpsc, |
| }; |
| |
| use crate::{ |
| digest::ExpectedDigest, |
| executor::{ExecutionContext, ExecutionStatus, ExecutionStatusMsg}, |
| target::{Download, DownloadVariant}, |
| Error, Result, |
| }; |
| |
| const MAX_DOWNLOAD_ATTEMPTS: u32 = 5; |
| |
| /// The file format of a download target. |
| #[derive(Debug)] |
| pub enum Format { |
| /// An uncompressed binary file, storing the relative path to which it |
| /// should be downloaded. |
| Binary(PathBuf), |
| } |
| |
| /// Returns the URL of a download target, filling in any parameter substitutions. |
| fn download_url(metadata: &Download, variant: Option<&DownloadVariant>) -> Result<String> { |
| let mut params: HashMap<&str, &str> = metadata |
| .url_parameters |
| .iter() |
| .map(|(k, v)| (k.as_str(), v.as_str())) |
| .collect(); |
| if let Some(variant) = variant { |
| params.extend( |
| variant |
| .url_parameters |
| .iter() |
| .map(|(k, v)| (k.as_str(), v.as_str())), |
| ); |
| } |
| |
| metadata.url.substitute(¶ms) |
| } |
| |
| pub(crate) struct Downloader<'a> { |
| context: &'a ExecutionContext, |
| metadata: &'a Download, |
| status_tx: &'a mpsc::UnboundedSender<ExecutionStatusMsg>, |
| } |
| |
| enum Status { |
| Retry, |
| Failure(Error), |
| } |
| |
| impl<'a> Downloader<'a> { |
| /// Instantiates a downloader for a target within the context of an |
| /// [`Executor`](crate::executor::Executor) run. |
| /// |
| /// # Errors |
| /// If `context.target` is not a downloadable target, returns an error. |
| pub fn new( |
| context: &'a ExecutionContext, |
| metadata: &'a Download, |
| status_tx: &'a mpsc::UnboundedSender<ExecutionStatusMsg>, |
| ) -> Self { |
| Self { |
| context, |
| metadata, |
| status_tx, |
| } |
| } |
| |
| /// Run the target to completion, downloading the file to its output directory. |
| pub async fn run(&self) -> Result<()> { |
| let status = match self.run_result().await { |
| Ok(_) => ExecutionStatus::Complete, |
| Err(e) => ExecutionStatus::Failed(e.to_string()), |
| }; |
| |
| // TODO(frolv): Remove this and use the return value of this function. |
| self.status_tx |
| .send(ExecutionStatusMsg { |
| name: self.context.target.full_name(), |
| status, |
| }) |
| .map_err(|e| Error::StringErrorPlaceholder(format!("error sending status: {e}"))) |
| } |
| |
| async fn run_result(&self) -> Result<()> { |
| let variant = self |
| .metadata |
| .variants |
| .iter() |
| .find(|v| v.matches.matches(&self.context.target_platform)); |
| |
| let url = download_url(self.metadata, variant)?; |
| |
| let digest = match variant { |
| Some(v) => v.digest.as_ref().or(self.metadata.digest.as_ref()), |
| None => self.metadata.digest.as_ref(), |
| }; |
| let Some(digest) = digest else { |
| return Err(Error::StringErrorPlaceholder("no digest provided".into())); |
| }; |
| |
| let tmpfile_path = self.temporary_file(); |
| let mut tmpfile = File::create(&tmpfile_path).await?; |
| |
| self.retry_download(&url, digest, &mut tmpfile).await?; |
| |
| match &self.metadata.format { |
| Format::Binary(bin_name) => { |
| self.handle_binary_file(&tmpfile, &tmpfile_path, bin_name) |
| .await |
| } |
| } |
| } |
| |
| /// Attempt to download from the URL to `tmpfile` several times unless a fatal error occurs. |
| async fn retry_download( |
| &self, |
| url: &str, |
| digest: &ExpectedDigest, |
| tmpfile: &mut File, |
| ) -> Result<()> { |
| for _attempt in 0..MAX_DOWNLOAD_ATTEMPTS { |
| tmpfile.rewind().await?; |
| |
| match self.run_download(url, digest, tmpfile).await { |
| Ok(()) => return Ok(()), |
| Err(Status::Retry) => continue, |
| Err(Status::Failure(e)) => return Err(e), |
| }; |
| } |
| |
| Err(Error::StringErrorPlaceholder(format!( |
| "failed to download after {} attempts", |
| MAX_DOWNLOAD_ATTEMPTS, |
| ))) |
| } |
| |
| async fn run_download( |
| &self, |
| url: &str, |
| digest: &ExpectedDigest, |
| tmpfile: &mut File, |
| ) -> std::result::Result<(), Status> { |
| let download_error_to_status = |e: reqwest::Error| { |
| // TODO(frolv): Expand this with more error cases and logging. |
| if e.is_status() { |
| Status::Failure(Error::StringErrorPlaceholder(format!( |
| "failed to download: {e}", |
| ))) |
| } else { |
| Status::Retry |
| } |
| }; |
| |
| let mut verifier = digest.verifier(); |
| |
| let mut response = reqwest::get(url).await.map_err(download_error_to_status)?; |
| let maybe_length = response.content_length(); |
| let mut downloaded_size = 0u64; |
| |
| // Stream the chunks of the file, updating the checksum and sending |
| // progress reports. |
| loop { |
| let chunk = match response.chunk().await { |
| Ok(Some(c)) => c, |
| Ok(None) => break, |
| Err(e) => return Err(download_error_to_status(e)), |
| }; |
| |
| tmpfile |
| .write_all(&chunk) |
| .await |
| .map_err(|e| Status::Failure(e.into()))?; |
| verifier.update(&chunk); |
| downloaded_size += chunk.len() as u64; |
| |
| if let Some(len) = maybe_length { |
| // TODO(frolv): Create a context API for sending progress updates. |
| self.status_tx |
| .send(ExecutionStatusMsg { |
| name: self.context.target.full_name(), |
| status: ExecutionStatus::InProgress { |
| current: downloaded_size, |
| total: len, |
| unit: "B", |
| }, |
| }) |
| .map_err(|e| { |
| Status::Failure(Error::StringErrorPlaceholder(format!( |
| "error sending status: {e}" |
| ))) |
| })?; |
| } |
| } |
| |
| if !verifier.verify() { |
| return Err(Status::Failure(Error::StringErrorPlaceholder( |
| "digest of downloaded file does not match expected".into(), |
| ))); |
| } |
| |
| Ok(()) |
| } |
| |
| /// Makes `tmpfile` executable and renames it to its final binary path. |
| async fn handle_binary_file( |
| &self, |
| tmpfile: &File, |
| tmpfile_path: &Path, |
| bin_name: &Path, |
| ) -> Result<()> { |
| // TODO(frolv): Handle non-UNIX platforms. |
| if cfg!(unix) { |
| use std::os::unix::prelude::PermissionsExt; |
| let mut permissions = tmpfile.metadata().await?.permissions(); |
| permissions.set_mode(0o755); |
| tmpfile.set_permissions(permissions).await?; |
| } |
| |
| let download_path = self.context.output_dir.join(bin_name); |
| fs::rename(&tmpfile_path, &download_path).await?; |
| Ok(()) |
| } |
| |
| /// Returns the path to a temporary file to which downloaded content can be |
| /// written. |
| fn temporary_file(&self) -> PathBuf { |
| match &self.metadata.format { |
| Format::Binary(bin_name) => self.context.work_dir.join(bin_name), |
| } |
| } |
| } |