pw_presubmit: Add sticky comments to keep-sorted

Change-Id: Ib7904a3f47fcadc75013fb71bb59b05b0d5e11dd
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/116427
Pigweed-Auto-Submit: Rob Mohr <mohrr@google.com>
Commit-Queue: Auto-Submit <auto-submit@pigweed.google.com.iam.gserviceaccount.com>
Reviewed-by: Wyatt Hepler <hepler@google.com>
diff --git a/pw_presubmit/docs.rst b/pw_presubmit/docs.rst
index 2efc87f..b0ef2c7 100644
--- a/pw_presubmit/docs.rst
+++ b/pw_presubmit/docs.rst
@@ -175,15 +175,38 @@
 comma-separated list of prefixes. The list below will be kept in this order.
 Neither commas nor whitespace are supported in prefixes.
 
+.. code-block::
+
   # keep-sorted: start ignore-prefix=',"
   'bar',
   "baz",
   'foo',
   # keep-sorted: end
 
+Inline comments are assumed to be associated with the following line. For
+example, the following is already sorted. This can be disabled with
+``sticky-comments=no``.
+
+.. todo-check: disable
+
+.. code-block::
+
+  # keep-sorted: start
+  # TODO(b/1234) Fix this.
+  bar,
+  # TODO(b/5678) Also fix this.
+  foo,
+  # keep-sorted: end
+
+.. todo-check: enable
+
+By default, the prefix of the keep-sorted line is assumed to be the comment
+marker used by any inline comments. This can be overridden by adding lines like
+``sticky-comments=%,#`` to the start line.
+
 The presubmit check will suggest fixes using ``pw keep-sorted --fix``.
 
-Future versions may support multiline list items.
+Future versions may support additional multiline list items.
 
 .gitmodules
 ^^^^^^^^^^^
diff --git a/pw_presubmit/py/keep_sorted_test.py b/pw_presubmit/py/keep_sorted_test.py
index 84550ec..34e6878 100644
--- a/pw_presubmit/py/keep_sorted_test.py
+++ b/pw_presubmit/py/keep_sorted_test.py
@@ -27,6 +27,7 @@
 END = keep_sorted.END
 
 # pylint: disable=attribute-defined-outside-init
+# pylint: disable=too-many-public-methods
 
 
 class TestKeepSorted(unittest.TestCase):
@@ -153,6 +154,97 @@
                   f'a\nB\nfooB\nbarc\n{END}\n')
         self.ctx.fail.assert_not_called()
 
