blob: 6b82fd51162ea23c896e4a055594debe1a60bbcd [file] [log] [blame]
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd
"""Unittest for thread safe"""
import sys
import threading
import time
import unittest
from google.protobuf import descriptor_pb2
from google.protobuf import descriptor_pool
from google.protobuf import message_factory
from google.protobuf.internal import api_implementation
from google.protobuf import unittest_pb2
class ThreadSafeTest(unittest.TestCase):
def setUp(self):
self.success = 0
def testFieldDecodersDataRace(self):
msg = unittest_pb2.TestAllTypes(optional_int32=1)
serialized_data = msg.SerializeToString()
lock = threading.Lock()
def ParseMessage():
parsed_msg = unittest_pb2.TestAllTypes()
time.sleep(0.005)
parsed_msg.ParseFromString(serialized_data)
with lock:
if msg == parsed_msg:
self.success += 1
field_des = unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name[
'optional_int32'
]
count = 1000
for x in range(0, count):
# delete the _decoders because only the first time parse the field
# may cause data race.
if hasattr(field_des, '_decoders'):
delattr(field_des, '_decoders')
thread1 = threading.Thread(target=ParseMessage)
thread2 = threading.Thread(target=ParseMessage)
thread1.start()
thread2.start()
thread1.join()
thread2.join()
self.assertEqual(count * 2, self.success)
class FreeThreadingTest(unittest.TestCase):
def RunThreads(self, thread_size, func):
threads = []
for i in range(0, thread_size):
threads.append(threading.Thread(target=func))
for thread in threads:
thread.start()
for thread in threads:
thread.join()
def testDoNothing(self):
thread_size = 10
def DoNothing():
return
self.RunThreads(thread_size, DoNothing)
@unittest.skipIf(
api_implementation.Type() != 'cpp',
'Only cpp supports free threading for now',
)
def testDescriptorPoolMap(self):
thread_size = 20
self.success_count = 0
lock = threading.Lock()
def CreatePool():
def DoCreate():
pool = descriptor_pool.DescriptorPool()
file_proto = descriptor_pb2.FileDescriptorProto(name='foo')
message_proto = file_proto.message_type.add(name='SomeMessage')
message_proto.field.add(
name='int_field',
number=1,
type=descriptor_pb2.FieldDescriptorProto.TYPE_INT32,
label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL,
)
pool.Add(file_proto)
desc = pool.FindMessageTypeByName('SomeMessage')
msg = message_factory.GetMessageClass(desc)()
msg.int_field = 1
DoCreate()
with lock:
self.success_count += 1
self.RunThreads(thread_size, CreatePool)
self.assertEqual(thread_size, self.success_count)
if __name__ == '__main__':
unittest.main()