blob: 06034832e38ce4ae3c03673b34253e538d2a7d5a [file] [log] [blame]
// Copyright 2023 The Pigweed Authors
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.
use std::{
collections::{BTreeMap, HashSet},
path::PathBuf,
sync::Arc,
};
use futures::future::join_all;
use tokio::{sync::mpsc, task::JoinHandle};
use crate::{registry::Dependency, Error, Project, Result, Target};
#[cfg(test)]
use crate::target::{Fake, Metadata};
#[derive(Debug, Eq, PartialEq)]
pub enum ExecutionStatus {
#[allow(unused)]
InProgress {
current: u64,
total: u64,
unit: &'static str,
},
#[allow(unused)]
Complete,
Failed(String),
}
#[derive(Debug, Eq, PartialEq)]
pub struct ExecutionStatusMsg {
pub name: String,
pub status: ExecutionStatus,
}
#[derive(Debug)]
pub struct ExecutionContext {
target: Arc<Target>,
#[allow(unused)]
output_dir: PathBuf,
#[allow(unused)]
work_dir: PathBuf,
}
pub(crate) struct Executor<'a> {
// Below, `BTreeMap`s are used instead of `HashMap`s to:
// 1. allow the executor to remove a single element easily.
// 2. give a deterministic execution order.
// The set of targets that have no unfinished dependencies, have been
// dispatched to the workers and are waiting status.
pending_targets: BTreeMap<String, Arc<Target>>,
// The set of targets waiting on dependencies to finish.
waiting_targets: BTreeMap<String, Arc<Target>>,
// The set of targets that have completed. HashSet is used here because
// ordering does not matter and constant time lookups are favored.
completed_targets: HashSet<String>,
dispatch_tx: async_channel::Sender<ExecutionContext>,
status_rx: mpsc::UnboundedReceiver<ExecutionStatusMsg>,
project: &'a Project,
workers: Vec<JoinHandle<()>>,
}
impl<'a> Executor<'a> {
#[allow(unused)]
pub fn new(project: &'a Project, num_workers: usize) -> Executor<'a> {
let (dispatch_tx, dispatch_rx) = async_channel::unbounded();
let (status_tx, status_rx) = mpsc::unbounded_channel();
let workers = (0..num_workers)
.map(|_| Worker::spawn(status_tx.clone(), dispatch_rx.clone()))
.collect();
Executor {
pending_targets: BTreeMap::new(),
waiting_targets: BTreeMap::new(),
completed_targets: HashSet::new(),
dispatch_tx,
status_rx,
project,
workers,
}
}
#[allow(unused)]
pub async fn close(mut self) {
// Close the dispatch channel to signal to the workers that the
// executor is shutting down.
self.dispatch_tx.close();
// Wait for the workers to terminate.
join_all(self.workers).await;
}
/// Returns true if the target needs to run.
fn is_target_dirty(&self, target: &Target) -> bool {
!self.completed_targets.contains(&target.full_name())
}
/// Returns true if any of the targets dependencies are not completed.
/// takes `completed_targets` instead of `&self` to allow it to be used
/// in places that have mutable references on `self`.
fn is_target_blocked(target: &Target, completed_targets: &HashSet<String>) -> bool {
!target
.dependencies()
.all(|dep| completed_targets.contains(dep))
}
/// Schedules a target and all it's transitive dependencies to be executed.
async fn schedule_target(&mut self, target_name: &str) -> Result<()> {
let registry = self.project.registry();
for target in registry.transitive_dependencies(target_name)? {
let target = match target {
Dependency::Resolved(target) => target,
Dependency::Unresolved(name) => {
return Err(Error::StringErrorPlaceholder(format!(
"Unresolved transitive dependency of {}: {}",
target_name, name
)));
}
};
// Only consider dirty targets.
if self.is_target_dirty(target) {
if Self::is_target_blocked(target, &self.completed_targets) {
// Blocked targets get put in the waiting set.
self.waiting_targets
.insert(target.full_name(), target.clone());
} else {
// Unblocked targets dispatched immediately.
self.dispatch_target(target.clone()).await?;
}
}
}
Ok(())
}
/// Finds all waiting targets who's dependencies are complete
/// and dispatches them
async fn dispatch_unblocked_targets(&mut self) -> Result<()> {
// Cache targets to dispatch since we can't call async functions in
// .retain(). This is a perfect use case for `BTreeMap::drain_filter`
// if/when it is stabilized.
let mut targets_to_dispatch = Vec::new();
self.waiting_targets.retain(|_name, target| {
if Self::is_target_blocked(target, &self.completed_targets) {
// If blocked, keep it in the `waiting_targets` set.
true
} else {
// If not blocked, dispatch the target and remove it from
// `waiting_targets`.
targets_to_dispatch.push(target.clone());
false
}
});
// Dispatch newly unblocked targets.
for target in targets_to_dispatch {
self.dispatch_target(target).await?;
}
Ok(())
}
async fn handle_completed_target(&mut self, target_name: &str) -> Result<()> {
// Remove from pending_targets and mark complete.
self.pending_targets.remove(target_name);
self.completed_targets.insert(target_name.to_string());
// Dispatch newly unblocked targets.
self.dispatch_unblocked_targets().await?;
Ok(())
}
async fn run_to_completion(&mut self) -> Result<()> {
loop {
let msg = self.status_rx.recv().await.ok_or_else(|| {
Error::StringErrorPlaceholder("Error receiving status messages from workers".into())
})?;
match msg.status {
ExecutionStatus::Complete => self.handle_completed_target(&msg.name).await?,
// TODO(konkers): Instead of returning errors immediately, errors should
// be aggregated and returned once all running tasks have completed.
ExecutionStatus::Failed(error_str) => {
return Err(Error::StringErrorPlaceholder(error_str));
}
// Ignore InProgress events for now.
ExecutionStatus::InProgress {
current: _current,
total: _total,
unit: _unit,
} => (),
}
// If we have no pending targets at this point, we can't make progress and should exit.
if self.pending_targets.is_empty() {
break;
}
}
Ok(())
}
#[allow(unused)]
pub async fn execute_target(&mut self, target_name: &str) -> Result<()> {
self.schedule_target(target_name).await?;
self.run_to_completion().await?;
Ok(())
}
async fn dispatch_target(&mut self, target: Arc<Target>) -> Result<()> {
self.pending_targets
.insert(target.full_name(), target.clone());
self.dispatch_tx
.send(ExecutionContext {
target,
// TODO(konkers): Add canonical location for `output_dir` and `work_dir`.
output_dir: "".into(),
work_dir: "".into(),
})
.await
.map_err(|e| Error::StringErrorPlaceholder(format!("Error dispatching target: {e}")))
}
}
struct Worker {
status_tx: mpsc::UnboundedSender<ExecutionStatusMsg>,
dispatch_rx: async_channel::Receiver<ExecutionContext>,
}
impl Worker {
fn spawn(
status_tx: mpsc::UnboundedSender<ExecutionStatusMsg>,
dispatch_rx: async_channel::Receiver<ExecutionContext>,
) -> JoinHandle<()> {
let worker = Worker {
status_tx,
dispatch_rx,
};
tokio::spawn(worker.run())
}
fn send_status(&self, target: &Target, status: ExecutionStatus) -> Result<()> {
self.status_tx
.send(ExecutionStatusMsg {
name: target.full_name(),
status,
})
.map_err(|e| Error::StringErrorPlaceholder(format!("error sending status: {e}")))
}
#[cfg(test)]
async fn handle_fake_target(&self, target: &Target, metadata: &Fake) -> Result<()> {
use crate::fake;
fake::run(target, metadata, &self.status_tx).await
}
async fn run(self) {
loop {
let Ok(context) = self.dispatch_rx.recv().await else {
// Closing of the dispatch channel indicates that the executor is shutting down.
break;
};
let res = match context.target.metadata() {
#[cfg(test)]
Metadata::Fake(fake) => self.handle_fake_target(&context.target, fake).await,
_ => self.send_status(
&context.target,
ExecutionStatus::Failed("Unsupported target type".into()),
),
};
if let Err(e) = res {
// TODO(konkers): Log error.
println!("worker terminating: {e}");
break;
}
}
}
}
#[cfg(test)]
mod tests {
use crate::fake::{Event, FakeCoordinator};
use super::*;
#[tokio::test]
async fn fake_execution() {
FakeCoordinator::get().await.reset_ticks();
let join_handle = tokio::spawn(async move {
let project = Project::load("./src/test_projects/dependency_test").unwrap();
let mut executor = Executor::new(&project, 2);
executor.execute_target("dep-test:a").await.unwrap();
});
// Give the executor time to start.
tokio::task::yield_now().await;
// No ticks have been processed so b and d have been started
// and no other progress has been made.
assert_eq!(
FakeCoordinator::get().await.clone_events(),
vec![Event::start(0, "dep-test:b"), Event::start(0, "dep-test:d"),]
.into_iter()
.collect()
);
// Tick the fake targets 100 times which is sufficient for them to
// complete.
for _ in 0..100 {
FakeCoordinator::get().await.increment_ticks(1).await;
}
// Wait for the executor to finish.
join_handle.await.unwrap();
assert_eq!(
FakeCoordinator::get().await.clone_events(),
vec![
// b, d, and e are all dispatched but we only have two workers
// so only b and d are started.
Event::start(0, "dep-test:b"),
Event::start(0, "dep-test:d"),
// Two ticks later b completes and a tick after that e is
// started now that an executor is free.
Event::end(2, "dep-test:b"),
Event::start(3, "dep-test:e"),
// Later d and e complete unblocking c which starts a tick
// later.
Event::end(4, "dep-test:d"),
Event::end(8, "dep-test:e"),
Event::start(9, "dep-test:c"),
// Once c completes, which unblocks a.
Event::end(12, "dep-test:c"),
Event::start(13, "dep-test:a"),
Event::end(14, "dep-test:a"),
]
.into_iter()
.collect()
);
}
}