+    def test_python_comment_marks_sorted(self) -> None:
+        self._run(f'# {START}\n1\n2\n# {END}\n')
+        self.ctx.fail.assert_not_called()
+
+    def test_python_comment_marks_not_sorted(self) -> None:
+        self._run(f'# {START}\n2\n1\n# {END}\n')
+        self.ctx.fail.assert_called()
+        self.assertEqual(self.contents, f'# {START}\n1\n2\n# {END}\n')
+
+    def test_python_comment_sticky_sorted(self) -> None:
+        self._run(f'# {START}\n# A\n1\n2\n# {END}\n')
+        self.ctx.fail.assert_not_called()
+
+    def test_python_comment_sticky_not_sorted(self) -> None:
+        self._run(f'# {START}\n2\n# A\n1\n# {END}\n')
+        self.ctx.fail.assert_called()
+        self.assertEqual(self.contents, f'# {START}\n# A\n1\n2\n# {END}\n')
+
+    def test_python_comment_sticky_disabled(self) -> None:
+        self._run(f'# {START} sticky-comments=no\n1\n# B\n2\n# {END}\n')
+        self.ctx.fail.assert_called()
+        self.assertEqual(
+            self.contents,
+            f'# {START} sticky-comments=no\n# B\n1\n2\n# {END}\n')
+
+    def test_cpp_comment_marks_sorted(self) -> None:
+        self._run(f'// {START}\n1\n2\n// {END}\n')
+        self.ctx.fail.assert_not_called()
+
+    def test_cpp_comment_marks_not_sorted(self) -> None:
+        self._run(f'// {START}\n2\n1\n// {END}\n')
+        self.ctx.fail.assert_called()
+        self.assertEqual(self.contents, f'// {START}\n1\n2\n// {END}\n')
+
+    def test_cpp_comment_sticky_sorted(self) -> None:
+        self._run(f'// {START}\n1\n// B\n2\n// {END}\n')
+        self.ctx.fail.assert_not_called()
+
+    def test_cpp_comment_sticky_not_sorted(self) -> None:
+        self._run(f'// {START}\n// B\n2\n1\n// {END}\n')
+        self.ctx.fail.assert_called()
+        self.assertEqual(self.contents, f'// {START}\n1\n// B\n2\n// {END}\n')
+
+    def test_cpp_comment_sticky_disabled(self) -> None:
+        self._run(f'// {START} sticky-comments=no\n1\n// B\n2\n// {END}\n')
+        self.ctx.fail.assert_called()
+        self.assertEqual(
+            self.contents,
+            f'// {START} sticky-comments=no\n// B\n1\n2\n// {END}\n')
+
+    def test_custom_comment_sticky_sorted(self) -> None:
+        self._run(f'{START} sticky-comments=%\n1\n% B\n2\n{END}\n')
+        self.ctx.fail.assert_not_called()
+
+    def test_custom_comment_sticky_not_sorted(self) -> None:
+        self._run(f'{START} sticky-comments=%\n% B\n2\n1\n{END}\n')
+        self.ctx.fail.assert_called()
+        self.assertEqual(self.contents,
+                         f'{START} sticky-comments=%\n1\n% B\n2\n{END}\n')
+
+    def test_multiline_comment_sticky_sorted(self) -> None:
+        self._run(f'# {START}\n# B\n# A\n1\n2\n# {END}\n')
+        self.ctx.fail.assert_not_called()
+
+    def test_multiline_comment_sticky_not_sorted(self) -> None:
+        self._run(f'# {START}\n# B\n# A\n2\n1\n# {END}\n')
+        self.ctx.fail.assert_called()
+        self.assertEqual(self.contents,
+                         f'# {START}\n1\n# B\n# A\n2\n# {END}\n')
+
+    def test_comment_sticky_sorted_fallback_sorted(self) -> None:
+        self._run(f'# {START}\n# A\n1\n# B\n1\n# {END}\n')
+        self.ctx.fail.assert_not_called()
+
+    def test_comment_sticky_sorted_fallback_not_sorted(self) -> None:
+        self._run(f'# {START}\n# B\n1\n# A\n1\n# {END}\n')
+        self.ctx.fail.assert_called()
+        self.assertEqual(self.contents,
+                         f'# {START}\n# A\n1\n# B\n1\n# {END}\n')
+
+    def test_comment_sticky_sorted_fallback_dupes(self) -> None:
+        self._run(f'# {START} allow-dupes\n# A\n1\n# A\n1\n# {END}\n')
+        self.ctx.fail.assert_not_called()
+
+    def test_different_comment_sticky_not_sorted(self) -> None:
+        self._run(f'# {START} sticky-comments=%\n% A\n1\n# B\n2\n# {END}\n')
+        self.ctx.fail.assert_called()
+        self.assertEqual(
+            self.contents,
+            f'# {START} sticky-comments=%\n# B\n% A\n1\n2\n# {END}\n')
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/pw_presubmit/py/pw_presubmit/keep_sorted.py b/pw_presubmit/py/pw_presubmit/keep_sorted.py
index 57fd24d..93ddccd 100644
--- a/pw_presubmit/py/pw_presubmit/keep_sorted.py
+++ b/pw_presubmit/py/pw_presubmit/keep_sorted.py
@@ -36,6 +36,7 @@
 _IGNORE_CASE = re.compile(r'ignore-case', re.IGNORECASE)
 _ALLOW_DUPES = re.compile(r'allow-dupes', re.IGNORECASE)
 _IGNORE_PREFIX = re.compile(r'ignore-prefix=(\S+)', re.IGNORECASE)
+_STICKY_COMMENTS = re.compile(r'sticky-comments=(\S+)', re.IGNORECASE)
 
 # Only include these literals here so keep_sorted doesn't try to reorder later
 # test lines.
@@ -77,10 +78,28 @@
 
 
 @dataclasses.dataclass
+class _Line:
+    value: str = ''
+    sticky_comments: Sequence[str] = ()
+
+    @property
+    def full(self):
+        return ''.join((*self.sticky_comments, self.value))
+
+    def __lt__(self, other):
+        if not isinstance(other, _Line):
+            return NotImplemented
+        if self.value != other.value:
+            return self.value < other.value
+        return self.sticky_comments < other.sticky_comments
+
+
+@dataclasses.dataclass
 class _Block:
     ignore_case: bool = False
     allow_dupes: bool = False
     ignored_prefixes: Sequence[str] = dataclasses.field(default_factory=list)
+    sticky_comments: Tuple[str, ...] = ()
     start_line_number: int = -1
     start_line: str = ''
     end_line: str = ''
@@ -96,13 +115,34 @@
         self.changed: bool = False
 
     def _process_block(self, block: _Block) -> Sequence[str]:
-        lines_after_dupes: List[str] = []
-        if block.allow_dupes:
-            lines_after_dupes = block.lines
-        else:
-            lines_after_dupes = list({x: None for x in block.lines})
+        raw_lines: List[str] = block.lines
+        lines: List[_Line] = []
 
