#include "google/protobuf/reflection_visit_fields.h"

#include <cstddef>
#include <cstdint>
#include <string>
#include <utility>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/log/absl_check.h"
#include "absl/strings/cord.h"
#include "absl/strings/string_view.h"
#include "google/protobuf/arena.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/map_test_util.h"
#include "google/protobuf/map_unittest.pb.h"
#include "google/protobuf/test_util.h"
#include "google/protobuf/unittest.pb.h"
#include "google/protobuf/unittest_mset.pb.h"
#include "google/protobuf/unittest_mset_wire_format.pb.h"
#include "google/protobuf/wire_format.h"
#include "google/protobuf/wire_format_lite.h"

namespace google {
namespace protobuf {
namespace internal {
namespace {

#ifdef __cpp_if_constexpr

using ::protobuf_unittest::NestedTestAllTypes;
using ::protobuf_unittest::TestAllExtensions;
using ::protobuf_unittest::TestAllTypes;
using ::protobuf_unittest::TestMap;
using ::protobuf_unittest::TestOneof2;
using ::protobuf_unittest::TestPackedExtensions;
using ::protobuf_unittest::TestPackedTypes;
using ::proto2_wireformat_unittest::TestMessageSet;

struct TestParam {
  absl::string_view name;
  Message* (*create_message)(Arena& arena);
};

class VisitFieldsTest : public testing::TestWithParam<TestParam> {
 public:
  VisitFieldsTest() {
    message_ = GetParam().create_message(arena_);
    reflection_ = message_->GetReflection();
  }

 protected:
  Arena arena_;
  Message* message_;
  const Reflection* reflection_;
};

TEST_P(VisitFieldsTest, VisitedFieldsCountMatchesListFields) {
  uint32_t count = 0;
  VisitFields(*message_, [&](auto info) { ++count; });

  std::vector<const FieldDescriptor*> fields;
  reflection_->ListFields(*message_, &fields);

  EXPECT_EQ(count, fields.size());
}

TEST_P(VisitFieldsTest, VisitedMessageFieldsCountMatchesListFields) {
  uint32_t count = 0;
  VisitFields(*message_, [&](auto info) { ++count; }, FieldMask::kMessage);

  std::vector<const FieldDescriptor*> fields;
  reflection_->ListFields(*message_, &fields);
  int message_count = 0;
  for (auto field : fields) {
    if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) ++message_count;
  }

