| # Protocol Buffers - Google's data interchange format |
| # Copyright 2008 Google Inc. All rights reserved. |
| # |
| # Use of this source code is governed by a BSD-style |
| # license that can be found in the LICENSE file or at |
| # https://developers.google.com/open-source/licenses/bsd |
| |
| """Unittest for thread safe""" |
| |
| import sys |
| import threading |
| import time |
| import unittest |
| |
| from google.protobuf import descriptor_pb2 |
| from google.protobuf import descriptor_pool |
| from google.protobuf import message_factory |
| from google.protobuf.internal import api_implementation |
| |
| from google.protobuf import unittest_pb2 |
| |
| class ThreadSafeTest(unittest.TestCase): |
| |
| def setUp(self): |
| self.success = 0 |
| |
| def testFieldDecodersDataRace(self): |
| msg = unittest_pb2.TestAllTypes(optional_int32=1) |
| serialized_data = msg.SerializeToString() |
| lock = threading.Lock() |
| |
| def ParseMessage(): |
| parsed_msg = unittest_pb2.TestAllTypes() |
| time.sleep(0.005) |
| parsed_msg.ParseFromString(serialized_data) |
| with lock: |
| if msg == parsed_msg: |
| self.success += 1 |
| |
| field_des = unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name[ |
| 'optional_int32' |
| ] |
| count = 1000 |
| for x in range(0, count): |
| # delete the _decoders because only the first time parse the field |
| # may cause data race. |
| if hasattr(field_des, '_decoders'): |
| delattr(field_des, '_decoders') |
| thread1 = threading.Thread(target=ParseMessage) |
| thread2 = threading.Thread(target=ParseMessage) |
| thread1.start() |
| thread2.start() |
| thread1.join() |
| thread2.join() |
| |
| self.assertEqual(count * 2, self.success) |
| |
| |
| class FreeThreadingTest(unittest.TestCase): |
| |
| def RunThreads(self, thread_size, func): |
| threads = [] |
| for i in range(0, thread_size): |
| threads.append(threading.Thread(target=func)) |
| for thread in threads: |
| thread.start() |
| for thread in threads: |
| thread.join() |
| |
| def testDoNothing(self): |
| thread_size = 10 |
| |
| def DoNothing(): |
| return |
| |
| self.RunThreads(thread_size, DoNothing) |
| |
| @unittest.skipIf( |
| api_implementation.Type() != 'cpp', |
| 'Only cpp supports free threading for now', |
| ) |
| def testDescriptorPoolMap(self): |
| thread_size = 20 |
| self.success_count = 0 |
| lock = threading.Lock() |
| |
| def CreatePool(): |
| def DoCreate(): |
| pool = descriptor_pool.DescriptorPool() |
| file_proto = descriptor_pb2.FileDescriptorProto(name='foo') |
| message_proto = file_proto.message_type.add(name='SomeMessage') |
| message_proto.field.add( |
| name='int_field', |
| number=1, |
| type=descriptor_pb2.FieldDescriptorProto.TYPE_INT32, |
| label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL, |
| ) |
| pool.Add(file_proto) |
| desc = pool.FindMessageTypeByName('SomeMessage') |
| msg = message_factory.GetMessageClass(desc)() |
| msg.int_field = 1 |
| |
| DoCreate() |
| with lock: |
| self.success_count += 1 |
| |
| self.RunThreads(thread_size, CreatePool) |
| self.assertEqual(thread_size, self.success_count) |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |