support multi-phase initialization

PiperOrigin-RevId: 811797785
diff --git a/python/_brotli.c b/python/_brotli.c
index 1f9150e..d1fd3d2 100644
--- a/python/_brotli.c
+++ b/python/_brotli.c
@@ -1,6 +1,6 @@
 #include <assert.h>
-#include <stdio.h>
 #include <stddef.h>
+#include <stdio.h>
 #include <stdlib.h>
 #include <string.h>
 #define PY_SSIZE_T_CLEAN 1
@@ -18,7 +18,12 @@
 #define PY_GET_TYPE(Obj) ((Obj)->ob_type)
 #endif
 
+static const char kErrorAttr[] = "error";
+#if PY_MAJOR_VERSION >= 3
+static const char kModuleAttr[] = "_module";
+#else
 static PyObject* BrotliError;
+#endif
 
 static const char kInvalidBufferError[] =
     "brotli: data must be a C-contiguous buffer";
@@ -195,8 +200,33 @@
 "Implementation module for the Brotli library.");
 /* clang-format on */
 
-static void set_brotli_exception(const char* msg) {
+/*
+   Sets an exception with the given message.
+
+   `context` could be a module, a module type, or an instance of a module type.
+*/
+static void set_brotli_exception(PyObject* context, const char* msg) {
+#if PY_MAJOR_VERSION >= 3
+  PyObject* error = NULL;
+  assert(context != NULL);
+
+  if (PyModule_Check(context)) {
+    error = PyObject_GetAttrString(context, kErrorAttr);
+  } else {
+    PyObject* type =
+        PyType_Check(context) ? context : (PyObject*)PY_GET_TYPE(context);
+    PyObject* module = PyObject_GetAttrString(type, kModuleAttr);
+    if (!module) return; /* AttributeError raised. */
+    error = PyObject_GetAttrString(module, kErrorAttr);
+    Py_DECREF(module);
+  }
+
+  if (error == NULL) return; /* AttributeError raised. */
+  PyErr_SetString(error, msg);
+  Py_DECREF(error);
+#else
   PyErr_SetString(BrotliError, msg);
+#endif
 }
 
 /*
@@ -316,7 +346,7 @@
 
   result = PyBytes_FromStringAndSize(NULL, len);
   if (result == NULL) {
-    PyErr_Clear();  /* OOM exception will be raised by callers. */
+    PyErr_Clear(); /* OOM exception will be raised by callers. */
     return NULL;
   }
   if (len == 0) return result;
@@ -354,7 +384,7 @@
   self->processing = 0;
   self->enc = BrotliEncoderCreateInstance(0, 0, 0);
   if (self->enc == NULL) {
-    set_brotli_exception(kCompressCreateError);
+    set_brotli_exception((PyObject*)type, kCompressCreateError);
     PY_GET_TYPE(self)->tp_free((PyObject*)self);
     return NULL;
   }
@@ -386,7 +416,7 @@
   if ((mode == 0) || (mode == 1) || (mode == 2)) {
     BrotliEncoderSetParameter(self->enc, BROTLI_PARAM_MODE, (uint32_t)mode);
   } else {
-    set_brotli_exception(kInvalidModeError);
+    set_brotli_exception((PyObject*)self, kInvalidModeError);
     self->healthy = 0;
     return -1;
   }
@@ -394,14 +424,14 @@
     BrotliEncoderSetParameter(self->enc, BROTLI_PARAM_QUALITY,
                               (uint32_t)quality);
   } else {
-    set_brotli_exception(kInvalidQualityError);
+    set_brotli_exception((PyObject*)self, kInvalidQualityError);
     self->healthy = 0;
     return -1;
   }
   if ((10 <= lgwin) && (lgwin <= 24)) {
     BrotliEncoderSetParameter(self->enc, BROTLI_PARAM_LGWIN, (uint32_t)lgwin);
   } else {
-    set_brotli_exception(kInvalidLgwinError);
+    set_brotli_exception((PyObject*)self, kInvalidLgwinError);
     self->healthy = 0;
     return -1;
   }
@@ -409,7 +439,7 @@
     BrotliEncoderSetParameter(self->enc, BROTLI_PARAM_LGBLOCK,
                               (uint32_t)lgblock);
   } else {
-    set_brotli_exception(kInvalidLgblockError);
+    set_brotli_exception((PyObject*)self, kInvalidLgblockError);
     self->healthy = 0;
     return -1;
   }
@@ -470,8 +500,8 @@
   if (ok) {
     ret = Buffer_Finish(&buffer);
     if (ret == NULL) oom = 1;
-  } else {  /* Not ok */
-    set_brotli_exception(kCompressError);
+  } else { /* Not ok */
+    set_brotli_exception((PyObject*)self, kCompressError);
   }
 
 error:
@@ -495,11 +525,11 @@
   Py_buffer input;
 
   if (self->healthy == 0) {
-    set_brotli_exception(kCompressUnhealthyError);
+    set_brotli_exception((PyObject*)self, kCompressUnhealthyError);
     return NULL;
   }
   if (self->processing != 0) {
-    set_brotli_exception(kCompressConcurrentError);
+    set_brotli_exception((PyObject*)self, kCompressConcurrentError);
     return NULL;
   }
 
@@ -522,11 +552,11 @@
   PyObject* ret = NULL;
 
   if (self->healthy == 0) {
-    set_brotli_exception(kCompressUnhealthyError);
+    set_brotli_exception((PyObject*)self, kCompressUnhealthyError);
     return NULL;
   }
   if (self->processing != 0) {
-    set_brotli_exception(kCompressConcurrentError);
+    set_brotli_exception((PyObject*)self, kCompressConcurrentError);
     return NULL;
   }
 
@@ -540,11 +570,11 @@
   PyObject* ret = NULL;
 
   if (self->healthy == 0) {
-    set_brotli_exception(kCompressUnhealthyError);
+    set_brotli_exception((PyObject*)self, kCompressUnhealthyError);
     return NULL;
   }
   if (self->processing != 0) {
-    set_brotli_exception(kCompressConcurrentError);
+    set_brotli_exception((PyObject*)self, kCompressConcurrentError);
     return NULL;
   }
 
@@ -578,7 +608,7 @@
 
   self->dec = BrotliDecoderCreateInstance(0, 0, 0);
   if (self->dec == NULL) {
-    set_brotli_exception(kDecompressCreateError);
+    set_brotli_exception((PyObject*)type, kDecompressCreateError);
     PY_GET_TYPE(self)->tp_free((PyObject*)self);
     return NULL;
   }
@@ -633,11 +663,11 @@
   int oom = 0;
 
   if (self->healthy == 0) {
-    set_brotli_exception(kDecompressUnhealthyError);
+    set_brotli_exception((PyObject*)self, kDecompressUnhealthyError);
     return NULL;
   }
   if (self->processing != 0) {
-    set_brotli_exception(kDecompressConcurrentError);
+    set_brotli_exception((PyObject*)self, kDecompressConcurrentError);
     return NULL;
   }
 
