Improve callback handling inside oneofs. (#175)

Removed PB_OLD_CALLBACK_STYLE. Added generator option
submsg_callback to have a message-level callback for
setting up callbacks inside oneofs. Added example/testcase
about using callbacks with oneofs.
diff --git a/docs/migration.rst b/docs/migration.rst
index e077a74..1fee491 100644
--- a/docs/migration.rst
+++ b/docs/migration.rst
@@ -151,6 +151,20 @@
 
 **Error indications:** Submessages do not get encoded.
 
+PB_OLD_CALLBACK_STYLE option has been removed
+---------------------------------------------
+**Rationale:** Back in 2013, function signature for callbacks was changed. The
+`PB_OLD_CALLBACK_STYLE` option allowed compatibility with old code, but
+complicated code and testing because of the different options.
+
+**Changes:** `PB_OLD_CALLBACK_STYLE` option no-longer has any effect.
+
+**Required actions:** If `PB_OLD_CALLBACK_STYLE` option was in use previously,
+function signatures must be updated to use double pointers (`void**` and
+`void * const *`).
+
+**Error indications:** Assignment from incompatible pointer type.
+
 Nanopb-0.3.9.4, 0.4.0 (2019-xx-xx)
 ==================================
 
diff --git a/generator/nanopb_generator.py b/generator/nanopb_generator.py
index d14860d..327801c 100755
--- a/generator/nanopb_generator.py
+++ b/generator/nanopb_generator.py
@@ -446,6 +446,8 @@
             self.pbtype = 'MESSAGE'
             self.ctype = self.submsgname = names_from_type_name(desc.type_name)
             self.enc_size = None # Needs to be filled in after the message type is available
+            if field_options.submsg_callback:
+                self.pbtype = 'MSG_W_CB'
         else:
             raise NotImplementedError(desc.type)
 
@@ -456,9 +458,11 @@
         result = ''
         if self.allocation == 'POINTER':
             if self.rules == 'REPEATED':
+                if self.pbtype == 'MSG_W_CB':
+                    result += '    pb_callback_t cb_' + self.name + ';\n'
                 result += '    pb_size_t ' + self.name + '_count;\n'
 
-            if self.pbtype == 'MESSAGE':
+            if self.pbtype in ['MESSAGE', 'MSG_W_CB']:
                 # Use struct definition, so recursive submessages are possible
                 result += '    struct _%s *%s;' % (self.ctype, self.name)
             elif self.pbtype == 'FIXED_LENGTH_BYTES':
@@ -472,6 +476,9 @@
         elif self.allocation == 'CALLBACK':
             result += '    %s %s;' % (self.callback_datatype, self.name)
         else:
+            if self.pbtype == 'MSG_W_CB' and self.rules in ['OPTIONAL', 'REPEATED']:
+                result += '    pb_callback_t cb_' + self.name + ';\n'
+
             if self.rules == 'OPTIONAL':
                 result += '    bool has_' + self.name + ';\n'
             elif self.rules == 'REPEATED':
@@ -501,7 +508,7 @@
         '''
 
         inner_init = None
-        if self.pbtype == 'MESSAGE':
+        if self.pbtype in ['MESSAGE', 'MSG_W_CB']:
             if null_init:
                 inner_init = '%s_init_zero' % self.ctype
             else:
@@ -574,6 +581,9 @@
             else:
                 outer_init = '{{NULL}, NULL}'
 
+        if self.pbtype == 'MSG_W_CB' and self.rules in ['REPEATED', 'OPTIONAL']:
+            outer_init = '{{NULL}, NULL}, ' + outer_init
+
         return outer_init
 
     def tags(self):
@@ -607,7 +617,7 @@
             size = 8
         elif self.allocation == 'CALLBACK':
             size = 16
-        elif self.pbtype == 'MESSAGE':
+        elif self.pbtype in ['MESSAGE', 'MSG_W_CB']:
             if str(self.submsgname) in dependencies:
                 size = dependencies[str(self.submsgname)].data_size(dependencies)
             else:
@@ -641,7 +651,7 @@
         if self.allocation != 'STATIC':
             return None
 
-        if self.pbtype == 'MESSAGE':
+        if self.pbtype in ['MESSAGE', 'MSG_W_CB']:
             encsize = None
             if str(self.submsgname) in dependencies:
                 submsg = dependencies[str(self.submsgname)]
@@ -698,12 +708,11 @@
 
         return encsize
 
-    def requires_custom_field_callback(self):
-        if self.allocation == 'CALLBACK' and self.callback_datatype != 'pb_callback_t':
-            return True
-        else:
-            return False
+    def has_callbacks(self):
+        return self.allocation == 'CALLBACK'
 
+    def requires_custom_field_callback(self):
+        return self.allocation == 'CALLBACK' and self.callback_datatype != 'pb_callback_t'
 
 class ExtensionRange(Field):
     def __init__(self, struct_name, range_start, field_options):
@@ -808,24 +817,27 @@
         self.default = None
         self.rules = 'ONEOF'
         self.anonymous = False
+        self.has_msg_cb = False
 
     def add_field(self, field):
-        if field.allocation == 'CALLBACK':
-            raise Exception("Callback fields inside of oneof are not supported"
-                            + " (field %s)" % field.name)
-
         field.union_name = self.name
         field.rules = 'ONEOF'
         field.anonymous = self.anonymous
         self.fields.append(field)
         self.fields.sort(key = lambda f: f.tag)
 
+        if field.pbtype == 'MSG_W_CB':
+            self.has_msg_cb = True
+
         # Sort by the lowest tag number inside union
         self.tag = min([f.tag for f in self.fields])
 
     def __str__(self):
         result = ''
         if self.fields:
+            if self.has_msg_cb:
+                result += '    pb_callback_t cb_' + self.name + ';\n'
+
             result += '    pb_size_t which_' + self.name + ";\n"
             result += '    union {\n'
             for f in self.fields:
@@ -846,7 +858,10 @@
         return deps
 
     def get_initializer(self, null_init):
-        return '0, {' + self.fields[0].get_initializer(null_init) + '}'
+        if self.has_msg_cb:
+            return '{{NULL}, NULL}, 0, {' + self.fields[0].get_initializer(null_init) + '}'
+        else:
+            return '0, {' + self.fields[0].get_initializer(null_init) + '}'
 
     def tags(self):
         return ''.join([f.tags() for f in self.fields])
@@ -887,6 +902,12 @@
             union_def = ' '.join('char f%d[%s];' % s for s in symbols)
             return EncodedSize(5, ['sizeof(union{%s})' % union_def])
 
+    def has_callbacks(self):
+        return bool([f for f in self.fields if f.has_callbacks()])
+
+    def requires_custom_field_callback(self):
+        return bool([f for f in self.fields if f.requires_custom_field_callback()])
+
 # ---------------------------------------------------------------------------
 #                   Generation of messages (structures)
 # ---------------------------------------------------------------------------
@@ -1042,7 +1063,7 @@
         result += ' \\\n'.join(field.fieldlist() for field in sorted(self.fields))
         result += '\n'
 
-        has_callbacks = bool([f for f in self.fields if f.allocation == 'CALLBACK'])
+        has_callbacks = bool([f for f in self.fields if f.has_callbacks()])
         if has_callbacks:
             if self.callback_function != 'pb_default_field_callback':
                 result += "extern bool %s(pb_istream_t *istream, pb_ostream_t *ostream, const pb_field_t *field);\n" % self.callback_function
@@ -1058,11 +1079,11 @@
             result += '#define %s_DEFAULT NULL\n' % self.name
 
         for field in sorted(self.fields):
-            if field.pbtype == 'MESSAGE':
+            if field.pbtype in ['MESSAGE', 'MSG_W_CB']:
                 result += "#define %s_%s_MSGTYPE %s\n" % (self.name, field.name, field.ctype)
             elif field.rules == 'ONEOF':
                 for member in field.fields:
-                    if member.pbtype == 'MESSAGE':
+                    if member.pbtype in ['MESSAGE', 'MSG_W_CB']:
                         result += "#define %s_%s_%s_MSGTYPE %s\n" % (self.name, member.union_name, member.name, member.ctype)
 
         return result
diff --git a/generator/proto/nanopb.proto b/generator/proto/nanopb.proto
index 64f1cc4..83ad9c4 100644
--- a/generator/proto/nanopb.proto
+++ b/generator/proto/nanopb.proto
@@ -103,6 +103,10 @@
   // Generate repeated field with fixed count
   optional bool fixed_count = 16 [default = false];
 
+  // Generate message-level callback that is called before decoding submessages.
+  // This can be used to set callback fields for submsgs inside oneofs.
+  optional bool submsg_callback = 22 [default = false];
+
   // Mangle long names
   optional TypenameMangling mangle_names = 17 [default = M_NONE];
 
diff --git a/pb.h b/pb.h
index 32046a2..91c1e79 100644
--- a/pb.h
+++ b/pb.h
@@ -34,10 +34,6 @@
    or to save some code space. */
 /* #define PB_WITHOUT_64BIT 1 */
 
-/* Switch back to the old-style callback function signature.
- * This was the default until nanopb-0.2.1. */
-/* #define PB_OLD_CALLBACK_STYLE 1 */
-
 /* Set the fieldinfo width for all messages using automatic width
  * selection. Valid values are 2, 4 and 8. Usually even if you need
  * to change the width manually for some reason, it is preferrable
@@ -205,18 +201,23 @@
  * submsg_fields is pointer to field descriptions */
 #define PB_LTYPE_SUBMESSAGE 0x08U
 
+/* Submessage with pre-decoding callback
+ * The pre-decoding callback is stored as pb_callback_t right before pSize.
+ * submsg_fields is pointer to field descriptions */
+#define PB_LTYPE_SUBMSG_W_CB 0x09U
+
 /* Extension pseudo-field
  * The field contains a pointer to pb_extension_t */
-#define PB_LTYPE_EXTENSION 0x09U
+#define PB_LTYPE_EXTENSION 0x0AU
 
 /* Byte array with inline, pre-allocated byffer.
  * data_size is the length of the inline, allocated buffer.
  * This differs from PB_LTYPE_BYTES by defining the element as
  * pb_byte_t[data_size] rather than pb_bytes_array_t. */
-#define PB_LTYPE_FIXED_LENGTH_BYTES 0x0AU
+#define PB_LTYPE_FIXED_LENGTH_BYTES 0x0BU
 
 /* Number of declared LTYPES */
-#define PB_LTYPES_COUNT 0x0BU
+#define PB_LTYPES_COUNT 0x0CU
 #define PB_LTYPE_MASK 0x0FU
 
 /**** Field repetition rules ****/
@@ -239,6 +240,8 @@
 #define PB_ATYPE(x) ((x) & PB_ATYPE_MASK)
 #define PB_HTYPE(x) ((x) & PB_HTYPE_MASK)
 #define PB_LTYPE(x) ((x) & PB_LTYPE_MASK)
+#define PB_LTYPE_IS_SUBMSG(x) (PB_LTYPE(x) == PB_LTYPE_SUBMESSAGE || \
+                               PB_LTYPE(x) == PB_LTYPE_SUBMSG_W_CB)
 
 /* Data type used for storing sizes of struct fields
  * and array counts.
@@ -347,19 +350,13 @@
  */
 typedef struct pb_callback_s pb_callback_t;
 struct pb_callback_s {
-#ifdef PB_OLD_CALLBACK_STYLE
-    /* Deprecated since nanopb-0.2.1 */
-    union {
-        bool (*decode)(pb_istream_t *stream, const pb_field_t *field, void *arg);
-        bool (*encode)(pb_ostream_t *stream, const pb_field_t *field, const void *arg);
-    } funcs;
-#else
-    /* New function signature, which allows modifying arg contents in callback. */
+    /* Callback functions receive a pointer to the arg field.
+     * You can access the value of the field as *arg, and modify it if needed.
+     */
     union {
         bool (*decode)(pb_istream_t *stream, const pb_field_t *field, void **arg);
         bool (*encode)(pb_ostream_t *stream, const pb_field_t *field, void * const *arg);
     } funcs;
-#endif    
     
     /* Free arg for use by callback */
     void *arg;
@@ -535,7 +532,7 @@
 
 #define PB_SIZE_OFFSET_STATIC(htype, structname, fieldname) PB_SIZE_OFFSET_ ## htype(structname, fieldname)
 #define PB_SIZE_OFFSET_POINTER(htype, structname, fieldname) PB_SIZE_OFFSET_PTR_ ## htype(structname, fieldname)
-#define PB_SIZE_OFFSET_CALLBACK(htype, structname, fieldname) 0
+#define PB_SIZE_OFFSET_CALLBACK(htype, structname, fieldname) PB_SIZE_OFFSET_CB_ ## htype(structname, fieldname)
 #define PB_SIZE_OFFSET_REQUIRED(structname, fieldname) 0
 #define PB_SIZE_OFFSET_SINGULAR(structname, fieldname) 0
 #define PB_SIZE_OFFSET_ONEOF(structname, fieldname) PB_SIZE_OFFSET_ONEOF2(structname, PB_ONEOF_NAME(FULL, fieldname), PB_ONEOF_NAME(UNION, fieldname))
@@ -550,6 +547,12 @@
 #define PB_SIZE_OFFSET_PTR_OPTIONAL(structname, fieldname) 0
 #define PB_SIZE_OFFSET_PTR_REPEATED(structname, fieldname) PB_SIZE_OFFSET_REPEATED(structname, fieldname)
 #define PB_SIZE_OFFSET_PTR_FIXARRAY(structname, fieldname) 0
+#define PB_SIZE_OFFSET_CB_REQUIRED(structname, fieldname) 0
+#define PB_SIZE_OFFSET_CB_SINGULAR(structname, fieldname) 0
+#define PB_SIZE_OFFSET_CB_ONEOF(structname, fieldname) PB_SIZE_OFFSET_ONEOF(structname, fieldname)
+#define PB_SIZE_OFFSET_CB_OPTIONAL(structname, fieldname) 0
+#define PB_SIZE_OFFSET_CB_REPEATED(structname, fieldname) 0
+#define PB_SIZE_OFFSET_CB_FIXARRAY(structname, fieldname) 0
 
 #define PB_ARRAY_SIZE_STATIC(htype, structname, fieldname) PB_ARRAY_SIZE_ ## htype(structname, fieldname)
 #define PB_ARRAY_SIZE_POINTER(htype, structname, fieldname) 1
@@ -563,7 +566,7 @@
 
 #define PB_DATA_SIZE_STATIC(htype, structname, fieldname) PB_DATA_SIZE_ ## htype(structname, fieldname)
 #define PB_DATA_SIZE_POINTER(htype, structname, fieldname) PB_DATA_SIZE_PTR_ ## htype(structname, fieldname)
-#define PB_DATA_SIZE_CALLBACK(htype, structname, fieldname) pb_membersize(structname, fieldname)
+#define PB_DATA_SIZE_CALLBACK(htype, structname, fieldname) PB_DATA_SIZE_CB_ ## htype(structname, fieldname)
 #define PB_DATA_SIZE_REQUIRED(structname, fieldname) pb_membersize(structname, fieldname)
 #define PB_DATA_SIZE_SINGULAR(structname, fieldname) pb_membersize(structname, fieldname)
 #define PB_DATA_SIZE_OPTIONAL(structname, fieldname) pb_membersize(structname, fieldname)
@@ -576,6 +579,12 @@
 #define PB_DATA_SIZE_PTR_ONEOF(structname, fieldname) pb_membersize(structname, PB_ONEOF_NAME(FULL, fieldname)[0])
 #define PB_DATA_SIZE_PTR_REPEATED(structname, fieldname) pb_membersize(structname, fieldname[0])
 #define PB_DATA_SIZE_PTR_FIXARRAY(structname, fieldname) pb_membersize(structname, fieldname[0])
+#define PB_DATA_SIZE_CB_REQUIRED(structname, fieldname) pb_membersize(structname, fieldname)
+#define PB_DATA_SIZE_CB_SINGULAR(structname, fieldname) pb_membersize(structname, fieldname)
+#define PB_DATA_SIZE_CB_OPTIONAL(structname, fieldname) pb_membersize(structname, fieldname)
+#define PB_DATA_SIZE_CB_ONEOF(structname, fieldname) pb_membersize(structname, PB_ONEOF_NAME(FULL, fieldname))
+#define PB_DATA_SIZE_CB_REPEATED(structname, fieldname) pb_membersize(structname, fieldname)
+#define PB_DATA_SIZE_CB_FIXARRAY(structname, fieldname) pb_membersize(structname, fieldname)
 
 #define PB_ONEOF_NAME(type, tuple) PB_EXPAND(PB_ONEOF_NAME_ ## type tuple)
 #define PB_ONEOF_NAME_UNION(unionname,membername,fullname) unionname
@@ -604,6 +613,7 @@
 #define PB_SUBMSG_INFO_INT32(t)
 #define PB_SUBMSG_INFO_INT64(t)
 #define PB_SUBMSG_INFO_MESSAGE(t)  PB_SUBMSG_DESCRIPTOR(t)
+#define PB_SUBMSG_INFO_MSG_W_CB(t) PB_SUBMSG_DESCRIPTOR(t)
 #define PB_SUBMSG_INFO_SFIXED32(t)
 #define PB_SUBMSG_INFO_SFIXED64(t)
 #define PB_SUBMSG_INFO_SINT32(t)
@@ -746,6 +756,7 @@
 #define PB_LTYPE_MAP_INT32              PB_LTYPE_VARINT
 #define PB_LTYPE_MAP_INT64              PB_LTYPE_VARINT
 #define PB_LTYPE_MAP_MESSAGE            PB_LTYPE_SUBMESSAGE
+#define PB_LTYPE_MAP_MSG_W_CB           PB_LTYPE_SUBMSG_W_CB
 #define PB_LTYPE_MAP_SFIXED32           PB_LTYPE_FIXED32
 #define PB_LTYPE_MAP_SFIXED64           PB_LTYPE_FIXED64
 #define PB_LTYPE_MAP_SINT32             PB_LTYPE_SVARINT
diff --git a/pb_common.c b/pb_common.c
index 9e9527d..00bc7d6 100644
--- a/pb_common.c
+++ b/pb_common.c
@@ -94,7 +94,7 @@
         iter->pData = iter->pField;
     }
 
-    if (PB_LTYPE(iter->type) == PB_LTYPE_SUBMESSAGE)
+    if (PB_LTYPE_IS_SUBMSG(iter->type))
     {
         iter->submsg_desc = iter->descriptor->submsg_info[iter->submessage_index];
     }
@@ -137,7 +137,7 @@
             iter->required_field_index++;
         }
 
-        if (PB_LTYPE(prev_type) == PB_LTYPE_SUBMESSAGE)
+        if (PB_LTYPE_IS_SUBMSG(prev_type))
         {
             iter->submessage_index++;
         }
@@ -231,20 +231,12 @@
         {
             if (istream != NULL && pCallback->funcs.decode != NULL)
             {
-#ifdef PB_OLD_CALLBACK_STYLE
-                return pCallback->funcs.decode(istream, field, pCallback->arg);
-#else
                 return pCallback->funcs.decode(istream, field, &pCallback->arg);
-#endif
             }
 
             if (ostream != NULL && pCallback->funcs.encode != NULL)
             {
-#ifdef PB_OLD_CALLBACK_STYLE
-                return pCallback->funcs.encode(ostream, field, pCallback->arg);
-#else
                 return pCallback->funcs.encode(ostream, field, &pCallback->arg);
-#endif
             }
         }
     }
diff --git a/pb_decode.c b/pb_decode.c
index 68351b5..892e688 100644
--- a/pb_decode.c
+++ b/pb_decode.c
@@ -410,6 +410,7 @@
             return pb_dec_string(stream, field);
 
         case PB_LTYPE_SUBMESSAGE:
+        case PB_LTYPE_SUBMSG_W_CB:
             return pb_dec_submessage(stream, field);
 
         case PB_LTYPE_FIXED_LENGTH_BYTES:
@@ -477,9 +478,14 @@
 
         case PB_HTYPE_ONEOF:
             *(pb_size_t*)field->pSize = field->tag;
-            if (PB_LTYPE(field->type) == PB_LTYPE_SUBMESSAGE)
+            if (PB_LTYPE_IS_SUBMSG(field->type))
             {
                 /* We memset to zero so that any callbacks are set to NULL.
+                 * This is because the callbacks might otherwise have values
+                 * from some other union field.
+                 * If callbacks are needed inside oneof field, use .proto
+                 * option submsg_callback to have a separate callback function
+                 * that can set the fields before submessage is decoded.
                  * pb_dec_submessage() will set any default values. */
                 memset(field->pData, 0, (size_t)field->data_size);
             }
@@ -538,7 +544,7 @@
     {
         *(void**)pItem = NULL;
     }
-    else if (PB_LTYPE(field->type) == PB_LTYPE_SUBMESSAGE)
+    else if (PB_LTYPE_IS_SUBMSG(field->type))
     {
         /* We memset to zero so that any callbacks are set to NULL.
          * Then set any default values. */
@@ -565,8 +571,7 @@
         case PB_HTYPE_REQUIRED:
         case PB_HTYPE_OPTIONAL:
         case PB_HTYPE_ONEOF:
-            if (PB_LTYPE(field->type) == PB_LTYPE_SUBMESSAGE &&
-                *(void**)field->pField != NULL)
+            if (PB_LTYPE_IS_SUBMSG(field->type) && *(void**)field->pField != NULL)
             {
                 /* Duplicate field, have to release the old allocation first. */
                 /* FIXME: Does this work correctly for oneofs? */
@@ -843,7 +848,7 @@
 
         if (init_data)
         {
-            if (PB_LTYPE(field->type) == PB_LTYPE_SUBMESSAGE)
+            if (PB_LTYPE_IS_SUBMSG(field->type))
             {
                 /* Initialize submessage to defaults */
                 pb_field_iter_t submsg_iter;
@@ -1180,7 +1185,7 @@
             ext = ext->next;
         }
     }
-    else if (PB_LTYPE(type) == PB_LTYPE_SUBMESSAGE && PB_ATYPE(type) != PB_ATYPE_CALLBACK)
+    else if (PB_LTYPE_IS_SUBMSG(type) && PB_ATYPE(type) != PB_ATYPE_CALLBACK)
     {
         /* Release fields in submessage or submsg array */
         pb_size_t count = 1;
@@ -1570,7 +1575,7 @@
 
 static bool checkreturn pb_dec_submessage(pb_istream_t *stream, const pb_field_iter_t *field)
 {
-    bool status;
+    bool status = true;
     pb_istream_t substream;
 
     if (!pb_make_string_substream(stream, &substream))
@@ -1583,9 +1588,33 @@
      * submessages have already been initialized in the top-level pb_decode. */
     if (PB_HTYPE(field->type) == PB_HTYPE_REPEATED ||
         PB_HTYPE(field->type) == PB_HTYPE_ONEOF)
-        status = pb_decode(&substream, field->submsg_desc, field->pData);
-    else
+    {
+        pb_field_iter_t submsg_iter;
+        if (pb_field_iter_begin(&submsg_iter, field->submsg_desc, field->pData))
+        {
+            if (!pb_message_set_to_defaults(&submsg_iter))
+                PB_RETURN_ERROR(stream, "failed to set defaults");
+        }
+    }
+
+    /* Submessages can have a separate message-level callback that is called
+     * before decoding the message. Typically it is used to set callback fields
+     * inside oneofs. */
+    if (PB_LTYPE(field->type) == PB_LTYPE_SUBMSG_W_CB && field->pSize != NULL)
+    {
+        /* Message callback is stored right before pSize. */
+        pb_callback_t *callback = (pb_callback_t*)field->pSize - 1;
+        if (callback->funcs.decode)
+        {
+            status = callback->funcs.decode(&substream, field, &callback->arg);
+        }
+    }
+
+    /* Now decode the submessage contents */
+    if (status)
+    {
         status = pb_decode_noinit(&substream, field->submsg_desc, field->pData);
+    }
     
     if (!pb_close_string_substream(stream, &substream))
         return false;
diff --git a/pb_encode.c b/pb_encode.c
index 093173b..4e40505 100644
--- a/pb_encode.c
+++ b/pb_encode.c
@@ -259,7 +259,8 @@
         }
         else if (PB_HTYPE(type) == PB_HTYPE_OPTIONAL && field->pSize != NULL)
         {
-            /* Proto2 optional fields inside proto3 submessage */
+            /* Proto2 optional fields inside proto3 message, or proto3
+             * submessage fields. */
             return safe_read_bool(field->pSize) == false;
         }
 
@@ -280,12 +281,13 @@
              * it anyway. */
             return field->data_size == 0;
         }
-        else if (PB_LTYPE(type) == PB_LTYPE_SUBMESSAGE)
+        else if (PB_LTYPE_IS_SUBMSG(type))
         {
             /* Check all fields in the submessage to find if any of them
              * are non-zero. The comparison cannot be done byte-per-byte
              * because the C struct may contain padding bytes that must
-             * be skipped.
+             * be skipped. Note that usually proto3 submessages have
+             * a separate has_field that is checked earlier in this if.
              */
             pb_field_iter_t iter;
             if (pb_field_iter_begin(&iter, field->submsg_desc, field->pData))
@@ -358,6 +360,7 @@
             return pb_enc_string(stream, field);
 
         case PB_LTYPE_SUBMESSAGE:
+        case PB_LTYPE_SUBMSG_W_CB:
             return pb_enc_submessage(stream, field);
 
         case PB_LTYPE_FIXED_LENGTH_BYTES:
@@ -383,6 +386,43 @@
 /* Encode a single field of any callback, pointer or static type. */
 static bool checkreturn encode_field(pb_ostream_t *stream, pb_field_iter_t *field)
 {
+    /* Check field presence */
+    if (PB_HTYPE(field->type) == PB_HTYPE_ONEOF)
+    {
+        if (*(const pb_size_t*)field->pSize != field->tag)
+        {
+            /* Different type oneof field */
+            return true;
+        }
+    }
+    else if (PB_HTYPE(field->type) == PB_HTYPE_OPTIONAL)
+    {
+        if (field->pSize)
+        {
+            if (safe_read_bool(field->pSize) == false)
+            {
+                /* Missing optional field */
+                return true;
+            }
+        }
+        else if (PB_ATYPE(field->type) == PB_ATYPE_STATIC)
+        {
+            /* Proto3 singular field */
+            if (pb_check_proto3_default_value(field))
+                return true;
+        }
+    }
+
+    if (!field->pData)
+    {
+        if (PB_HTYPE(field->type) == PB_HTYPE_REQUIRED)
+            PB_RETURN_ERROR(stream, "missing required field");
+
+        /* Pointer field set to NULL */
+        return true;
+    }
+
+    /* Then encode field contents */
     if (PB_ATYPE(field->type) == PB_ATYPE_CALLBACK)
     {
         return encode_callback_field(stream, field);
@@ -391,45 +431,9 @@
     {
         return encode_array(stream, field);
     }
-    else if (PB_HTYPE(field->type) == PB_HTYPE_REQUIRED)
-    {
-        if (!field->pData)
-            PB_RETURN_ERROR(stream, "missing required field");
-
-        return encode_basic_field(stream, field);
-    }
-    else if (PB_HTYPE(field->type) == PB_HTYPE_OPTIONAL)
-    {
-        if (PB_ATYPE(field->type) == PB_ATYPE_STATIC)
-        {
-            if (!field->pSize)
-            {
-                /* Proto3 singular field */
-                if (pb_check_proto3_default_value(field))
-                    return true;
-            }
-            else if (safe_read_bool(field->pSize) == false)
-            {
-                /* Missing optional field */
-                return true;
-            }
-        }
-
-        return encode_basic_field(stream, field);
-    }
-    else if (PB_HTYPE(field->type) == PB_HTYPE_ONEOF)
-    {
-        if (*(const pb_size_t*)field->pSize != field->tag)
-        {
-            /* Different type oneof field */
-            return true;
-        }
-
-        return encode_basic_field(stream, field);
-    }
     else
     {
-        PB_RETURN_ERROR(stream, "invalid field type");
+        return encode_basic_field(stream, field);
     }
 }
 
@@ -667,6 +671,7 @@
         case PB_LTYPE_BYTES:
         case PB_LTYPE_STRING:
         case PB_LTYPE_SUBMESSAGE:
+        case PB_LTYPE_SUBMSG_W_CB:
         case PB_LTYPE_FIXED_LENGTH_BYTES:
             wiretype = PB_WT_STRING;
             break;
@@ -891,6 +896,17 @@
 {
     if (field->submsg_desc == NULL)
         PB_RETURN_ERROR(stream, "invalid field descriptor");
+
+    if (PB_LTYPE(field->type) == PB_LTYPE_SUBMSG_W_CB && field->pSize != NULL)
+    {
+        /* Message callback is stored right before pSize. */
+        pb_callback_t *callback = (pb_callback_t*)field->pSize - 1;
+        if (callback->funcs.encode)
+        {
+            if (!callback->funcs.encode(stream, field, &callback->arg))
+                return false;
+        }
+    }
     
     return pb_encode_submessage(stream, field->submsg_desc, field->pData);
 }
diff --git a/tests/oneof_callback/SConscript b/tests/oneof_callback/SConscript
new file mode 100644
index 0000000..03e9658
--- /dev/null
+++ b/tests/oneof_callback/SConscript
@@ -0,0 +1,22 @@
+# Test decoder callback support inside oneofs.
+
+Import('env')
+
+env.NanopbProto('oneof')
+
+enc = env.Program(['encode_oneof.c',
+                'oneof.pb.c',
+                '$COMMON/pb_encode.o',
+                '$COMMON/pb_common.o'])
+
+dec = env.Program(['decode_oneof.c',
+                'oneof.pb.c',
+                '$COMMON/pb_decode.o',
+                '$COMMON/pb_common.o'])
+
+for i in range(1,7):
+    # Encode message, then decode with protoc and test program and compare.
+    e = env.RunTest("message%d.pb" % i, enc, ARGS = [str(i)])
+    d1 = env.Decode([e, "oneof.proto"], MESSAGE = "OneOfMessage")
+    d2 = env.RunTest("message%d.txt" % i, [dec, e])
+    env.Compare([d1, d2])
diff --git a/tests/oneof_callback/decode_oneof.c b/tests/oneof_callback/decode_oneof.c
new file mode 100644
index 0000000..3b38f15
--- /dev/null
+++ b/tests/oneof_callback/decode_oneof.c
@@ -0,0 +1,174 @@
+/* Decode a message using callbacks inside oneof fields */
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <pb_decode.h>
+#include <assert.h>
+#include "oneof.pb.h"
+#include "test_helpers.h"
+#include "unittests.h"
+
+/* This is a nanopb-0.4 style global callback, that is referred by function name
+ * and does not have to be bound separately to the message. It also allows defining
+ * a custom data type for the field in the structure.
+ */
+bool SubMsg3_callback(pb_istream_t *istream, pb_ostream_t *ostream, const pb_field_t *field)
+{
+    if (istream && field->tag == SubMsg3_strvalue_tag)
+    {
+        /* We could e.g. malloc some memory and assign it to our custom datatype
+         * in the message structure here, accessible by field->pData. But in
+         * this example we just print the string directly.
+         */
+
+        uint8_t buffer[64];
+        int strlen = istream->bytes_left;
+
+        if (strlen > sizeof(buffer) - 1)
+            return false;
+
+        buffer[strlen] = '\0';
+
+        if (!pb_read(istream, buffer, strlen))
+            return false;
+
+        printf("  strvalue: \"%s\"\n", buffer);
+    }
+
+    return true;
+}
+
+/* The two callbacks below are traditional callbacks that use function pointers
+ * defined in pb_callback_t.
+ */
+bool print_int32(pb_istream_t *stream, const pb_field_t *field, void **arg)
+{
+    uint64_t value;
+    if (!pb_decode_varint(stream, &value))
+        return false;
+
+    printf((char*)*arg, (long)value);
+    return true;
+}
+
+bool print_string(pb_istream_t *stream, const pb_field_t *field, void **arg)
+{
+    uint8_t buffer[64];
+    int strlen = stream->bytes_left;
+
+    if (strlen > sizeof(buffer) - 1)
+        return false;
+
+    buffer[strlen] = '\0';
+
+    if (!pb_read(stream, buffer, strlen))
+        return false;
+
+    /* Print the string, in format comparable with protoc --decode.
+     * Format comes from the arg defined in main().
+     */
+    printf((char*)*arg, buffer);
+    return true;
+}
+
+/* The callback below is a message-level callback which is called before each
+ * submessage is encoded. It is used to set the pb_callback_t callbacks inside
+ * the submessage. The reason we need this is that different submessages share
+ * storage inside oneof union, and before we know the message type we can't set
+ * the callbacks without overwriting each other.
+ */
+bool msg_callback(pb_istream_t *stream, const pb_field_t *field, void **arg)
+{
+    /* Print the prefix field before the submessages.
+     * This also demonstrates how to access the top level message fields
+     * from callbacks.
+     */
+    OneOfMessage *topmsg = field->message;
+    printf("prefix: %d\n", topmsg->prefix);
+
+    if (field->tag == OneOfMessage_submsg1_tag)
+    {
+        SubMsg1 *msg = field->pData;
+        printf("submsg1 {\n");
+        msg->array.funcs.decode = print_int32;
+        msg->array.arg = "  array: %d\n";
+    }
+    else if (field->tag == OneOfMessage_submsg2_tag)
+    {
+        SubMsg2 *msg = field->pData;
+        printf("submsg2 {\n");
+        msg->strvalue.funcs.decode = print_string;
+        msg->strvalue.arg = "  strvalue: \"%s\"\n";
+    }
+    else if (field->tag == OneOfMessage_submsg3_tag)
+    {
+        /* Because SubMsg3 callback is bound by function name, we do not
+         * need to initialize anything here. But we just print a string
+         * to get protoc-equivalent formatted output from the testcase.
+         */
+        printf("submsg3 {\n");
+    }
+
+    /* Once we return true, pb_dec_submessage() will go on to decode the
+     * submessage contents. But if we want, we can also decode it ourselves
+     * above and leave stream->bytes_left at 0 value, inhibiting automatic
+     * decoding.
+     */
+    return true;
+}
+
+int main(int argc, char **argv)
+{
+    uint8_t buffer[256];
+    OneOfMessage msg = OneOfMessage_init_zero;
+    pb_istream_t stream;
+    size_t count;
+
+    SET_BINARY_MODE(stdin);
+    count = fread(buffer, 1, sizeof(buffer), stdin);
+
+    if (!feof(stdin))
+    {
+        fprintf(stderr, "Message does not fit in buffer\n");
+        return 1;
+    }
+
+    /* Set up the cb_values callback, which will in turn set up the callbacks
+     * for each oneof field once the field tag is known. */
+    msg.cb_values.funcs.decode = msg_callback;
+    stream = pb_istream_from_buffer(buffer, count);
+    if (!pb_decode(&stream, OneOfMessage_fields, &msg))
+    {
+        fprintf(stderr, "Decoding failed: %s\n", PB_GET_ERROR(&stream));
+        return 1;
+    }
+
+    /* This is just printing for the test case logic */
+    if (msg.which_values == OneOfMessage_intvalue_tag)
+    {
+        printf("prefix: %d\n", msg.prefix);
+        printf("intvalue: %d\n", msg.values.intvalue);
+    }
+    else if (msg.which_values == OneOfMessage_strvalue_tag)
+    {
+        printf("prefix: %d\n", msg.prefix);
+        printf("strvalue: \"%s\"\n", msg.values.strvalue);
+    }
+    else if (msg.which_values == OneOfMessage_submsg3_tag &&
+             msg.values.submsg3.which_values == SubMsg3_intvalue_tag)
+    {
+        printf("  intvalue: %d\n", msg.values.submsg3.values.intvalue);
+        printf("}\n");
+    }
+    else
+    {
+        printf("}\n");
+    }
+    printf("suffix: %d\n", msg.suffix);
+
+    assert(msg.prefix == 123);
+    assert(msg.suffix == 321);
+
+    return 0;
+}
diff --git a/tests/oneof_callback/encode_oneof.c b/tests/oneof_callback/encode_oneof.c
new file mode 100644
index 0000000..ad7a54c
--- /dev/null
+++ b/tests/oneof_callback/encode_oneof.c
@@ -0,0 +1,126 @@
+/* Encode a message using callbacks inside oneof fields.
+ * For encoding, callbacks inside oneofs require nothing special
+ * so this is just normal callback usage.
+ */
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <pb_encode.h>
+#include "oneof.pb.h"
+#include "test_helpers.h"
+
+/* This is a nanopb-0.4 style global callback, that is referred by function name
+ * and does not have to be bound separately to the message. It also allows defining
+ * a custom data type for the field in the structure.
+ */
+bool SubMsg3_callback(pb_istream_t *istream, pb_ostream_t *ostream, const pb_field_t *field)
+{
+    if (ostream && field->tag == SubMsg3_strvalue_tag)
+    {
+         /* Our custom data type is char* */
+        const char *str = *(const char**)field->pData;
+
+        if (!pb_encode_tag_for_field(ostream, field))
+            return false;
+
+        return pb_encode_string(ostream, (const uint8_t*)str, strlen(str));
+    }
+
+    return true;
+}
+
+/* The two callbacks below are traditional callbacks that use function pointers
+ * defined in pb_callback_t.
+ */
+bool encode_int32_array(pb_ostream_t *stream, const pb_field_t *field, void * const *arg)
+{
+    int i;
+    for (i = 0; i < 15; i++)
+    {
+        if (!pb_encode_tag_for_field(stream, field))
+            return false;
+
+        if (!pb_encode_varint(stream, i))
+            return false;
+    }
+    return true;
+}
+
+bool encode_string(pb_ostream_t *stream, const pb_field_t *field, void * const *arg)
+{
+    const char *str = "mystring";
+
+    if (!pb_encode_tag_for_field(stream, field))
+        return false;
+
+    return pb_encode_string(stream, (const uint8_t*)str, strlen(str));
+}
+
+int main(int argc, char **argv)
+{
+    uint8_t buffer[256];
+    OneOfMessage msg = OneOfMessage_init_zero;
+    pb_ostream_t stream;
+    int option;
+
+    if (argc != 2)
+    {
+        fprintf(stderr, "Usage: encode_oneof [number]\n");
+        return 1;
+    }
+    option = atoi(argv[1]);
+
+    /* Prefix and suffix are used to test that the union does not disturb
+     * other fields in the same message. */
+    msg.prefix = 123;
+
+    /* We encode one of the 'values' fields based on command line argument */
+    if (option == 1)
+    {
+        msg.which_values = OneOfMessage_intvalue_tag;
+        msg.values.intvalue = 999;
+    }
+    else if (option == 2)
+    {
+        msg.which_values = OneOfMessage_strvalue_tag;
+        strcpy(msg.values.strvalue, "abcd");
+    }
+    else if (option == 3)
+    {
+        msg.which_values = OneOfMessage_submsg1_tag;
+        msg.values.submsg1.array.funcs.encode = encode_int32_array;
+    }
+    else if (option == 4)
+    {
+        msg.which_values = OneOfMessage_submsg2_tag;
+        msg.values.submsg2.strvalue.funcs.encode = encode_string;
+    }
+    else if (option == 5)
+    {
+        msg.which_values = OneOfMessage_submsg3_tag;
+        msg.values.submsg3.which_values = SubMsg3_intvalue_tag;
+        msg.values.submsg3.values.intvalue = 1234;
+    }
+    else if (option == 6)
+    {
+        msg.which_values = OneOfMessage_submsg3_tag;
+        msg.values.submsg3.which_values = SubMsg3_strvalue_tag;
+        msg.values.submsg3.values.strvalue = "efgh";
+    }
+
+    msg.suffix = 321;
+
+    stream = pb_ostream_from_buffer(buffer, sizeof(buffer));
+
+    if (pb_encode(&stream, OneOfMessage_fields, &msg))
+    {
+        SET_BINARY_MODE(stdout);
+        fwrite(buffer, 1, stream.bytes_written, stdout);
+        return 0;
+    }
+    else
+    {
+        fprintf(stderr, "Encoding failed: %s\n", PB_GET_ERROR(&stream));
+        return 1;
+    }
+}
diff --git a/tests/oneof_callback/oneof.proto b/tests/oneof_callback/oneof.proto
new file mode 100644
index 0000000..41b70b6
--- /dev/null
+++ b/tests/oneof_callback/oneof.proto
@@ -0,0 +1,41 @@
+syntax = "proto3";
+
+import 'nanopb.proto';
+
+// Repeated callback inside submessage inside oneof
+message SubMsg1
+{
+    repeated int32 array = 1;
+}
+
+// String callback inside submessage inside oneof
+message SubMsg2
+{
+    string strvalue = 1;
+}
+
+// String callback directly inside oneof
+message SubMsg3
+{
+    oneof values
+    {
+        int32 intvalue = 1;
+        string strvalue = 2 [(nanopb).callback_datatype = "const char*"];
+    }
+}
+
+message OneOfMessage
+{
+    option (nanopb_msgopt).submsg_callback = true;
+
+    int32 prefix = 1;
+    oneof values
+    {
+        int32 intvalue = 5;
+        string strvalue = 6 [(nanopb).max_size = 8];
+        SubMsg1 submsg1 = 7;
+        SubMsg2 submsg2 = 8;
+        SubMsg3 submsg3 = 9;
+    }
+    int32 suffix = 99;
+}
diff --git a/tests/regression/issue_342/SConscript b/tests/regression/issue_342/SConscript
index e15d022..5178290 100644
--- a/tests/regression/issue_342/SConscript
+++ b/tests/regression/issue_342/SConscript
@@ -3,19 +3,9 @@
 
 Import("env")
 
-# Define the compilation options
-opts = env.Clone()
-opts.Append(CPPDEFINES = {'PB_OLD_CALLBACK_STYLE': 1})
+env.NanopbProto("extensions")
+testprog = env.Program(["test_extensions.c", "extensions.pb.c",
+                '$COMMON/pb_decode.o', '$COMMON/pb_encode.o', '$COMMON/pb_common.o'])
 
-# Build new version of core
-strict = opts.Clone()
-strict.Append(CFLAGS = strict['CORECFLAGS'])
-strict.Object("pb_decode_oldcallback.o", "$NANOPB/pb_decode.c")
-strict.Object("pb_encode_oldcallback.o", "$NANOPB/pb_encode.c")
-strict.Object("pb_common_oldcallback.o", "$NANOPB/pb_common.c")
-
-opts.NanopbProto("extensions")
-testprog = opts.Program(["test_extensions.c", "extensions.pb.c", "pb_encode_oldcallback.o", "pb_decode_oldcallback.o", "pb_common_oldcallback.o"])
-
-opts.RunTest(testprog)
+env.RunTest(testprog)
 
diff --git a/tests/regression/issue_342/test_extensions.c b/tests/regression/issue_342/test_extensions.c
index 1a48855..1fb0c30 100644
--- a/tests/regression/issue_342/test_extensions.c
+++ b/tests/regression/issue_342/test_extensions.c
@@ -6,7 +6,7 @@
 #include "extensions.pb.h"
 #include "unittests.h"
 
-static bool write_string(pb_ostream_t *stream, const pb_field_t *field, const void *arg)
+static bool write_string(pb_ostream_t *stream, const pb_field_t *field, void * const *arg)
 {
     return pb_encode_tag_for_field(stream, field) &&
            pb_encode_string(stream, (const void*)"abc", 3);
@@ -41,7 +41,7 @@
         BaseMessage msg = {0};
         
         ext.type = &string_extension;
-        /* Note: ext.dest remains null to trigger buf #342 */
+        /* Note: ext.dest remains null to trigger bug #342 */
         msg.extensions = &ext;
     
         TEST(pb_decode(&stream, BaseMessage_fields, &msg));