// Copyright 2023 The Pigweed Authors
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
//
//     https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.

use std::{
    collections::{BTreeMap, HashSet},
    path::PathBuf,
    sync::Arc,
};

use futures::future::join_all;
use tokio::{sync::mpsc, task::JoinHandle};

use crate::{
    download::Downloader,
    platform::Platform,
    registry::Dependency,
    target::{Download, Metadata},
    Error, Project, Result, Target,
};

#[cfg(test)]
use crate::target::Fake;

#[derive(Debug, Eq, PartialEq)]
pub enum ExecutionStatus {
    #[allow(unused)]
    InProgress {
        current: u64,
        total: u64,
        unit: &'static str,
    },
    #[allow(unused)]
    Complete,
    Failed(String),
}

#[derive(Debug, Eq, PartialEq)]
pub struct ExecutionStatusMsg {
    pub name: String,
    pub status: ExecutionStatus,
}

#[derive(Debug)]
pub struct ExecutionContext {
    pub target: Arc<Target>,
    pub target_platform: Platform,
    pub output_dir: PathBuf,
    pub work_dir: PathBuf,
}

pub(crate) struct Executor<'a> {
    // Below, `BTreeMap`s are used instead of `HashMap`s to:
    // 1. allow the executor to remove a single element easily.
    // 2. give a deterministic execution order.

    // The set of targets that have no unfinished dependencies, have been
    // dispatched to the workers and are waiting status.
    pending_targets: BTreeMap<String, Arc<Target>>,

    // The set of targets waiting on dependencies to finish.
    waiting_targets: BTreeMap<String, Arc<Target>>,

    // The set of targets that have completed.  HashSet is used here because
    // ordering does not matter and constant time lookups are favored.
    completed_targets: HashSet<String>,

    dispatch_tx: async_channel::Sender<ExecutionContext>,

    status_rx: mpsc::UnboundedReceiver<ExecutionStatusMsg>,

    project: &'a Project,

    workers: Vec<JoinHandle<()>>,
}

impl<'a> Executor<'a> {
    #[allow(unused)]
    pub fn new(project: &'a Project, num_workers: usize) -> Executor<'a> {
        let (dispatch_tx, dispatch_rx) = async_channel::unbounded();
        let (status_tx, status_rx) = mpsc::unbounded_channel();

        let workers = (0..num_workers)
            .map(|_| Worker::spawn(status_tx.clone(), dispatch_rx.clone()))
            .collect();

        Executor {
            pending_targets: BTreeMap::new(),
            waiting_targets: BTreeMap::new(),
            completed_targets: HashSet::new(),
            dispatch_tx,
            status_rx,
            project,
            workers,
        }
    }

    #[allow(unused)]
    pub async fn close(mut self) {
        // Close the dispatch channel to signal to the workers that the
        // executor is shutting down.
        self.dispatch_tx.close();

        // Wait for the workers to terminate.
        join_all(self.workers).await;
    }

    /// Returns true if the target needs to run.
    fn is_target_dirty(&self, target: &Target) -> bool {
        !self.completed_targets.contains(&target.full_name())
    }

    /// Returns true if any of the targets dependencies are not completed.
    /// takes `completed_targets` instead of `&self` to allow it to be used
    /// in places that have mutable references on `self`.
    fn is_target_blocked(target: &Target, completed_targets: &HashSet<String>) -> bool {
        !target
            .dependencies()
            .all(|dep| completed_targets.contains(dep))
    }

    /// Schedules a target and all it's transitive dependencies to be executed.
    async fn schedule_target(&mut self, target_name: &str) -> Result<()> {
        let registry = self.project.registry();

        for target in registry.transitive_dependencies(target_name)? {
            let target = match target {
                Dependency::Resolved(target) => target,
                Dependency::Unresolved(name) => {
                    return Err(Error::StringErrorPlaceholder(format!(
                        "Unresolved transitive dependency of {}: {}",
                        target_name, name
                    )));
                }
            };

            // Only consider dirty targets.
            if self.is_target_dirty(target) {
                if Self::is_target_blocked(target, &self.completed_targets) {
                    // Blocked targets get put in the waiting set.
                    self.waiting_targets
                        .insert(target.full_name(), target.clone());
                } else {
                    // Unblocked targets dispatched immediately.
                    self.dispatch_target(target.clone()).await?;
                }
            }
        }

        Ok(())
    }

    /// Finds all waiting targets who's dependencies are complete
    /// and dispatches them
    async fn dispatch_unblocked_targets(&mut self) -> Result<()> {
        // Cache targets to dispatch since we can't call async functions in
        // .retain().  This is a perfect use case for `BTreeMap::drain_filter`
        //  if/when it is stabilized.
        let mut targets_to_dispatch = Vec::new();

        self.waiting_targets.retain(|_name, target| {
            if Self::is_target_blocked(target, &self.completed_targets) {
                // If blocked, keep it in the `waiting_targets` set.
                true
            } else {
                // If not blocked, dispatch the target and remove it from
                // `waiting_targets`.
                targets_to_dispatch.push(target.clone());
                false
            }
        });

        // Dispatch newly unblocked targets.
        for target in targets_to_dispatch {
            self.dispatch_target(target).await?;
        }

        Ok(())
    }

    async fn handle_completed_target(&mut self, target_name: &str) -> Result<()> {
        // Remove from pending_targets and mark complete.
        self.pending_targets.remove(target_name);
        self.completed_targets.insert(target_name.to_string());

        // Dispatch newly unblocked targets.
        self.dispatch_unblocked_targets().await?;

        Ok(())
    }

    async fn run_to_completion(&mut self) -> Result<()> {
        loop {
            let msg = self.status_rx.recv().await.ok_or_else(|| {
                Error::StringErrorPlaceholder("Error receiving status messages from workers".into())
            })?;

            match msg.status {
                ExecutionStatus::Complete => self.handle_completed_target(&msg.name).await?,

                // TODO(konkers): Instead of returning errors immediately, errors should
                // be aggregated and returned once all running tasks have completed.
                ExecutionStatus::Failed(error_str) => {
                    return Err(Error::StringErrorPlaceholder(format!(
                        "Error when executing target: {error_str}",
                    )));
                }

                // Ignore InProgress events for now.
                ExecutionStatus::InProgress {
                    current: _current,
                    total: _total,
                    unit: _unit,
                } => (),
            }

            // If we have no pending targets at this point, we can't make progress and should exit.
            if self.pending_targets.is_empty() {
                break;
            }
        }

        Ok(())
    }

    #[allow(unused)]
    pub async fn execute_target(&mut self, target_name: &str) -> Result<()> {
        self.schedule_target(target_name).await?;
        self.run_to_completion().await?;

        Ok(())
    }

    async fn dispatch_target(&mut self, target: Arc<Target>) -> Result<()> {
        self.pending_targets
            .insert(target.full_name(), target.clone());
        self.dispatch_tx
            .send(ExecutionContext {
                target,
                // TODO(frolv): Allow setting platforms dynamically.
                target_platform: Platform::current(),
                // TODO(konkers): Add canonical location for `output_dir` and `work_dir`.
                output_dir: "".into(),
                work_dir: "".into(),
            })
            .await
            .map_err(|e| Error::StringErrorPlaceholder(format!("Error dispatching target: {e}")))
    }
}

struct Worker {
    status_tx: mpsc::UnboundedSender<ExecutionStatusMsg>,
    dispatch_rx: async_channel::Receiver<ExecutionContext>,
}

impl Worker {
    fn spawn(
        status_tx: mpsc::UnboundedSender<ExecutionStatusMsg>,
        dispatch_rx: async_channel::Receiver<ExecutionContext>,
    ) -> JoinHandle<()> {
        let worker = Worker {
            status_tx,
            dispatch_rx,
        };
        tokio::spawn(worker.run())
    }

    fn send_status(&self, target: &Target, status: ExecutionStatus) -> Result<()> {
        self.status_tx
            .send(ExecutionStatusMsg {
                name: target.full_name(),
                status,
            })
            .map_err(|e| Error::StringErrorPlaceholder(format!("error sending status: {e}")))
    }

    #[cfg(test)]
    async fn handle_fake_target(&self, target: &Target, metadata: &Fake) -> Result<()> {
        use crate::fake;

        fake::run(target, metadata, &self.status_tx).await
    }

    async fn handle_download_target(
        &self,
        context: &ExecutionContext,
        metadata: &Download,
    ) -> Result<()> {
        let downloader = Downloader::new(context, metadata, &self.status_tx);
        downloader.run().await
    }

    async fn run(self) {
        loop {
            let Ok(context) = self.dispatch_rx.recv().await else {
                // Closing of the dispatch channel indicates that the executor is shutting down.
                break;
            };

            let res = match context.target.metadata() {
                #[cfg(test)]
                Metadata::Fake(fake) => self.handle_fake_target(&context.target, fake).await,

                Metadata::Download(download) => {
                    self.handle_download_target(&context, download).await
                }
                Metadata::Cipd(_) | Metadata::DepOnly => self.send_status(
                    &context.target,
                    ExecutionStatus::Failed("Unsupported target type".into()),
                ),
            };

            if let Err(e) = res {
                // TODO(konkers): Log error.
                println!("worker terminating: {e}");
                break;
            }
        }
    }
}

#[cfg(test)]
mod tests {

    use crate::fake::{Event, FakeCoordinator};

    use super::*;

    #[tokio::test]
    async fn fake_execution() {
        FakeCoordinator::get().await.reset_ticks();
        let join_handle = tokio::spawn(async move {
            let project = Project::load("./src/test_projects/dependency_test").unwrap();
            let mut executor = Executor::new(&project, 2);
            executor.execute_target("dep-test:a").await.unwrap();
        });

        // Give the executor time to start.
        tokio::task::yield_now().await;

        // No ticks have been processed so b and d have been started
        // and no other progress has been made.
        assert_eq!(
            FakeCoordinator::get().await.clone_events(),
            vec![Event::start(0, "dep-test:b"), Event::start(0, "dep-test:d"),]
                .into_iter()
                .collect()
        );

        // Tick the fake targets 100 times which is sufficient for them to
        // complete.
        for _ in 0..100 {
            FakeCoordinator::get().await.increment_ticks(1).await;
        }

        // Wait for the executor to finish.
        join_handle.await.unwrap();

        assert_eq!(
            FakeCoordinator::get().await.clone_events(),
            vec![
                // b, d, and e are all dispatched but we only have two workers
                // so only b and d are started.
                Event::start(0, "dep-test:b"),
                Event::start(0, "dep-test:d"),
                // Two ticks later b completes and a tick after that e is
                // started now that an executor is free.
                Event::end(2, "dep-test:b"),
                Event::start(3, "dep-test:e"),
                // Later d and e complete unblocking c which starts a tick
                // later.
                Event::end(4, "dep-test:d"),
                Event::end(8, "dep-test:e"),
                Event::start(9, "dep-test:c"),
                // Once c completes, which unblocks a.
                Event::end(12, "dep-test:c"),
                Event::start(13, "dep-test:a"),
                Event::end(14, "dep-test:a"),
            ]
            .into_iter()
            .collect()
        );
    }
}