-        sort_key_funcs: List[Callable[[Tuple[str, ...]], Tuple[str, ...]]] = []
+        if block.sticky_comments:
+            comments: List[str] = []
+            for raw_line in raw_lines:
+                if raw_line.lstrip().startswith(block.sticky_comments):
+                    _LOG.debug('found sticky %s', raw_line.strip())
+                    comments.append(raw_line)
+                else:
+                    _LOG.debug('non-sticky %s', raw_line.strip())
+                    line = _Line(raw_line, tuple(comments))
+                    _LOG.debug('line %s', line)
+                    lines.append(line)
+                    comments = []
+            if comments:
+                self.ctx.fail(
+                    f'sticky comment at end of block: {comments[0].strip()}',
+                    self.path, block.start_line_number)
+
+        else:
+            lines = [_Line(x) for x in block.lines]
+
+        if not block.allow_dupes:
+            lines = list({x.full: x for x in lines}.values())
+
+        StrLinePair = Tuple[str, _Line]
+        sort_key_funcs: List[Callable[[StrLinePair], StrLinePair]] = []
 
         if block.ignored_prefixes:
 
@@ -120,18 +160,22 @@
         if block.ignore_case:
             sort_key_funcs.append(lambda val: (val[0].lower(), val[1]))
 
-        def sort_key(val):
-            vals = (val, val)
+        def sort_key(line):
+            vals = (line.value, line)
             for sort_key_func in sort_key_funcs:
                 vals = sort_key_func(vals)
             return vals
 
-        for val in lines_after_dupes:
+        for val in lines:
             _LOG.debug('For sorting: %r => %r', val, sort_key(val))
 
-        sorted_lines = sorted(lines_after_dupes, key=sort_key)
+        sorted_lines = sorted(lines, key=sort_key)
+        raw_sorted_lines: List[str] = []
+        for line in sorted_lines:
+            raw_sorted_lines.extend(line.sticky_comments)
+            raw_sorted_lines.append(line.value)
 
-        if block.lines != sorted_lines:
+        if block.lines != raw_sorted_lines:
             self.changed = True
             self.ctx.fail('keep-sorted block is not sorted', self.path,
                           block.start_line_number)
@@ -139,7 +183,7 @@
             diff = difflib.Differ()
             for dline in diff.compare(
                 [x.rstrip() for x in block.lines],
-                [x.rstrip() for x in sorted_lines],
+                [x.rstrip() for x in raw_sorted_lines],
             ):
                 if dline.startswith('-'):
                     dline = _COLOR.red(dline)
@@ -148,7 +192,7 @@
                 _LOG.info(dline)
             _LOG.info('  %s', block.end_line.rstrip())
 
-        return sorted_lines
+        return raw_sorted_lines
 
     def _parse_file(self, ins):
         block: Optional[_Block] = None
@@ -191,6 +235,19 @@
                     block.ignored_prefixes.sort(key=lambda x: (-len(x), x))
                 _LOG.debug('ignored_prefixes: %r', block.ignored_prefixes)
 
+                match = _STICKY_COMMENTS.search(line)
+                if match:
+                    if match.group(1) == 'no':
+                        block.sticky_comments = ()
+                    else:
+                        block.sticky_comments = tuple(
+                            match.group(1).split(','))
+                else:
+                    prefix = line[:start_match.start()].strip()
+                    if prefix and len(prefix) <= 3:
+                        block.sticky_comments = (prefix, )
+                _LOG.debug('sticky_comments: %s', block.sticky_comments)
+
                 block.start_line = line
                 block.start_line_number = i
                 self.all_lines.append(line)
@@ -199,6 +256,8 @@
                 remaining = _IGNORE_CASE.sub('', remaining, count=1).strip()
                 remaining = _ALLOW_DUPES.sub('', remaining, count=1).strip()
                 remaining = _IGNORE_PREFIX.sub('', remaining, count=1).strip()
+                remaining = _STICKY_COMMENTS.sub('', remaining,
+                                                 count=1).strip()
                 if remaining.strip():
                     raise KeepSortedParsingError(
                         f'unrecognized directive on keep-sorted line: '
diff --git a/pw_presubmit/py/pw_presubmit/pigweed_presubmit.py b/pw_presubmit/py/pw_presubmit/pigweed_presubmit.py
index 00cecf3..a87a16f 100755
--- a/pw_presubmit/py/pw_presubmit/pigweed_presubmit.py
+++ b/pw_presubmit/py/pw_presubmit/pigweed_presubmit.py
@@ -879,7 +879,8 @@
 
 OTHER_CHECKS = (
     # keep-sorted: start
-    bazel_test,  # TODO(b/235277910): Enable all Bazel tests when they're fixed.
+    # TODO(b/235277910): Enable all Bazel tests when they're fixed.
+    bazel_test,
     build.gn_gen_check,
     cmake_clang,
     cmake_gcc,
@@ -891,7 +892,8 @@
     gn_full_qemu_check,
     gn_gcc_build,
     npm_presubmit.npm_test,
-    oss_fuzz_build,  # Attempts to duplicate OSS-Fuzz. Currently failing.
+    # Attempts to duplicate OSS-Fuzz. Currently failing.
+    oss_fuzz_build,
     pw_transfer_integration_test,
     static_analysis,
     stm32f429i,