  EXPECT_EQ(count, message_count);
}

// Counts present message fields using ListFields() where:
// --N elements in a repeated message field are counted N times
// --M message values in a map field are counted M times.
// --A map field whose value type is not message is ignored.
int CountAllMessageFieldsViaListFields(const Reflection* reflection,
                                       const Message& message) {
  std::vector<const FieldDescriptor*> fields;
  reflection->ListFields(message, &fields);

  int message_count = 0;
  for (auto field : fields) {
    if (field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) continue;
    if (field->is_map()) {
      if (field->message_type()->map_value()->cpp_type() !=
          FieldDescriptor::CPPTYPE_MESSAGE)
        continue;
    }
    if (field->is_repeated()) {
      message_count += reflection->FieldSize(message, field);
    } else {
      ++message_count;
    }
  }
  return message_count;
}

TEST_P(VisitFieldsTest, VisitMessageFieldsCountIncludesRepeatedElements) {
  int count = 0;
  VisitMessageFields(*message_, [&](const Message& msg) { ++count; });

  EXPECT_EQ(count, CountAllMessageFieldsViaListFields(reflection_, *message_));
}

TEST_P(VisitFieldsTest,
       VisitMutableMessageFieldsCountIncludesRepeatedElements) {
  int count = 0;
  VisitMutableMessageFields(*message_, [&](Message& msg) { ++count; });

  EXPECT_EQ(count, CountAllMessageFieldsViaListFields(reflection_, *message_));
}

TEST_P(VisitFieldsTest, ClearByVisitFieldsMustBeEmpty) {
  VisitFields(*message_, [](auto info) { info.Clear(); });

  EXPECT_EQ(message_->ByteSizeLong(), 0);
}

TEST_P(VisitFieldsTest, ClearByVisitFieldsRevisitNone) {
  VisitFields(*message_, [](auto info) { info.Clear(); });

  uint32_t count = 0;
  VisitFields(*message_, [&](auto info) { ++count; });

  EXPECT_EQ(count, 0);
}

void MutateNothingByVisit(Message& message) {
  VisitFields(message, [&](auto info) {
    if constexpr (info.is_map) {
      info.VisitElements([&](auto key, auto val) {
        if constexpr (val.cpp_type == FieldDescriptor::CPPTYPE_MESSAGE) {
          auto* msg = val.Mutable();
          (void)msg->ByteSizeLong();
        } else {
          val.Set(val.Get());
        }
      });
    } else if constexpr (info.is_repeated) {
      if constexpr (info.cpp_type == FieldDescriptor::CPPTYPE_MESSAGE) {
        size_t byte_size = 0;
        for (auto& it : info.Mutable()) {
          byte_size += it.ByteSizeLong();
        }
      } else if constexpr (info.cpp_type == FieldDescriptor::CPPTYPE_STRING) {
        if constexpr (info.is_cord) {
          for (auto& it : info.Mutable()) {
            absl::Cord tmp = it;
            it = tmp;
          }
        } else if constexpr (info.is_string_piece) {
          for (auto& it : info.Mutable()) {
            it.CopyFrom({it.data(), it.size()});
          }
        } else {
          for (auto& it : info.Mutable()) {
            std::string tmp;
            tmp = it;
            it = tmp;
          }
        }
      } else {
        for (auto& it : info.Mutable()) {
          it = it;
        }
      }
    } else {
      if constexpr (info.cpp_type == FieldDescriptor::CPPTYPE_MESSAGE) {
        auto& msg = info.Mutable();
        ABSL_CHECK_GT(msg.ByteSizeLong(), 0);
      } else if constexpr (info.cpp_type == FieldDescriptor::CPPTYPE_STRING) {
        if constexpr (info.is_cord) {
          info.Set(info.Get());
        } else {
          absl::string_view view = info.Get();
          info.Set({view.data(), view.size()});
        }
      } else {
        info.Set(info.Get());
      }
    }
  });
}

TEST_P(VisitFieldsTest, MutateNothingByVisitIdempotent) {
  std::string data;
  ASSERT_TRUE(message_->SerializeToString(&data));

  MutateNothingByVisit(*message_);

  // Checking the identity by comparing serialize bytes is discouraged, but this
  // allows us to be type-agnositc for this test. Also, the back to back
  // serialization should be stable.
  EXPECT_EQ(data, message_->SerializeAsString());
}

template <typename InfoT>
inline size_t MapKeyByteSizeLong(FieldDescriptor::Type type, InfoT info) {
  if constexpr (info.cpp_type == FieldDescriptor::CPPTYPE_STRING) {
    return WireFormatLite::StringSize(info.Get());
  } else {
    return MapPrimitiveFieldByteSize<info.cpp_type>(type, info.Get());
  }
}

template <typename InfoT>
inline size_t MapValueByteSizeLong(FieldDescriptor::Type type, InfoT info) {
  if constexpr (info.cpp_type == FieldDescriptor::CPPTYPE_STRING) {
    return WireFormatLite::StringSize(info.Get());
  } else if constexpr (info.cpp_type == FieldDescriptor::CPPTYPE_MESSAGE) {
    ABSL_CHECK_NE(type, FieldDescriptor::TYPE_GROUP);
    return WireFormatLite::MessageSize(info.Get());
  } else {
    return MapPrimitiveFieldByteSize<info.cpp_type>(type, info.Get());
  }
}

size_t ByteSizeLongByVisit(const Message& message) {
  size_t byte_size = 0;

  VisitFields(message, [&](auto info) {
    if constexpr (info.is_map) {
      // A map field is encoded as repeated messages like:
      //
      // message MapField {
      //   message MapEntry {
      //     1: key
      //     2: value
      //   };
      //   repeated MapEntry entry = <FIELD_NUMBER>;
      // }
      // where the key and the value tag are always encoded as two bytes. (no
      // group).
      //
      // Visit each element to calculate the inner size (key, value) and the
      // length field for LENGTH_DELIMITED field.
      size_t size = static_cast<size_t>(info.size());
      ABSL_CHECK_GT(size, 0u);
      byte_size += size * WireFormat::TagSize(info.number(),
                                              FieldDescriptor::TYPE_MESSAGE);

      FieldDescriptor::Type key_type = info.key_type();
      FieldDescriptor::Type val_type = info.value_type();

      info.VisitElements([&](auto key, auto val) {
        size_t inner_size = 2 + MapKeyByteSizeLong(key_type, key) +
                            MapValueByteSizeLong(val_type, val);

        byte_size += io::CodedOutputStream::VarintSize32(
                         static_cast<uint32_t>(inner_size)) +
                     inner_size;
      });
    } else if constexpr (info.is_repeated) {
      if (info.is_packed()) {
        size_t payload_size = info.FieldByteSize();
        byte_size +=
            WireFormat::TagSize(info.number(), FieldDescriptor::TYPE_STRING) +
            payload_size +
            io::CodedOutputStream::VarintSize32(
                static_cast<uint32_t>(payload_size));
      } else {
        size_t size = static_cast<size_t>(info.size());
        ABSL_CHECK_GT(size, 0u);
        byte_size += size * WireFormat::TagSize(info.number(), info.type()) +
                     info.FieldByteSize();
      }
    } else {
      if constexpr (info.cpp_type == FieldDescriptor::CPPTYPE_MESSAGE) {
        if constexpr (info.is_extension) {
          if (info.type() == FieldDescriptor::TYPE_MESSAGE &&
              info.is_message_set) {
            byte_size +=
                WireFormatLite::kMessageSetItemTagsSize +
                io::CodedOutputStream::VarintSize32(
                    static_cast<uint32_t>(info.number())) +
                WireFormatLite::LengthDelimitedSize(info.FieldByteSize());
            return;
          }
        }
        if (info.type() == FieldDescriptor::TYPE_GROUP) {
          byte_size += WireFormat::TagSize(info.number(), info.type()) +
                       info.FieldByteSize();
        } else if (info.type() == FieldDescriptor::TYPE_MESSAGE) {
          byte_size +=
              WireFormat::TagSize(info.number(), info.type()) +
              WireFormatLite::LengthDelimitedSize(info.FieldByteSize());
        }
      } else {
        byte_size += WireFormat::TagSize(info.number(), info.type()) +
                     info.FieldByteSize();
      }
    }
  });
  return byte_size;
}

TEST_P(VisitFieldsTest, ByteSizeByVisitFieldsMatchesCodegen) {
  EXPECT_EQ(ByteSizeLongByVisit(*message_), message_->ByteSizeLong());
}

TestMessageSet* CreateTestMessageSet(Arena& arena) {
  auto* msg = Arena::Create<TestMessageSet>(&arena);
  auto* ext1 = msg->MutableExtension(
      unittest::TestMessageSetExtension1::message_set_extension);
  ext1->set_i(-1);
  auto* ext3 = ext1->mutable_recursive()->MutableExtension(
      unittest::TestMessageSetExtension3::message_set_extension);
  ext3->set_required_int(-1);
  ext3->mutable_msg()->set_b(0);
  return msg;
}

NestedTestAllTypes* CreateNestedTestAllTypes(Arena& arena) {
  auto* msg = Arena::Create<NestedTestAllTypes>(&arena);
  TestUtil::SetAllFields(msg->mutable_payload());
  TestUtil::SetAllFields(msg->mutable_lazy_child()->mutable_payload());
  return msg;
}

INSTANTIATE_TEST_SUITE_P(
    ReflectionVisitFieldsTest, VisitFieldsTest,
    testing::Values(
        TestParam{"TestAllTypes",
                  [](Arena& arena) -> Message* {
                    auto* msg = Arena::Create<TestAllTypes>(&arena);
                    TestUtil::SetAllFields(msg);
                    return msg;
                  }},
        TestParam{"TestAllExtensions",
                  [](Arena& arena) -> Message* {
                    auto* msg = Arena::Create<TestAllExtensions>(&arena);
                    TestUtil::SetAllExtensions(msg);
                    return msg;
                  }},
        TestParam{"TestAllExtensionsLazy",
                  [](Arena& arena) -> Message* {
                    TestAllExtensions original;
                    TestUtil::SetAllExtensions(&original);
                    auto* parsed = Arena::Create<TestAllExtensions>(&arena);
                    ABSL_CHECK(
                        parsed->ParseFromString(original.SerializeAsString()));
                    return parsed;
                  }},
        TestParam{"TestMap",
                  [](Arena& arena) -> Message* {
                    auto* msg = Arena::Create<TestMap>(&arena);
                    MapTestUtil::SetMapFields(msg);
                    return msg;
                  }},
        TestParam{"TestMessageSet",
                  [](Arena& arena) -> Message* {
                    return CreateTestMessageSet(arena);
                  }},
        TestParam{"TestMessageSetLazy",
                  [](Arena& arena) -> Message* {
                    auto* original = CreateTestMessageSet(arena);
                    auto* parsed = Arena::Create<TestMessageSet>(&arena);
                    ABSL_CHECK(
                        parsed->ParseFromString(original->SerializeAsString()));
                    return parsed;
                  }},
        TestParam{"TestOneof2LazyField",
                  [](Arena& arena) -> Message* {
                    auto* msg = Arena::Create<TestOneof2>(&arena);
                    TestUtil::SetOneof2(msg);
                    msg->mutable_foo_lazy_message()->set_moo_int(0);
                    return msg;
                  }},
        TestParam{"TestPacked",
                  [](Arena& arena) -> Message* {
                    auto* msg = Arena::Create<TestPackedTypes>(&arena);
                    TestUtil::SetPackedFields(msg);
                    return msg;
                  }},
        TestParam{"TestPackedExtensions",
                  [](Arena& arena) -> Message* {
                    auto* msg = Arena::Create<TestPackedExtensions>(&arena);
                    TestUtil::SetPackedExtensions(msg);
                    return msg;
                  }},
        TestParam{"NestedTestAllTypes",
                  [](Arena& arena) -> Message* {
                    return CreateNestedTestAllTypes(arena);
                  }},
        TestParam{"NestedTestAllTypesLazy",
                  [](Arena& arena) -> Message* {
                    auto* original = CreateNestedTestAllTypes(arena);
                    auto* parsed = Arena::Create<NestedTestAllTypes>(&arena);
                    ABSL_CHECK(
                        parsed->ParseFromString(original->SerializeAsString()));
                    return parsed;
                  }}),
    [](const testing::TestParamInfo<TestParam>& info) {
      return std::string(info.param.name);
    });

template <typename T>
void MutateMapValue(TestMap& message, absl::string_view name, int index,
                    T&& callback) {
  const Reflection* reflection = message.GetReflection();
  const Descriptor* descriptor = message.GetDescriptor();
  const FieldDescriptor* field = descriptor->FindFieldByName(name);

  auto* map_entry = reflection->MutableRepeatedMessage(&message, field, index);
  const FieldDescriptor* val_field = map_entry->GetDescriptor()->map_value();
  ABSL_CHECK_NE(val_field, nullptr);
  callback(map_entry->GetReflection(), map_entry, val_field);
}

TEST(ReflectionVisitTest, VisitMapAfterMutableRepeated) {
  TestMap message;
  auto& map = *message.mutable_map_int32_int32();
  map[0] = 0;
  map[1] = 0;

  // Reflectively overwrites values to 200 for all entries. This forces
  // conversion to a mutable repeated field.
  auto set_int32_val = [&](const Reflection* reflection, Message* msg,
                           const FieldDescriptor* field) {
    reflection->SetInt32(msg, field, 200);
  };
  MutateMapValue(message, "map_int32_int32", 0, set_int32_val);
  MutateMapValue(message, "map_int32_int32", 1, set_int32_val);

  // Later visit fields must be map fields synced with the change.
  std::vector<std::pair<int32_t, int32_t>> key_val_pairs;
  VisitFields(message, [&](auto info) {
    if constexpr (info.is_map) {
      ASSERT_EQ(info.key_type(), FieldDescriptor::TYPE_INT32);
      ASSERT_EQ(info.value_type(), FieldDescriptor::TYPE_INT32);

      info.VisitElements([&](auto key, auto val) {
        if constexpr (key.cpp_type == FieldDescriptor::CPPTYPE_INT32 &&
                      val.cpp_type == FieldDescriptor::CPPTYPE_INT32) {
          key_val_pairs.emplace_back(key.Get(), val.Get());
        }
      });
    }
  });

  EXPECT_THAT(key_val_pairs, testing::UnorderedElementsAre(
                                 testing::Pair(0, 200), testing::Pair(1, 200)));
}

#endif  // __cpp_if_constexpr

}  // namespace
}  // namespace internal
}  // namespace protobuf
}  // namespace google
