Paul Wankadia | af137a1 | 2023-05-15 18:08:54 +0000 | [diff] [blame] | 1 | # Copyright 2019 The RE2 Authors. All Rights Reserved. |
| 2 | # Use of this source code is governed by a BSD-style |
| 3 | # license that can be found in the LICENSE file. |
| 4 | r"""A drop-in replacement for the re module. |
| 5 | |
| 6 | It uses RE2 under the hood, of course, so various PCRE features |
| 7 | (e.g. backreferences, look-around assertions) are not supported. |
| 8 | See https://github.com/google/re2/wiki/Syntax for the canonical |
| 9 | reference, but known syntactic "gotchas" relative to Python are: |
| 10 | |
| 11 | * PCRE supports \Z and \z; RE2 supports \z; Python supports \z, |
| 12 | but calls it \Z. You must rewrite \Z to \z in pattern strings. |
| 13 | |
| 14 | Known differences between this module's API and the re module's API: |
| 15 | |
| 16 | * The error class does not provide any error information as attributes. |
| 17 | * The Options class replaces the re module's flags with RE2's options as |
| 18 | gettable/settable properties. Please see re2.h for their documentation. |
| 19 | * The pattern string and the input string do not have to be the same type. |
| 20 | Any str will be encoded to UTF-8. |
| 21 | * The pattern string cannot be str if the options specify Latin-1 encoding. |
| 22 | |
| 23 | This module's LRU cache contains a maximum of 128 regular expression objects. |
| 24 | Each regular expression object's underlying RE2 object uses a maximum of 8MiB |
| 25 | of memory (by default). Hence, this module's LRU cache uses a maximum of 1GiB |
| 26 | of memory (by default), but in most cases, it should use much less than that. |
| 27 | """ |
| 28 | |
| 29 | import codecs |
| 30 | import functools |
| 31 | import itertools |
| 32 | |
| 33 | import _re2 |
| 34 | |
| 35 | |
| 36 | class error(Exception): |
| 37 | pass |
| 38 | |
| 39 | |
| 40 | class Options(_re2.RE2.Options): |
| 41 | |
| 42 | __slots__ = () |
| 43 | |
| 44 | NAMES = ( |
| 45 | 'max_mem', |
| 46 | 'encoding', |
| 47 | 'posix_syntax', |
| 48 | 'longest_match', |
| 49 | 'log_errors', |
| 50 | 'literal', |
| 51 | 'never_nl', |
| 52 | 'dot_nl', |
| 53 | 'never_capture', |
| 54 | 'case_sensitive', |
| 55 | 'perl_classes', |
| 56 | 'word_boundary', |
| 57 | 'one_line', |
| 58 | ) |
| 59 | |
| 60 | |
| 61 | def compile(pattern, options=None): |
| 62 | if isinstance(pattern, _Regexp): |
| 63 | if options: |
| 64 | raise error('pattern is already compiled, so ' |
| 65 | 'options may not be specified') |
| 66 | pattern = pattern._pattern |
| 67 | options = options or Options() |
| 68 | values = tuple(getattr(options, name) for name in Options.NAMES) |
| 69 | return _Regexp._make(pattern, values) |
| 70 | |
| 71 | |
| 72 | def search(pattern, text, options=None): |
| 73 | return compile(pattern, options=options).search(text) |
| 74 | |
| 75 | |
| 76 | def match(pattern, text, options=None): |
| 77 | return compile(pattern, options=options).match(text) |
| 78 | |
| 79 | |
| 80 | def fullmatch(pattern, text, options=None): |
| 81 | return compile(pattern, options=options).fullmatch(text) |
| 82 | |
| 83 | |
| 84 | def finditer(pattern, text, options=None): |
| 85 | return compile(pattern, options=options).finditer(text) |
| 86 | |
| 87 | |
| 88 | def findall(pattern, text, options=None): |
| 89 | return compile(pattern, options=options).findall(text) |
| 90 | |
| 91 | |
| 92 | def split(pattern, text, maxsplit=0, options=None): |
| 93 | return compile(pattern, options=options).split(text, maxsplit) |
| 94 | |
| 95 | |
| 96 | def subn(pattern, repl, text, count=0, options=None): |
| 97 | return compile(pattern, options=options).subn(repl, text, count) |
| 98 | |
| 99 | |
| 100 | def sub(pattern, repl, text, count=0, options=None): |
| 101 | return compile(pattern, options=options).sub(repl, text, count) |
| 102 | |
| 103 | |
| 104 | def _encode(t): |
| 105 | return t.encode(encoding='utf-8') |
| 106 | |
| 107 | |
| 108 | def _decode(b): |
| 109 | return b.decode(encoding='utf-8') |
| 110 | |
| 111 | |
| 112 | def escape(pattern): |
| 113 | if isinstance(pattern, str): |
| 114 | encoded_pattern = _encode(pattern) |
| 115 | escaped = _re2.RE2.QuoteMeta(encoded_pattern) |
| 116 | decoded_escaped = _decode(escaped) |
| 117 | return decoded_escaped |
| 118 | else: |
| 119 | escaped = _re2.RE2.QuoteMeta(pattern) |
| 120 | return escaped |
| 121 | |
| 122 | |
| 123 | def purge(): |
| 124 | return _Regexp._make.cache_clear() |
| 125 | |
| 126 | |
| 127 | _Anchor = _re2.RE2.Anchor |
| 128 | _NULL_SPAN = (-1, -1) |
| 129 | |
| 130 | |
| 131 | class _Regexp(object): |
| 132 | |
| 133 | __slots__ = ('_pattern', '_regexp') |
| 134 | |
| 135 | @classmethod |
| 136 | @functools.lru_cache(typed=True) |
| 137 | def _make(cls, pattern, values): |
| 138 | options = Options() |
| 139 | for name, value in zip(Options.NAMES, values): |
| 140 | setattr(options, name, value) |
| 141 | return cls(pattern, options) |
| 142 | |
| 143 | def __init__(self, pattern, options): |
| 144 | self._pattern = pattern |
| 145 | if isinstance(self._pattern, str): |
| 146 | if options.encoding == Options.Encoding.LATIN1: |
| 147 | raise error('string type of pattern is str, but ' |
| 148 | 'encoding specified in options is LATIN1') |
| 149 | encoded_pattern = _encode(self._pattern) |
| 150 | self._regexp = _re2.RE2(encoded_pattern, options) |
| 151 | else: |
| 152 | self._regexp = _re2.RE2(self._pattern, options) |
| 153 | if not self._regexp.ok(): |
| 154 | raise error(self._regexp.error()) |
| 155 | |
| 156 | def __getstate__(self): |
| 157 | options = {name: getattr(self.options, name) for name in Options.NAMES} |
| 158 | return self._pattern, options |
| 159 | |
| 160 | def __setstate__(self, state): |
| 161 | pattern, options = state |
| 162 | values = tuple(options[name] for name in Options.NAMES) |
| 163 | other = _Regexp._make(pattern, values) |
| 164 | self._pattern = other._pattern |
| 165 | self._regexp = other._regexp |
| 166 | |
| 167 | def _match(self, anchor, text, pos=None, endpos=None): |
| 168 | pos = 0 if pos is None else max(0, min(pos, len(text))) |
| 169 | endpos = len(text) if endpos is None else max(0, min(endpos, len(text))) |
| 170 | if pos > endpos: |
| 171 | return |
| 172 | if isinstance(text, str): |
| 173 | encoded_text = _encode(text) |
| 174 | encoded_pos = _re2.CharLenToBytes(encoded_text, 0, pos) |
| 175 | if endpos == len(text): |
| 176 | # This is the common case. |
| 177 | encoded_endpos = len(encoded_text) |
| 178 | else: |
| 179 | encoded_endpos = encoded_pos + _re2.CharLenToBytes( |
| 180 | encoded_text, encoded_pos, endpos - pos) |
| 181 | decoded_offsets = {0: 0} |
| 182 | last_offset = 0 |
| 183 | while True: |
| 184 | spans = self._regexp.Match(anchor, encoded_text, encoded_pos, |
| 185 | encoded_endpos) |
| 186 | if spans[0] == _NULL_SPAN: |
| 187 | break |
| 188 | |
| 189 | # This algorithm is linear in the length of encoded_text. Specifically, |
| 190 | # no matter how many groups there are for a given regular expression or |
| 191 | # how many iterations through the loop there are for a given generator, |
| 192 | # this algorithm uses a single, straightforward pass over encoded_text. |
| 193 | offsets = sorted(set(itertools.chain(*spans))) |
| 194 | if offsets[0] == -1: |
| 195 | offsets = offsets[1:] |
| 196 | # Discard the rest of the items because they are useless now - and we |
| 197 | # could accumulate one item per str offset in the pathological case! |
| 198 | decoded_offsets = {last_offset: decoded_offsets[last_offset]} |
| 199 | for offset in offsets: |
| 200 | decoded_offsets[offset] = ( |
| 201 | decoded_offsets[last_offset] + |
| 202 | _re2.BytesToCharLen(encoded_text, last_offset, offset)) |
| 203 | last_offset = offset |
| 204 | |
| 205 | def decode(span): |
| 206 | if span == _NULL_SPAN: |
| 207 | return span |
| 208 | return decoded_offsets[span[0]], decoded_offsets[span[1]] |
| 209 | |
| 210 | decoded_spans = [decode(span) for span in spans] |
| 211 | yield _Match(self, text, pos, endpos, decoded_spans) |
| 212 | if encoded_pos == encoded_endpos: |
| 213 | break |
| 214 | elif encoded_pos == spans[0][1]: |
| 215 | # We matched the empty string at encoded_pos and would be stuck, so |
| 216 | # in order to make forward progress, increment the str offset. |
| 217 | encoded_pos += _re2.CharLenToBytes(encoded_text, encoded_pos, 1) |
| 218 | else: |
| 219 | encoded_pos = spans[0][1] |
| 220 | else: |
| 221 | while True: |
| 222 | spans = self._regexp.Match(anchor, text, pos, endpos) |
| 223 | if spans[0] == _NULL_SPAN: |
| 224 | break |
| 225 | yield _Match(self, text, pos, endpos, spans) |
| 226 | if pos == endpos: |
| 227 | break |
| 228 | elif pos == spans[0][1]: |
| 229 | # We matched the empty string at pos and would be stuck, so in order |
| 230 | # to make forward progress, increment the bytes offset. |
| 231 | pos += 1 |
| 232 | else: |
| 233 | pos = spans[0][1] |
| 234 | |
| 235 | def search(self, text, pos=None, endpos=None): |
| 236 | return next(self._match(_Anchor.UNANCHORED, text, pos, endpos), None) |
| 237 | |
| 238 | def match(self, text, pos=None, endpos=None): |
| 239 | return next(self._match(_Anchor.ANCHOR_START, text, pos, endpos), None) |
| 240 | |
| 241 | def fullmatch(self, text, pos=None, endpos=None): |
| 242 | return next(self._match(_Anchor.ANCHOR_BOTH, text, pos, endpos), None) |
| 243 | |
| 244 | def finditer(self, text, pos=None, endpos=None): |
| 245 | return self._match(_Anchor.UNANCHORED, text, pos, endpos) |
| 246 | |
| 247 | def findall(self, text, pos=None, endpos=None): |
| 248 | empty = type(text)() |
| 249 | items = [] |
| 250 | for match in self.finditer(text, pos, endpos): |
| 251 | if not self.groups: |
| 252 | item = match.group() |
| 253 | elif self.groups == 1: |
| 254 | item = match.groups(default=empty)[0] |
| 255 | else: |
| 256 | item = match.groups(default=empty) |
| 257 | items.append(item) |
| 258 | return items |
| 259 | |
| 260 | def _split(self, cb, text, maxsplit=0): |
| 261 | if maxsplit < 0: |
| 262 | return [text], 0 |
| 263 | elif maxsplit > 0: |
| 264 | matchiter = itertools.islice(self.finditer(text), maxsplit) |
| 265 | else: |
| 266 | matchiter = self.finditer(text) |
| 267 | pieces = [] |
| 268 | end = 0 |
| 269 | numsplit = 0 |
| 270 | for match in matchiter: |
| 271 | pieces.append(text[end:match.start()]) |
| 272 | pieces.extend(cb(match)) |
| 273 | end = match.end() |
| 274 | numsplit += 1 |
| 275 | pieces.append(text[end:]) |
| 276 | return pieces, numsplit |
| 277 | |
| 278 | def split(self, text, maxsplit=0): |
| 279 | cb = lambda match: [match[group] for group in range(1, self.groups + 1)] |
| 280 | pieces, _ = self._split(cb, text, maxsplit) |
| 281 | return pieces |
| 282 | |
| 283 | def subn(self, repl, text, count=0): |
| 284 | cb = lambda match: [repl(match) if callable(repl) else match.expand(repl)] |
| 285 | empty = type(text)() |
| 286 | pieces, numsplit = self._split(cb, text, count) |
| 287 | joined_pieces = empty.join(pieces) |
| 288 | return joined_pieces, numsplit |
| 289 | |
| 290 | def sub(self, repl, text, count=0): |
| 291 | joined_pieces, _ = self.subn(repl, text, count) |
| 292 | return joined_pieces |
| 293 | |
| 294 | @property |
| 295 | def pattern(self): |
| 296 | return self._pattern |
| 297 | |
| 298 | @property |
| 299 | def options(self): |
| 300 | return self._regexp.options() |
| 301 | |
| 302 | @property |
| 303 | def groups(self): |
| 304 | return self._regexp.NumberOfCapturingGroups() |
| 305 | |
| 306 | @property |
| 307 | def groupindex(self): |
| 308 | groups = self._regexp.NamedCapturingGroups() |
| 309 | if isinstance(self._pattern, str): |
| 310 | decoded_groups = [(_decode(group), index) for group, index in groups] |
| 311 | return dict(decoded_groups) |
| 312 | else: |
| 313 | return dict(groups) |
| 314 | |
| 315 | @property |
| 316 | def programsize(self): |
| 317 | return self._regexp.ProgramSize() |
| 318 | |
| 319 | @property |
| 320 | def reverseprogramsize(self): |
| 321 | return self._regexp.ReverseProgramSize() |
| 322 | |
| 323 | @property |
| 324 | def programfanout(self): |
| 325 | return self._regexp.ProgramFanout() |
| 326 | |
| 327 | @property |
| 328 | def reverseprogramfanout(self): |
| 329 | return self._regexp.ReverseProgramFanout() |
| 330 | |
| 331 | def possiblematchrange(self, maxlen): |
| 332 | ok, min, max = self._regexp.PossibleMatchRange(maxlen) |
| 333 | if not ok: |
| 334 | raise error('failed to compute match range') |
| 335 | return min, max |
| 336 | |
| 337 | |
| 338 | class _Match(object): |
| 339 | |
| 340 | __slots__ = ('_regexp', '_text', '_pos', '_endpos', '_spans') |
| 341 | |
| 342 | def __init__(self, regexp, text, pos, endpos, spans): |
| 343 | self._regexp = regexp |
| 344 | self._text = text |
| 345 | self._pos = pos |
| 346 | self._endpos = endpos |
| 347 | self._spans = spans |
| 348 | |
| 349 | # Python prioritises three-digit octal numbers over group escapes. |
| 350 | # For example, \100 should not be handled the same way as \g<10>0. |
| 351 | _OCTAL_RE = compile('\\\\[0-7][0-7][0-7]') |
| 352 | |
| 353 | # Python supports \1 through \99 (inclusive) and \g<...> syntax. |
| 354 | _GROUP_RE = compile('\\\\[1-9][0-9]?|\\\\g<\\w+>') |
| 355 | |
| 356 | @classmethod |
| 357 | @functools.lru_cache(typed=True) |
| 358 | def _split(cls, template): |
| 359 | if isinstance(template, str): |
| 360 | backslash = '\\' |
| 361 | else: |
| 362 | backslash = b'\\' |
| 363 | empty = type(template)() |
| 364 | pieces = [empty] |
| 365 | index = template.find(backslash) |
| 366 | while index != -1: |
| 367 | piece, template = template[:index], template[index:] |
| 368 | pieces[-1] += piece |
| 369 | octal_match = cls._OCTAL_RE.match(template) |
| 370 | group_match = cls._GROUP_RE.match(template) |
| 371 | if (not octal_match) and group_match: |
| 372 | index = group_match.end() |
| 373 | piece, template = template[:index], template[index:] |
| 374 | pieces.extend((piece, empty)) |
| 375 | else: |
| 376 | # 2 isn't enough for \o, \x, \N, \u and \U escapes, but none of those |
| 377 | # should contain backslashes, so break them here and then fix them at |
| 378 | # the beginning of the next loop iteration or right before returning. |
| 379 | index = 2 |
| 380 | piece, template = template[:index], template[index:] |
| 381 | pieces[-1] += piece |
| 382 | index = template.find(backslash) |
| 383 | pieces[-1] += template |
| 384 | return pieces |
| 385 | |
| 386 | def expand(self, template): |
| 387 | if isinstance(template, str): |
| 388 | unescape = codecs.unicode_escape_decode |
| 389 | else: |
| 390 | unescape = codecs.escape_decode |
| 391 | empty = type(template)() |
| 392 | # Make a copy so that we don't clobber the cached pieces! |
| 393 | pieces = list(self._split(template)) |
| 394 | for index, piece in enumerate(pieces): |
| 395 | if not index % 2: |
| 396 | pieces[index], _ = unescape(piece) |
| 397 | else: |
| 398 | if len(piece) <= 3: # \1 through \99 (inclusive) |
| 399 | group = int(piece[1:]) |
| 400 | else: # \g<...> |
| 401 | group = piece[3:-1] |
| 402 | try: |
| 403 | group = int(group) |
| 404 | except ValueError: |
| 405 | pass |
| 406 | pieces[index] = self.__getitem__(group) or empty |
| 407 | joined_pieces = empty.join(pieces) |
| 408 | return joined_pieces |
| 409 | |
| 410 | def __getitem__(self, group): |
| 411 | if not isinstance(group, int): |
| 412 | try: |
| 413 | group = self._regexp.groupindex[group] |
| 414 | except KeyError: |
| 415 | raise IndexError('bad group name') |
| 416 | if not 0 <= group <= self._regexp.groups: |
| 417 | raise IndexError('bad group index') |
| 418 | span = self._spans[group] |
| 419 | if span == _NULL_SPAN: |
| 420 | return None |
| 421 | return self._text[span[0]:span[1]] |
| 422 | |
| 423 | def group(self, *groups): |
| 424 | if not groups: |
| 425 | groups = (0,) |
| 426 | items = (self.__getitem__(group) for group in groups) |
| 427 | return next(items) if len(groups) == 1 else tuple(items) |
| 428 | |
| 429 | def groups(self, default=None): |
| 430 | items = [] |
| 431 | for group in range(1, self._regexp.groups + 1): |
| 432 | item = self.__getitem__(group) |
| 433 | items.append(default if item is None else item) |
| 434 | return tuple(items) |
| 435 | |
| 436 | def groupdict(self, default=None): |
| 437 | items = [] |
| 438 | for group, index in self._regexp.groupindex.items(): |
| 439 | item = self.__getitem__(index) |
| 440 | items.append((group, default) if item is None else (group, item)) |
| 441 | return dict(items) |
| 442 | |
| 443 | def start(self, group=0): |
| 444 | if not 0 <= group <= self._regexp.groups: |
| 445 | raise IndexError('bad group index') |
| 446 | return self._spans[group][0] |
| 447 | |
| 448 | def end(self, group=0): |
| 449 | if not 0 <= group <= self._regexp.groups: |
| 450 | raise IndexError('bad group index') |
| 451 | return self._spans[group][1] |
| 452 | |
| 453 | def span(self, group=0): |
| 454 | if not 0 <= group <= self._regexp.groups: |
| 455 | raise IndexError('bad group index') |
| 456 | return self._spans[group] |
| 457 | |
| 458 | @property |
| 459 | def re(self): |
| 460 | return self._regexp |
| 461 | |
| 462 | @property |
| 463 | def string(self): |
| 464 | return self._text |
| 465 | |
| 466 | @property |
| 467 | def pos(self): |
| 468 | return self._pos |
| 469 | |
| 470 | @property |
| 471 | def endpos(self): |
| 472 | return self._endpos |
| 473 | |
| 474 | @property |
| 475 | def lastindex(self): |
| 476 | max_end = -1 |
| 477 | max_group = None |
| 478 | # We look for the rightmost right parenthesis by keeping the first group |
| 479 | # that ends at max_end because that is the leftmost/outermost group when |
| 480 | # there are nested groups! |
| 481 | for group in range(1, self._regexp.groups + 1): |
| 482 | end = self._spans[group][1] |
| 483 | if max_end < end: |
| 484 | max_end = end |
| 485 | max_group = group |
| 486 | return max_group |
| 487 | |
| 488 | @property |
| 489 | def lastgroup(self): |
| 490 | max_group = self.lastindex |
| 491 | if not max_group: |
| 492 | return None |
| 493 | for group, index in self._regexp.groupindex.items(): |
| 494 | if max_group == index: |
| 495 | return group |
| 496 | return None |
| 497 | |
| 498 | |
| 499 | class Set(object): |
| 500 | """A Pythonic wrapper around RE2::Set.""" |
| 501 | |
| 502 | __slots__ = ('_set') |
| 503 | |
| 504 | def __init__(self, anchor, options=None): |
| 505 | options = options or Options() |
| 506 | self._set = _re2.Set(anchor, options) |
| 507 | |
| 508 | @classmethod |
| 509 | def SearchSet(cls, options=None): |
| 510 | return cls(_Anchor.UNANCHORED, options=options) |
| 511 | |
| 512 | @classmethod |
| 513 | def MatchSet(cls, options=None): |
| 514 | return cls(_Anchor.ANCHOR_START, options=options) |
| 515 | |
| 516 | @classmethod |
| 517 | def FullMatchSet(cls, options=None): |
| 518 | return cls(_Anchor.ANCHOR_BOTH, options=options) |
| 519 | |
| 520 | def Add(self, pattern): |
| 521 | if isinstance(pattern, str): |
| 522 | encoded_pattern = _encode(pattern) |
| 523 | index = self._set.Add(encoded_pattern) |
| 524 | else: |
| 525 | index = self._set.Add(pattern) |
| 526 | if index == -1: |
| 527 | raise error('failed to add %r to Set' % pattern) |
| 528 | return index |
| 529 | |
| 530 | def Compile(self): |
| 531 | if not self._set.Compile(): |
| 532 | raise error('failed to compile Set') |
| 533 | |
| 534 | def Match(self, text): |
| 535 | if isinstance(text, str): |
| 536 | encoded_text = _encode(text) |
| 537 | matches = self._set.Match(encoded_text) |
| 538 | else: |
| 539 | matches = self._set.Match(text) |
| 540 | return matches or None |
| 541 | |
| 542 | |
| 543 | class Filter(object): |
| 544 | """A Pythonic wrapper around FilteredRE2.""" |
| 545 | |
| 546 | __slots__ = ('_filter', '_patterns') |
| 547 | |
| 548 | def __init__(self): |
| 549 | self._filter = _re2.Filter() |
| 550 | self._patterns = [] |
| 551 | |
| 552 | def Add(self, pattern, options=None): |
| 553 | options = options or Options() |
| 554 | if isinstance(pattern, str): |
| 555 | encoded_pattern = _encode(pattern) |
| 556 | index = self._filter.Add(encoded_pattern, options) |
| 557 | else: |
| 558 | index = self._filter.Add(pattern, options) |
| 559 | if index == -1: |
| 560 | raise error('failed to add %r to Filter' % pattern) |
| 561 | self._patterns.append(pattern) |
| 562 | return index |
| 563 | |
| 564 | def Compile(self): |
| 565 | if not self._filter.Compile(): |
| 566 | raise error('failed to compile Filter') |
| 567 | |
| 568 | def Match(self, text, potential=False): |
| 569 | if isinstance(text, str): |
| 570 | encoded_text = _encode(text) |
| 571 | matches = self._filter.Match(encoded_text, potential) |
| 572 | else: |
| 573 | matches = self._filter.Match(text, potential) |
| 574 | return matches or None |
| 575 | |
| 576 | def re(self, index): |
| 577 | if not 0 <= index < len(self._patterns): |
| 578 | raise IndexError('bad index') |
| 579 | proxy = object.__new__(_Regexp) |
| 580 | proxy._pattern = self._patterns[index] |
| 581 | proxy._regexp = self._filter.GetRE2(index) |
| 582 | return proxy |