PY: fix set_brotli_exception

PiperOrigin-RevId: 803421474
diff --git a/python/_brotli.c b/python/_brotli.c
index 7b1c227..6ddf9c8 100644
--- a/python/_brotli.c
+++ b/python/_brotli.c
@@ -36,24 +36,21 @@
 }
 #endif
 
-static void set_brotli_exception(const char* message) {
-#if PY_MAJOR_VERSION >= 3
-  PyObject* module = PyImport_ImportModule("_brotli");
-  if (module == NULL) {
-    PyErr_SetString(PyExc_RuntimeError, "Module not found");
-    return;
+static void set_brotli_exception(PyObject* context, const char* message) {
+  PyObject* module = NULL;
+  if (PyModule_Check(context)) {
+    module = context;
+  } else {
+    module = PyObject_GetAttrString(context, "__module__");
   }
+#if PY_MAJOR_VERSION >= 3
   brotli_module_state* state = get_brotli_state(module);
   PyErr_SetString(state->BrotliError, message);
 #else
   /* For Python 2, use static global */
   static PyObject* brotli_error = NULL;
   if (brotli_error == NULL) {
-    PyObject* module_dict = PyImport_GetModuleDict();
-    PyObject* module = PyDict_GetItemString(module_dict, "_brotli");
-    if (module != NULL) {
-      brotli_error = PyObject_GetAttrString(module, "error");
-    }
+    brotli_error = PyObject_GetAttrString(module, "error");
   }
   if (brotli_error != NULL) {
     PyErr_SetString(brotli_error, message);
@@ -263,76 +260,23 @@
   Py_CLEAR(buffer->list);
 }
 
-static int as_bounded_int(PyObject* o, int* result, int lower_bound,
-                          int upper_bound) {
-  long value = PyInt_AsLong(o);
-  if ((value < (long)lower_bound) || (value > (long)upper_bound)) {
+typedef struct {
+  int8_t value;
+  int8_t ok; /* -1 default, 0 error, 1 success */
+} brotli_maybe_int;
+
+static int int_convertor(PyObject* obj, brotli_maybe_int* rv) {
+  if (!PyInt_Check(obj)) {
+    rv->ok = 0;
     return 0;
   }
-  *result = (int)value;
-  return 1;
-}
-
-static int mode_convertor(PyObject* o, BrotliEncoderMode* mode) {
-  if (!PyInt_Check(o)) {
-    set_brotli_exception("Invalid mode");
+  long value = PyInt_AsLong(obj);
+  if ((value < -128) || (value > 127)) {
+    rv->ok = 0;
     return 0;
   }
-
-  int mode_value = -1;
-  if (!as_bounded_int(o, &mode_value, 0, 255)) {
-    set_brotli_exception("Invalid mode");
-    return 0;
-  }
-  *mode = (BrotliEncoderMode)mode_value;
-  if (*mode != BROTLI_MODE_GENERIC && *mode != BROTLI_MODE_TEXT &&
-      *mode != BROTLI_MODE_FONT) {
-    set_brotli_exception("Invalid mode");
-    return 0;
-  }
-
-  return 1;
-}
-
-static int quality_convertor(PyObject* o, int* quality) {
-  if (!PyInt_Check(o)) {
-    set_brotli_exception("Invalid quality");
-    return 0;
-  }
-
-  if (!as_bounded_int(o, quality, 0, 11)) {
-    set_brotli_exception("Invalid quality. Range is 0 to 11.");
-    return 0;
-  }
-
-  return 1;
-}
-
-static int lgwin_convertor(PyObject* o, int* lgwin) {
-  if (!PyInt_Check(o)) {
-    set_brotli_exception("Invalid lgwin");
-    return 0;
-  }
-
-  if (!as_bounded_int(o, lgwin, 10, 24)) {
-    set_brotli_exception("Invalid lgwin. Range is 10 to 24.");
-    return 0;
-  }
-
-  return 1;
-}
-
-static int lgblock_convertor(PyObject* o, int* lgblock) {
-  if (!PyInt_Check(o)) {
-    set_brotli_exception("Invalid lgblock");
-    return 0;
-  }
-
-  if (!as_bounded_int(o, lgblock, 0, 24) || (*lgblock != 0 && *lgblock < 16)) {
-    set_brotli_exception("Invalid lgblock. Can be 0 or in range 16 to 24.");
-    return 0;
-  }
-
+  rv->value = (int8_t)value;
+  rv->ok = 1;
   return 1;
 }
 
@@ -436,31 +380,60 @@
 
 static int brotli_Compressor_init(brotli_Compressor* self, PyObject* args,
                                   PyObject* keywds) {
-  BrotliEncoderMode mode = (BrotliEncoderMode)-1;
-  int quality = -1;
-  int lgwin = -1;
-  int lgblock = -1;
+  brotli_maybe_int mode = {-1, -1};
+  brotli_maybe_int quality = {-1, -1};
+  brotli_maybe_int lgwin = {-1, -1};
+  brotli_maybe_int lgblock = {-1, -1};
   int ok;
 
   static const char* kwlist[] = {"mode", "quality", "lgwin", "lgblock", NULL};
 
-  ok = PyArg_ParseTupleAndKeywords(
-      args, keywds, "|O&O&O&O&:Compressor", (char**)kwlist, &mode_convertor,
-      &mode, &quality_convertor, &quality, &lgwin_convertor, &lgwin,
-      &lgblock_convertor, &lgblock);
+  ok = PyArg_ParseTupleAndKeywords(args, keywds, "|O&O&O&O&:Compressor",
+                                   (char**)kwlist, &int_convertor, &mode,
+                                   &int_convertor, &quality, &int_convertor,
+                                   &lgwin, &int_convertor, &lgblock);
+  if (!mode.ok || ((mode.ok == 1) && (mode.value != BROTLI_MODE_GENERIC &&
+                                      mode.value != BROTLI_MODE_TEXT &&
+                                      mode.value != BROTLI_MODE_FONT))) {
+    set_brotli_exception((PyObject*)self, "Invalid mode");
+    return -1;
+  }
+  if (!quality.ok ||
+      ((quality.ok == 1) && (quality.value < 0 || quality.value > 11))) {
+    set_brotli_exception((PyObject*)self, "Invalid quality; range is 0 to 11");
+    return -1;
+  }
+  if (!lgwin.ok ||
+      ((lgwin.ok == 1) && (lgwin.value < 10 || lgwin.value > 24))) {
+    set_brotli_exception((PyObject*)self, "Invalid lgwin; range is 10 to 24");
+    return -1;
+  }
+  if (!lgblock.ok ||
+      ((lgblock.ok == 1) &&
+       ((lgblock.value < 16 && lgblock.value != 0) || lgblock.value > 24))) {
+    set_brotli_exception((PyObject*)self,
+                         "Invalid lgblock; range is 16 to 24, or 0");
+    return -1;
+  }
   if (!ok) return -1;
   if (!self->enc) return -1;
 
-  if ((int)mode != -1)
-    BrotliEncoderSetParameter(self->enc, BROTLI_PARAM_MODE, (uint32_t)mode);
-  if (quality != -1)
+  if (mode.ok == 1) {
+    BrotliEncoderSetParameter(self->enc, BROTLI_PARAM_MODE,
+                              (uint32_t)mode.value);
+  }
+  if (quality.ok == 1) {
     BrotliEncoderSetParameter(self->enc, BROTLI_PARAM_QUALITY,
-                              (uint32_t)quality);
-  if (lgwin != -1)
-    BrotliEncoderSetParameter(self->enc, BROTLI_PARAM_LGWIN, (uint32_t)lgwin);
-  if (lgblock != -1)
+                              (uint32_t)quality.value);
+  }
+  if (lgwin.ok == 1) {
+    BrotliEncoderSetParameter(self->enc, BROTLI_PARAM_LGWIN,
+                              (uint32_t)lgwin.value);
+  }
+  if (lgblock.ok == 1) {
     BrotliEncoderSetParameter(self->enc, BROTLI_PARAM_LGBLOCK,
-                              (uint32_t)lgblock);
+                              (uint32_t)lgblock.value);
+  }
 
   return 0;
 }
@@ -516,6 +489,7 @@
 
 error:
   set_brotli_exception(
+      (PyObject*)self,
       "BrotliEncoderCompressStream failed while processing the stream");
   ret = NULL;
 
@@ -554,6 +528,7 @@
 
 error:
   set_brotli_exception(
+      (PyObject*)self,
       "BrotliEncoderCompressStream failed while flushing the stream");
   ret = NULL;
 finally:
@@ -597,6 +572,7 @@
 
 error:
   set_brotli_exception(
+      (PyObject*)self,
       "BrotliEncoderCompressStream failed while finishing the stream");
   ret = NULL;
 finally:
@@ -852,6 +828,7 @@
   if (self->unconsumed_data_length > 0) {
     if (input.len > 0) {
       set_brotli_exception(
+          (PyObject*)self,
           "process called with data when accept_more_data is False");
       ret = NULL;
       goto finally;
@@ -870,6 +847,7 @@
 
 error:
   set_brotli_exception(
+      (PyObject*)self,
       "BrotliDecoderDecompressStream failed while processing the stream");
   ret = NULL;
 
@@ -896,6 +874,7 @@
 static PyObject* brotli_Decompressor_is_finished(brotli_Decompressor* self) {
   if (!self->dec) {
     set_brotli_exception(
+        (PyObject*)self,
         "BrotliDecoderState is NULL while checking is_finished");
     return NULL;
   }
@@ -1075,7 +1054,7 @@
 
 error:
   BlocksOutputBuffer_OnError(&buffer);
-  set_brotli_exception("BrotliDecompress failed");
+  set_brotli_exception((PyObject*)self, "BrotliDecompress failed");
   ret = NULL;
 
 finally: