// Copyright 2020 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::{BTreeMap, HashSet};
use std::error::Error;
use std::fmt;
use std::fmt::Write;
use std::iter::Peekable;
use std::mem::take;

#[derive(Debug, Clone)]
pub(crate) enum FlagParseError {
    UnknownFlag(String),
    ValueMissing(String),
    ProvidedMultipleTimes(String),
    ProgramNameMissing,
}

impl fmt::Display for FlagParseError {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match self {
            Self::UnknownFlag(ref flag) => write!(f, "unknown flag \"{flag}\""),
            Self::ValueMissing(ref flag) => write!(f, "flag \"{flag}\" missing parameter(s)"),
            Self::ProvidedMultipleTimes(ref flag) => {
                write!(f, "flag \"{flag}\" can only appear once")
            }
            Self::ProgramNameMissing => {
                write!(f, "program name (argv[0]) missing")
            }
        }
    }
}
impl Error for FlagParseError {}

struct FlagDef<'a, T> {
    name: String,
    help: String,
    output_storage: &'a mut Option<T>,
}

impl<'a, T> fmt::Display for FlagDef<'a, T> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "{}\t{}", self.name, self.help)
    }
}

impl<'a, T> fmt::Debug for FlagDef<'a, T> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        f.debug_struct("FlagDef")
            .field("name", &self.name)
            .field("help", &self.help)
            .finish()
    }
}

#[derive(Debug)]
pub(crate) struct Flags<'a> {
    single: BTreeMap<String, FlagDef<'a, String>>,
    repeated: BTreeMap<String, FlagDef<'a, Vec<String>>>,
}

#[derive(Debug)]
pub(crate) enum ParseOutcome {
    Help(String),
    Parsed(Vec<String>),
}

impl<'a> Flags<'a> {
    pub(crate) fn new() -> Flags<'a> {
        Flags {
            single: BTreeMap::new(),
            repeated: BTreeMap::new(),
        }
    }

    pub(crate) fn define_flag(
        &mut self,
        name: impl Into<String>,
        help: impl Into<String>,
        output_storage: &'a mut Option<String>,
    ) {
        let name = name.into();
        if self.repeated.contains_key(&name) {
            panic!("argument \"{}\" already defined as repeated flag", name)
        }
        self.single.insert(
            name.clone(),
            FlagDef::<'a, String> {
                name,
                help: help.into(),
                output_storage,
            },
        );
    }

    pub(crate) fn define_repeated_flag(
        &mut self,
        name: impl Into<String>,
        help: impl Into<String>,
        output_storage: &'a mut Option<Vec<String>>,
    ) {
        let name = name.into();
        if self.single.contains_key(&name) {
            panic!("argument \"{}\" already defined as flag", name)
        }
        self.repeated.insert(
            name.clone(),
            FlagDef::<'a, Vec<String>> {
                name,
                help: help.into(),
                output_storage,
            },
        );
    }

    fn help(&self, program_name: String) -> String {
        let single = self.single.values().map(|fd| fd.to_string());
        let repeated = self.repeated.values().map(|fd| fd.to_string());
        let mut all: Vec<String> = single.chain(repeated).collect();
        all.sort();

        let mut help_text = String::new();
        writeln!(
            &mut help_text,
            "Help for {program_name}: [options] -- [extra arguments]"
        )
        .unwrap();
        for line in all {
            writeln!(&mut help_text, "\t{line}").unwrap();
        }
        help_text
    }

    pub(crate) fn parse(mut self, argv: Vec<String>) -> Result<ParseOutcome, FlagParseError> {
        let mut argv_iter = argv.into_iter().peekable();
        let program_name = argv_iter.next().ok_or(FlagParseError::ProgramNameMissing)?;

        // To check if a non-repeated flag has been set already.
        let mut seen_single_flags = HashSet::<String>::new();

        while let Some(flag) = argv_iter.next() {
            if flag == "--help" {
                return Ok(ParseOutcome::Help(self.help(program_name)));
            }
            if !flag.starts_with("--") {
                return Err(FlagParseError::UnknownFlag(flag));
            }
            let mut args = consume_args(&flag, &mut argv_iter);
            if flag == "--" {
                return Ok(ParseOutcome::Parsed(args));
            }
            if args.is_empty() {
                return Err(FlagParseError::ValueMissing(flag.clone()));
            }
            if let Some(flag_def) = self.single.get_mut(&flag) {
                if args.len() > 1 || seen_single_flags.contains(&flag) {
                    return Err(FlagParseError::ProvidedMultipleTimes(flag.clone()));
                }
                let arg = args.first_mut().unwrap();
                seen_single_flags.insert(flag);
                *flag_def.output_storage = Some(take(arg));
                continue;
            }
            if let Some(flag_def) = self.repeated.get_mut(&flag) {
                flag_def
                    .output_storage
                    .get_or_insert_with(Vec::new)
                    .append(&mut args);
                continue;
            }
            return Err(FlagParseError::UnknownFlag(flag));
        }
        Ok(ParseOutcome::Parsed(vec![]))
    }
}

fn consume_args<I: Iterator<Item = String>>(
    flag: &str,
    argv_iter: &mut Peekable<I>,
) -> Vec<String> {
    if flag == "--" {
        // If we have found --, the rest of the iterator is just returned as-is.
        argv_iter.collect()
    } else {
        let mut args = vec![];
        while let Some(arg) = argv_iter.next_if(|s| !s.starts_with("--")) {
            args.push(arg);
        }
        args
    }
}

#[cfg(test)]
mod test {
    use super::*;

    fn args(args: &[&str]) -> Vec<String> {
        ["foo"].iter().chain(args).map(|&s| s.to_owned()).collect()
    }

    #[test]
    fn test_flag_help() {
        let mut bar = None;
        let mut parser = Flags::new();
        parser.define_flag("--bar", "bar help", &mut bar);
        let result = parser.parse(args(&["--help"])).unwrap();
        if let ParseOutcome::Help(h) = result {
            assert!(h.contains("Help for foo"));
            assert!(h.contains("--bar\tbar help"));
        } else {
            panic!("expected that --help would invoke help, instead parsed arguments")
        }
    }

    #[test]
    fn test_flag_single_repeated() {
        let mut bar = None;
        let mut parser = Flags::new();
        parser.define_flag("--bar", "bar help", &mut bar);
        let result = parser.parse(args(&["--bar", "aa", "bb"]));
        if let Err(FlagParseError::ProvidedMultipleTimes(f)) = result {
            assert_eq!(f, "--bar");
        } else {
            panic!("expected error, got {:?}", result)
        }
        let mut parser = Flags::new();
        parser.define_flag("--bar", "bar help", &mut bar);
        let result = parser.parse(args(&["--bar", "aa", "--bar", "bb"]));
        if let Err(FlagParseError::ProvidedMultipleTimes(f)) = result {
            assert_eq!(f, "--bar");
        } else {
            panic!("expected error, got {:?}", result)
        }
    }

    #[test]
    fn test_repeated_flags() {
        // Test case 1) --bar something something_else should work as a repeated flag.
        let mut bar = None;
        let mut parser = Flags::new();
        parser.define_repeated_flag("--bar", "bar help", &mut bar);
        let result = parser.parse(args(&["--bar", "aa", "bb"])).unwrap();
        assert!(matches!(result, ParseOutcome::Parsed(_)));
        assert_eq!(bar, Some(vec!["aa".to_owned(), "bb".to_owned()]));
        // Test case 2) --bar something --bar something_else should also work as a repeated flag.
        bar = None;
        let mut parser = Flags::new();
        parser.define_repeated_flag("--bar", "bar help", &mut bar);
        let result = parser.parse(args(&["--bar", "aa", "--bar", "bb"])).unwrap();
        assert!(matches!(result, ParseOutcome::Parsed(_)));
        assert_eq!(bar, Some(vec!["aa".to_owned(), "bb".to_owned()]));
    }

    #[test]
    fn test_extra_args() {
        let parser = Flags::new();
        let result = parser.parse(args(&["--", "bb"])).unwrap();
        if let ParseOutcome::Parsed(got) = result {
            assert_eq!(got, vec!["bb".to_owned()])
        } else {
            panic!("expected correct parsing, got {:?}", result)
        }
    }
}
