blob: dca3a14c96b6e07f444871a53bff1066818f038d [file] [log] [blame]
// Copyright 2020 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::{BTreeMap, HashSet};
use std::error::Error;
use std::fmt;
use std::fmt::Write;
use std::iter::Peekable;
use std::mem::take;
#[derive(Debug, Clone)]
pub(crate) enum FlagParseError {
UnknownFlag(String),
ValueMissing(String),
ProvidedMultipleTimes(String),
ProgramNameMissing,
}
impl fmt::Display for FlagParseError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::UnknownFlag(ref flag) => write!(f, "unknown flag \"{flag}\""),
Self::ValueMissing(ref flag) => write!(f, "flag \"{flag}\" missing parameter(s)"),
Self::ProvidedMultipleTimes(ref flag) => {
write!(f, "flag \"{flag}\" can only appear once")
}
Self::ProgramNameMissing => {
write!(f, "program name (argv[0]) missing")
}
}
}
}
impl Error for FlagParseError {}
struct FlagDef<'a, T> {
name: String,
help: String,
output_storage: &'a mut Option<T>,
}
impl<'a, T> fmt::Display for FlagDef<'a, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}\t{}", self.name, self.help)
}
}
impl<'a, T> fmt::Debug for FlagDef<'a, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("FlagDef")
.field("name", &self.name)
.field("help", &self.help)
.finish()
}
}
#[derive(Debug)]
pub(crate) struct Flags<'a> {
single: BTreeMap<String, FlagDef<'a, String>>,
repeated: BTreeMap<String, FlagDef<'a, Vec<String>>>,
}
#[derive(Debug)]
pub(crate) enum ParseOutcome {
Help(String),
Parsed(Vec<String>),
}
impl<'a> Flags<'a> {
pub(crate) fn new() -> Flags<'a> {
Flags {
single: BTreeMap::new(),
repeated: BTreeMap::new(),
}
}
pub(crate) fn define_flag(
&mut self,
name: impl Into<String>,
help: impl Into<String>,
output_storage: &'a mut Option<String>,
) {
let name = name.into();
if self.repeated.contains_key(&name) {
panic!("argument \"{}\" already defined as repeated flag", name)
}
self.single.insert(
name.clone(),
FlagDef::<'a, String> {
name,
help: help.into(),
output_storage,
},
);
}
pub(crate) fn define_repeated_flag(
&mut self,
name: impl Into<String>,
help: impl Into<String>,
output_storage: &'a mut Option<Vec<String>>,
) {
let name = name.into();
if self.single.contains_key(&name) {
panic!("argument \"{}\" already defined as flag", name)
}
self.repeated.insert(
name.clone(),
FlagDef::<'a, Vec<String>> {
name,
help: help.into(),
output_storage,
},
);
}
fn help(&self, program_name: String) -> String {
let single = self.single.values().map(|fd| fd.to_string());
let repeated = self.repeated.values().map(|fd| fd.to_string());
let mut all: Vec<String> = single.chain(repeated).collect();
all.sort();
let mut help_text = String::new();
writeln!(
&mut help_text,
"Help for {program_name}: [options] -- [extra arguments]"
)
.unwrap();
for line in all {
writeln!(&mut help_text, "\t{line}").unwrap();
}
help_text
}
pub(crate) fn parse(mut self, argv: Vec<String>) -> Result<ParseOutcome, FlagParseError> {
let mut argv_iter = argv.into_iter().peekable();
let program_name = argv_iter.next().ok_or(FlagParseError::ProgramNameMissing)?;
// To check if a non-repeated flag has been set already.
let mut seen_single_flags = HashSet::<String>::new();
while let Some(flag) = argv_iter.next() {
if flag == "--help" {
return Ok(ParseOutcome::Help(self.help(program_name)));
}
if !flag.starts_with("--") {
return Err(FlagParseError::UnknownFlag(flag));
}
let mut args = consume_args(&flag, &mut argv_iter);
if flag == "--" {
return Ok(ParseOutcome::Parsed(args));
}
if args.is_empty() {
return Err(FlagParseError::ValueMissing(flag.clone()));
}
if let Some(flag_def) = self.single.get_mut(&flag) {
if args.len() > 1 || seen_single_flags.contains(&flag) {
return Err(FlagParseError::ProvidedMultipleTimes(flag.clone()));
}
let arg = args.first_mut().unwrap();
seen_single_flags.insert(flag);
*flag_def.output_storage = Some(take(arg));
continue;
}
if let Some(flag_def) = self.repeated.get_mut(&flag) {
flag_def
.output_storage
.get_or_insert_with(Vec::new)
.append(&mut args);
continue;
}
return Err(FlagParseError::UnknownFlag(flag));
}
Ok(ParseOutcome::Parsed(vec![]))
}
}
fn consume_args<I: Iterator<Item = String>>(
flag: &str,
argv_iter: &mut Peekable<I>,
) -> Vec<String> {
if flag == "--" {
// If we have found --, the rest of the iterator is just returned as-is.
argv_iter.collect()
} else {
let mut args = vec![];
while let Some(arg) = argv_iter.next_if(|s| !s.starts_with("--")) {
args.push(arg);
}
args
}
}
#[cfg(test)]
mod test {
use super::*;
fn args(args: &[&str]) -> Vec<String> {
["foo"].iter().chain(args).map(|&s| s.to_owned()).collect()
}
#[test]
fn test_flag_help() {
let mut bar = None;
let mut parser = Flags::new();
parser.define_flag("--bar", "bar help", &mut bar);
let result = parser.parse(args(&["--help"])).unwrap();
if let ParseOutcome::Help(h) = result {
assert!(h.contains("Help for foo"));
assert!(h.contains("--bar\tbar help"));
} else {
panic!("expected that --help would invoke help, instead parsed arguments")
}
}
#[test]
fn test_flag_single_repeated() {
let mut bar = None;
let mut parser = Flags::new();
parser.define_flag("--bar", "bar help", &mut bar);
let result = parser.parse(args(&["--bar", "aa", "bb"]));
if let Err(FlagParseError::ProvidedMultipleTimes(f)) = result {
assert_eq!(f, "--bar");
} else {
panic!("expected error, got {:?}", result)
}
let mut parser = Flags::new();
parser.define_flag("--bar", "bar help", &mut bar);
let result = parser.parse(args(&["--bar", "aa", "--bar", "bb"]));
if let Err(FlagParseError::ProvidedMultipleTimes(f)) = result {
assert_eq!(f, "--bar");
} else {
panic!("expected error, got {:?}", result)
}
}
#[test]
fn test_repeated_flags() {
// Test case 1) --bar something something_else should work as a repeated flag.
let mut bar = None;
let mut parser = Flags::new();
parser.define_repeated_flag("--bar", "bar help", &mut bar);
let result = parser.parse(args(&["--bar", "aa", "bb"])).unwrap();
assert!(matches!(result, ParseOutcome::Parsed(_)));
assert_eq!(bar, Some(vec!["aa".to_owned(), "bb".to_owned()]));
// Test case 2) --bar something --bar something_else should also work as a repeated flag.
bar = None;
let mut parser = Flags::new();
parser.define_repeated_flag("--bar", "bar help", &mut bar);
let result = parser.parse(args(&["--bar", "aa", "--bar", "bb"])).unwrap();
assert!(matches!(result, ParseOutcome::Parsed(_)));
assert_eq!(bar, Some(vec!["aa".to_owned(), "bb".to_owned()]));
}
#[test]
fn test_extra_args() {
let parser = Flags::new();
let result = parser.parse(args(&["--", "bb"])).unwrap();
if let ParseOutcome::Parsed(got) = result {
assert_eq!(got, vec!["bb".to_owned()])
} else {
panic!("expected correct parsing, got {:?}", result)
}
}
}