@@ -654,7 +684,7 @@
 
   if (self->unconsumed_data_length > 0) {
     if (input.len > 0) {
-      set_brotli_exception(kDecompressSinkError);
+      set_brotli_exception((PyObject*)self, kDecompressSinkError);
       goto finally;
     }
     next_in = self->unconsumed_data;
@@ -693,7 +723,7 @@
   if (oom) {
     goto finally;
   } else if (result == BROTLI_DECODER_RESULT_ERROR) {
-    set_brotli_exception(kDecompressError);
+    set_brotli_exception((PyObject*)self, kDecompressError);
     goto finally;
   }
 
@@ -712,7 +742,7 @@
 
   if ((result == BROTLI_DECODER_RESULT_SUCCESS) && (avail_in > 0)) {
     /* TODO(eustas): Add API to ignore / fetch unused "tail"? */
-    set_brotli_exception(kDecompressError);
+    set_brotli_exception((PyObject*)self, kDecompressError);
     goto finally;
   }
 
@@ -742,11 +772,11 @@
 
 static PyObject* brotli_Decompressor_is_finished(PyBrotli_Decompressor* self) {
   if (self->healthy == 0) {
-    set_brotli_exception(kDecompressUnhealthyError);
+    set_brotli_exception((PyObject*)self, kDecompressUnhealthyError);
     return NULL;
   }
   if (self->processing != 0) {
-    set_brotli_exception(kDecompressConcurrentError);
+    set_brotli_exception((PyObject*)self, kDecompressConcurrentError);
     return NULL;
   }
   if (BrotliDecoderIsFinished(self->dec)) {
@@ -759,11 +789,11 @@
 static PyObject* brotli_Decompressor_can_accept_more_data(
     PyBrotli_Decompressor* self) {
   if (self->healthy == 0) {
-    set_brotli_exception(kDecompressUnhealthyError);
+    set_brotli_exception((PyObject*)self, kDecompressUnhealthyError);
     return NULL;
   }
   if (self->processing != 0) {
-    set_brotli_exception(kDecompressConcurrentError);
+    set_brotli_exception((PyObject*)self, kDecompressConcurrentError);
     return NULL;
   }
   if (self->unconsumed_data_length > 0) {
@@ -832,7 +862,7 @@
   if (oom) {
     goto finally;
   } else if (result != BROTLI_DECODER_RESULT_SUCCESS || available_in > 0) {
-    set_brotli_exception(kDecompressError);
+    set_brotli_exception((PyObject*)self, kDecompressError);
     goto finally;
   }
 
@@ -849,6 +879,8 @@
 
 /* Module definition */
 
+static int init_brotli_mod(PyObject* m);
+
 static PyMethodDef brotli_methods[] = {
     {"decompress", (PyCFunction)brotli_decompress, METH_VARARGS | METH_KEYWORDS,
      brotli_decompress__doc__},
@@ -877,16 +909,29 @@
 
 #if PY_MAJOR_VERSION >= 3
 
+#if PY_MINOR_VERSION >= 5
+static PyModuleDef_Slot brotli_mod_slots[] = {
+    {Py_mod_exec, init_brotli_mod},
+#if PY_MINOR_VERSION >= 12
+    {Py_mod_multiple_interpreters, Py_MOD_PER_INTERPRETER_GIL_SUPPORTED},
+#endif
+    {0, NULL}};
+#endif
+
 static struct PyModuleDef brotli_module = {
     PyModuleDef_HEAD_INIT,
     "_brotli",      /* m_name */
     brotli_doc,     /* m_doc */
     0,              /* m_size */
     brotli_methods, /* m_methods */
-    NULL,           /* m_reload */
-    NULL,           /* m_traverse */
-    NULL,           /* m_clear */
-    NULL            /* m_free */
+#if PY_MINOR_VERSION >= 5
+    brotli_mod_slots, /* m_slots */
+#else
+    NULL, /* m_reload */
+#endif
+    NULL, /* m_traverse */
+    NULL, /* m_clear */
+    NULL  /* m_free */
 };
 
 static PyType_Slot brotli_Compressor_slots[] = {
@@ -915,6 +960,8 @@
     "brotli.Decompressor", sizeof(PyBrotli_Decompressor), 0,
     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, brotli_Decompressor_slots};
 
+PyMODINIT_FUNC PyInit__brotli(void) { return PyModuleDef_Init(&brotli_module); }
+
 #else
 
 static PyTypeObject brotli_CompressorType = {
@@ -999,18 +1046,11 @@
     brotli_Decompressor_new,                 /* tp_new */
 };
 
-#endif
+PyMODINIT_FUNC init_brotli(void) {
+  PyObject* m = Py_InitModule3("_brotli", brotli_methods, brotli_doc);
+  (void)init_brotli_mod(m);
+}
 
-#if PY_MAJOR_VERSION >= 3
-#define INIT_BROTLI PyInit__brotli
-#define CREATE_BROTLI PyModule_Create(&brotli_module)
-#define RETURN_BROTLI return m
-#define RETURN_NULL return NULL
-#else
-#define INIT_BROTLI init_brotli
-#define CREATE_BROTLI Py_InitModule3("_brotli", brotli_methods, brotli_doc)
-#define RETURN_BROTLI return
-#define RETURN_NULL return
 #endif
 
 /* Emulates PyModule_AddObject */
@@ -1027,8 +1067,7 @@
 #endif
 }
 
-PyMODINIT_FUNC INIT_BROTLI(void) {
-  PyObject* m = CREATE_BROTLI;
+static int init_brotli_mod(PyObject* m) {
   PyObject* error_type = NULL;
   PyObject* compressor_type = NULL;
   PyObject* decompressor_type = NULL;
@@ -1039,27 +1078,35 @@
                                          brotli_error_doc, NULL, NULL);
   if (error_type == NULL) goto error;
 
-  if (RegisterObject(m, "error", error_type) < 0) goto error;
+  if (RegisterObject(m, kErrorAttr, error_type) < 0) goto error;
+#if PY_MAJOR_VERSION < 3
   /* Assumption: pointer is used only while module is alive and well. */
   BrotliError = error_type;
+#endif
   error_type = NULL;
 
 #if PY_MAJOR_VERSION >= 3
   compressor_type = PyType_FromSpec(&brotli_Compressor_spec);
   decompressor_type = PyType_FromSpec(&brotli_Decompressor_spec);
 #else
-  compressor_type = &brotli_CompressorType;
+  compressor_type = (PyObject*)&brotli_CompressorType;
   Py_INCREF(compressor_type);
-  decompressor_type = &brotli_DecompressorType;
+  decompressor_type = (PyObject*)&brotli_DecompressorType;
   Py_INCREF(decompressor_type);
 #endif
   if (compressor_type == NULL) goto error;
   if (PyType_Ready((PyTypeObject*)compressor_type) < 0) goto error;
+#if PY_MAJOR_VERSION >= 3
+  if (PyObject_SetAttrString(compressor_type, kModuleAttr, m) < 0) goto error;
+#endif
   if (RegisterObject(m, "Compressor", compressor_type) < 0) goto error;
   compressor_type = NULL;
 
   if (decompressor_type == NULL) goto error;
   if (PyType_Ready((PyTypeObject*)decompressor_type) < 0) goto error;
+#if PY_MAJOR_VERSION >= 3
+  if (PyObject_SetAttrString(decompressor_type, kModuleAttr, m) < 0) goto error;
+#endif
   if (RegisterObject(m, "Decompressor", decompressor_type) < 0) goto error;
   decompressor_type = NULL;
 
@@ -1073,7 +1120,7 @@
            (decoderVersion >> 12) & 0xFFF, decoderVersion & 0xFFF);
   PyModule_AddStringConstant(m, "__version__", version);
 
-  RETURN_BROTLI;
+  return 0;
 
 error:
   if (m != NULL) {
@@ -1092,5 +1139,5 @@
     Py_DECREF(decompressor_type);
     decompressor_type = NULL;
   }
-  RETURN_NULL;
+  return -1;
 }