blob: 8a6d985394d0390fe21d2cda6bd5722d85231bfe [file] [log] [blame]
Paul Wankadiaaf137a12023-05-15 18:08:54 +00001# Copyright 2019 The RE2 Authors. All Rights Reserved.
2# Use of this source code is governed by a BSD-style
3# license that can be found in the LICENSE file.
4r"""A drop-in replacement for the re module.
5
6It uses RE2 under the hood, of course, so various PCRE features
7(e.g. backreferences, look-around assertions) are not supported.
8See https://github.com/google/re2/wiki/Syntax for the canonical
9reference, but known syntactic "gotchas" relative to Python are:
10
11 * PCRE supports \Z and \z; RE2 supports \z; Python supports \z,
12 but calls it \Z. You must rewrite \Z to \z in pattern strings.
13
14Known differences between this module's API and the re module's API:
15
16 * The error class does not provide any error information as attributes.
17 * The Options class replaces the re module's flags with RE2's options as
18 gettable/settable properties. Please see re2.h for their documentation.
19 * The pattern string and the input string do not have to be the same type.
20 Any str will be encoded to UTF-8.
21 * The pattern string cannot be str if the options specify Latin-1 encoding.
22
23This module's LRU cache contains a maximum of 128 regular expression objects.
24Each regular expression object's underlying RE2 object uses a maximum of 8MiB
25of memory (by default). Hence, this module's LRU cache uses a maximum of 1GiB
26of memory (by default), but in most cases, it should use much less than that.
27"""
28
29import codecs
30import functools
31import itertools
32
33import _re2
34
35
36class error(Exception):
37 pass
38
39
40class Options(_re2.RE2.Options):
41
42 __slots__ = ()
43
44 NAMES = (
45 'max_mem',
46 'encoding',
47 'posix_syntax',
48 'longest_match',
49 'log_errors',
50 'literal',
51 'never_nl',
52 'dot_nl',
53 'never_capture',
54 'case_sensitive',
55 'perl_classes',
56 'word_boundary',
57 'one_line',
58 )
59
60
61def compile(pattern, options=None):
62 if isinstance(pattern, _Regexp):
63 if options:
64 raise error('pattern is already compiled, so '
65 'options may not be specified')
66 pattern = pattern._pattern
67 options = options or Options()
68 values = tuple(getattr(options, name) for name in Options.NAMES)
69 return _Regexp._make(pattern, values)
70
71
72def search(pattern, text, options=None):
73 return compile(pattern, options=options).search(text)
74
75
76def match(pattern, text, options=None):
77 return compile(pattern, options=options).match(text)
78
79
80def fullmatch(pattern, text, options=None):
81 return compile(pattern, options=options).fullmatch(text)
82
83
84def finditer(pattern, text, options=None):
85 return compile(pattern, options=options).finditer(text)
86
87
88def findall(pattern, text, options=None):
89 return compile(pattern, options=options).findall(text)
90
91
92def split(pattern, text, maxsplit=0, options=None):
93 return compile(pattern, options=options).split(text, maxsplit)
94
95
96def subn(pattern, repl, text, count=0, options=None):
97 return compile(pattern, options=options).subn(repl, text, count)
98
99
100def sub(pattern, repl, text, count=0, options=None):
101 return compile(pattern, options=options).sub(repl, text, count)
102
103
104def _encode(t):
105 return t.encode(encoding='utf-8')
106
107
108def _decode(b):
109 return b.decode(encoding='utf-8')
110
111
112def escape(pattern):
113 if isinstance(pattern, str):
114 encoded_pattern = _encode(pattern)
115 escaped = _re2.RE2.QuoteMeta(encoded_pattern)
116 decoded_escaped = _decode(escaped)
117 return decoded_escaped
118 else:
119 escaped = _re2.RE2.QuoteMeta(pattern)
120 return escaped
121
122
123def purge():
124 return _Regexp._make.cache_clear()
125
126
127_Anchor = _re2.RE2.Anchor
128_NULL_SPAN = (-1, -1)
129
130
131class _Regexp(object):
132
133 __slots__ = ('_pattern', '_regexp')
134
135 @classmethod
136 @functools.lru_cache(typed=True)
137 def _make(cls, pattern, values):
138 options = Options()
139 for name, value in zip(Options.NAMES, values):
140 setattr(options, name, value)
141 return cls(pattern, options)
142
143 def __init__(self, pattern, options):
144 self._pattern = pattern
145 if isinstance(self._pattern, str):
146 if options.encoding == Options.Encoding.LATIN1:
147 raise error('string type of pattern is str, but '
148 'encoding specified in options is LATIN1')
149 encoded_pattern = _encode(self._pattern)
150 self._regexp = _re2.RE2(encoded_pattern, options)
151 else:
152 self._regexp = _re2.RE2(self._pattern, options)
153 if not self._regexp.ok():
154 raise error(self._regexp.error())
155
156 def __getstate__(self):
157 options = {name: getattr(self.options, name) for name in Options.NAMES}
158 return self._pattern, options
159
160 def __setstate__(self, state):
161 pattern, options = state
162 values = tuple(options[name] for name in Options.NAMES)
163 other = _Regexp._make(pattern, values)
164 self._pattern = other._pattern
165 self._regexp = other._regexp
166
167 def _match(self, anchor, text, pos=None, endpos=None):
168 pos = 0 if pos is None else max(0, min(pos, len(text)))
169 endpos = len(text) if endpos is None else max(0, min(endpos, len(text)))
170 if pos > endpos:
171 return
172 if isinstance(text, str):
173 encoded_text = _encode(text)
174 encoded_pos = _re2.CharLenToBytes(encoded_text, 0, pos)
175 if endpos == len(text):
176 # This is the common case.
177 encoded_endpos = len(encoded_text)
178 else:
179 encoded_endpos = encoded_pos + _re2.CharLenToBytes(
180 encoded_text, encoded_pos, endpos - pos)
181 decoded_offsets = {0: 0}
182 last_offset = 0
183 while True:
184 spans = self._regexp.Match(anchor, encoded_text, encoded_pos,
185 encoded_endpos)
186 if spans[0] == _NULL_SPAN:
187 break
188
189 # This algorithm is linear in the length of encoded_text. Specifically,
190 # no matter how many groups there are for a given regular expression or
191 # how many iterations through the loop there are for a given generator,
192 # this algorithm uses a single, straightforward pass over encoded_text.
193 offsets = sorted(set(itertools.chain(*spans)))
194 if offsets[0] == -1:
195 offsets = offsets[1:]
196 # Discard the rest of the items because they are useless now - and we
197 # could accumulate one item per str offset in the pathological case!
198 decoded_offsets = {last_offset: decoded_offsets[last_offset]}
199 for offset in offsets:
200 decoded_offsets[offset] = (
201 decoded_offsets[last_offset] +
202 _re2.BytesToCharLen(encoded_text, last_offset, offset))
203 last_offset = offset
204
205 def decode(span):
206 if span == _NULL_SPAN:
207 return span
208 return decoded_offsets[span[0]], decoded_offsets[span[1]]
209
210 decoded_spans = [decode(span) for span in spans]
211 yield _Match(self, text, pos, endpos, decoded_spans)
212 if encoded_pos == encoded_endpos:
213 break
214 elif encoded_pos == spans[0][1]:
215 # We matched the empty string at encoded_pos and would be stuck, so
216 # in order to make forward progress, increment the str offset.
217 encoded_pos += _re2.CharLenToBytes(encoded_text, encoded_pos, 1)
218 else:
219 encoded_pos = spans[0][1]
220 else:
221 while True:
222 spans = self._regexp.Match(anchor, text, pos, endpos)
223 if spans[0] == _NULL_SPAN:
224 break
225 yield _Match(self, text, pos, endpos, spans)
226 if pos == endpos:
227 break
228 elif pos == spans[0][1]:
229 # We matched the empty string at pos and would be stuck, so in order
230 # to make forward progress, increment the bytes offset.
231 pos += 1
232 else:
233 pos = spans[0][1]
234
235 def search(self, text, pos=None, endpos=None):
236 return next(self._match(_Anchor.UNANCHORED, text, pos, endpos), None)
237
238 def match(self, text, pos=None, endpos=None):
239 return next(self._match(_Anchor.ANCHOR_START, text, pos, endpos), None)
240
241 def fullmatch(self, text, pos=None, endpos=None):
242 return next(self._match(_Anchor.ANCHOR_BOTH, text, pos, endpos), None)
243
244 def finditer(self, text, pos=None, endpos=None):
245 return self._match(_Anchor.UNANCHORED, text, pos, endpos)
246
247 def findall(self, text, pos=None, endpos=None):
248 empty = type(text)()
249 items = []
250 for match in self.finditer(text, pos, endpos):
251 if not self.groups:
252 item = match.group()
253 elif self.groups == 1:
254 item = match.groups(default=empty)[0]
255 else:
256 item = match.groups(default=empty)
257 items.append(item)
258 return items
259
260 def _split(self, cb, text, maxsplit=0):
261 if maxsplit < 0:
262 return [text], 0
263 elif maxsplit > 0:
264 matchiter = itertools.islice(self.finditer(text), maxsplit)
265 else:
266 matchiter = self.finditer(text)
267 pieces = []
268 end = 0
269 numsplit = 0
270 for match in matchiter:
271 pieces.append(text[end:match.start()])
272 pieces.extend(cb(match))
273 end = match.end()
274 numsplit += 1
275 pieces.append(text[end:])
276 return pieces, numsplit
277
278 def split(self, text, maxsplit=0):
279 cb = lambda match: [match[group] for group in range(1, self.groups + 1)]
280 pieces, _ = self._split(cb, text, maxsplit)
281 return pieces
282
283 def subn(self, repl, text, count=0):
284 cb = lambda match: [repl(match) if callable(repl) else match.expand(repl)]
285 empty = type(text)()
286 pieces, numsplit = self._split(cb, text, count)
287 joined_pieces = empty.join(pieces)
288 return joined_pieces, numsplit
289
290 def sub(self, repl, text, count=0):
291 joined_pieces, _ = self.subn(repl, text, count)
292 return joined_pieces
293
294 @property
295 def pattern(self):
296 return self._pattern
297
298 @property
299 def options(self):
300 return self._regexp.options()
301
302 @property
303 def groups(self):
304 return self._regexp.NumberOfCapturingGroups()
305
306 @property
307 def groupindex(self):
308 groups = self._regexp.NamedCapturingGroups()
309 if isinstance(self._pattern, str):
310 decoded_groups = [(_decode(group), index) for group, index in groups]
311 return dict(decoded_groups)
312 else:
313 return dict(groups)
314
315 @property
316 def programsize(self):
317 return self._regexp.ProgramSize()
318
319 @property
320 def reverseprogramsize(self):
321 return self._regexp.ReverseProgramSize()
322
323 @property
324 def programfanout(self):
325 return self._regexp.ProgramFanout()
326
327 @property
328 def reverseprogramfanout(self):
329 return self._regexp.ReverseProgramFanout()
330
331 def possiblematchrange(self, maxlen):
332 ok, min, max = self._regexp.PossibleMatchRange(maxlen)
333 if not ok:
334 raise error('failed to compute match range')
335 return min, max
336
337
338class _Match(object):
339
340 __slots__ = ('_regexp', '_text', '_pos', '_endpos', '_spans')
341
342 def __init__(self, regexp, text, pos, endpos, spans):
343 self._regexp = regexp
344 self._text = text
345 self._pos = pos
346 self._endpos = endpos
347 self._spans = spans
348
349 # Python prioritises three-digit octal numbers over group escapes.
350 # For example, \100 should not be handled the same way as \g<10>0.
351 _OCTAL_RE = compile('\\\\[0-7][0-7][0-7]')
352
353 # Python supports \1 through \99 (inclusive) and \g<...> syntax.
354 _GROUP_RE = compile('\\\\[1-9][0-9]?|\\\\g<\\w+>')
355
356 @classmethod
357 @functools.lru_cache(typed=True)
358 def _split(cls, template):
359 if isinstance(template, str):
360 backslash = '\\'
361 else:
362 backslash = b'\\'
363 empty = type(template)()
364 pieces = [empty]
365 index = template.find(backslash)
366 while index != -1:
367 piece, template = template[:index], template[index:]
368 pieces[-1] += piece
369 octal_match = cls._OCTAL_RE.match(template)
370 group_match = cls._GROUP_RE.match(template)
371 if (not octal_match) and group_match:
372 index = group_match.end()
373 piece, template = template[:index], template[index:]
374 pieces.extend((piece, empty))
375 else:
376 # 2 isn't enough for \o, \x, \N, \u and \U escapes, but none of those
377 # should contain backslashes, so break them here and then fix them at
378 # the beginning of the next loop iteration or right before returning.
379 index = 2
380 piece, template = template[:index], template[index:]
381 pieces[-1] += piece
382 index = template.find(backslash)
383 pieces[-1] += template
384 return pieces
385
386 def expand(self, template):
387 if isinstance(template, str):
388 unescape = codecs.unicode_escape_decode
389 else:
390 unescape = codecs.escape_decode
391 empty = type(template)()
392 # Make a copy so that we don't clobber the cached pieces!
393 pieces = list(self._split(template))
394 for index, piece in enumerate(pieces):
395 if not index % 2:
396 pieces[index], _ = unescape(piece)
397 else:
398 if len(piece) <= 3: # \1 through \99 (inclusive)
399 group = int(piece[1:])
400 else: # \g<...>
401 group = piece[3:-1]
402 try:
403 group = int(group)
404 except ValueError:
405 pass
406 pieces[index] = self.__getitem__(group) or empty
407 joined_pieces = empty.join(pieces)
408 return joined_pieces
409
410 def __getitem__(self, group):
411 if not isinstance(group, int):
412 try:
413 group = self._regexp.groupindex[group]
414 except KeyError:
415 raise IndexError('bad group name')
416 if not 0 <= group <= self._regexp.groups:
417 raise IndexError('bad group index')
418 span = self._spans[group]
419 if span == _NULL_SPAN:
420 return None
421 return self._text[span[0]:span[1]]
422
423 def group(self, *groups):
424 if not groups:
425 groups = (0,)
426 items = (self.__getitem__(group) for group in groups)
427 return next(items) if len(groups) == 1 else tuple(items)
428
429 def groups(self, default=None):
430 items = []
431 for group in range(1, self._regexp.groups + 1):
432 item = self.__getitem__(group)
433 items.append(default if item is None else item)
434 return tuple(items)
435
436 def groupdict(self, default=None):
437 items = []
438 for group, index in self._regexp.groupindex.items():
439 item = self.__getitem__(index)
440 items.append((group, default) if item is None else (group, item))
441 return dict(items)
442
443 def start(self, group=0):
444 if not 0 <= group <= self._regexp.groups:
445 raise IndexError('bad group index')
446 return self._spans[group][0]
447
448 def end(self, group=0):
449 if not 0 <= group <= self._regexp.groups:
450 raise IndexError('bad group index')
451 return self._spans[group][1]
452
453 def span(self, group=0):
454 if not 0 <= group <= self._regexp.groups:
455 raise IndexError('bad group index')
456 return self._spans[group]
457
458 @property
459 def re(self):
460 return self._regexp
461
462 @property
463 def string(self):
464 return self._text
465
466 @property
467 def pos(self):
468 return self._pos
469
470 @property
471 def endpos(self):
472 return self._endpos
473
474 @property
475 def lastindex(self):
476 max_end = -1
477 max_group = None
478 # We look for the rightmost right parenthesis by keeping the first group
479 # that ends at max_end because that is the leftmost/outermost group when
480 # there are nested groups!
481 for group in range(1, self._regexp.groups + 1):
482 end = self._spans[group][1]
483 if max_end < end:
484 max_end = end
485 max_group = group
486 return max_group
487
488 @property
489 def lastgroup(self):
490 max_group = self.lastindex
491 if not max_group:
492 return None
493 for group, index in self._regexp.groupindex.items():
494 if max_group == index:
495 return group
496 return None
497
498
499class Set(object):
500 """A Pythonic wrapper around RE2::Set."""
501
502 __slots__ = ('_set')
503
504 def __init__(self, anchor, options=None):
505 options = options or Options()
506 self._set = _re2.Set(anchor, options)
507
508 @classmethod
509 def SearchSet(cls, options=None):
510 return cls(_Anchor.UNANCHORED, options=options)
511
512 @classmethod
513 def MatchSet(cls, options=None):
514 return cls(_Anchor.ANCHOR_START, options=options)
515
516 @classmethod
517 def FullMatchSet(cls, options=None):
518 return cls(_Anchor.ANCHOR_BOTH, options=options)
519
520 def Add(self, pattern):
521 if isinstance(pattern, str):
522 encoded_pattern = _encode(pattern)
523 index = self._set.Add(encoded_pattern)
524 else:
525 index = self._set.Add(pattern)
526 if index == -1:
527 raise error('failed to add %r to Set' % pattern)
528 return index
529
530 def Compile(self):
531 if not self._set.Compile():
532 raise error('failed to compile Set')
533
534 def Match(self, text):
535 if isinstance(text, str):
536 encoded_text = _encode(text)
537 matches = self._set.Match(encoded_text)
538 else:
539 matches = self._set.Match(text)
540 return matches or None
541
542
543class Filter(object):
544 """A Pythonic wrapper around FilteredRE2."""
545
546 __slots__ = ('_filter', '_patterns')
547
548 def __init__(self):
549 self._filter = _re2.Filter()
550 self._patterns = []
551
552 def Add(self, pattern, options=None):
553 options = options or Options()
554 if isinstance(pattern, str):
555 encoded_pattern = _encode(pattern)
556 index = self._filter.Add(encoded_pattern, options)
557 else:
558 index = self._filter.Add(pattern, options)
559 if index == -1:
560 raise error('failed to add %r to Filter' % pattern)
561 self._patterns.append(pattern)
562 return index
563
564 def Compile(self):
565 if not self._filter.Compile():
566 raise error('failed to compile Filter')
567
568 def Match(self, text, potential=False):
569 if isinstance(text, str):
570 encoded_text = _encode(text)
571 matches = self._filter.Match(encoded_text, potential)
572 else:
573 matches = self._filter.Match(text, potential)
574 return matches or None
575
576 def re(self, index):
577 if not 0 <= index < len(self._patterns):
578 raise IndexError('bad index')
579 proxy = object.__new__(_Regexp)
580 proxy._pattern = self._patterns[index]
581 proxy._regexp = self._filter.GetRE2(index)
582 return proxy