# Copyright 2020 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""This module defines the generated code for pw_protobuf C++ classes."""

import abc
from datetime import datetime
import enum
import os
import sys
from typing import Dict, Iterable, List, Tuple
from typing import cast

import google.protobuf.descriptor_pb2 as descriptor_pb2

from pw_protobuf.output_file import OutputFile
from pw_protobuf.proto_tree import ProtoEnum, ProtoMessage, ProtoMessageField
from pw_protobuf.proto_tree import ProtoNode
from pw_protobuf.proto_tree import build_node_tree

PLUGIN_NAME = 'pw_protobuf'
PLUGIN_VERSION = '0.1.0'

PROTO_H_EXTENSION = '.pwpb.h'
PROTO_CC_EXTENSION = '.pwpb.cc'

PROTOBUF_NAMESPACE = 'pw::protobuf'


class EncoderType(enum.Enum):
    MEMORY = 1
    STREAMING = 2
    LEGACY = 3

    def base_class_name(self) -> str:
        """Returns the base class used by this encoder type."""
        if self is self.LEGACY:
            return 'ProtoMessageEncoder'
        if self is self.STREAMING:
            return 'StreamingEncoder'
        if self is self.MEMORY:
            return 'MemoryEncoder'

        raise ValueError('Unknown encoder type')

    def codegen_class_name(self) -> str:
        """Returns the base class used by this encoder type."""
        if self is self.LEGACY:
            return 'Encoder'
        if self is self.STREAMING:
            return 'StreamEncoder'
        if self is self.MEMORY:
            # TODO(pwbug/384): Make this the 'Encoder'
            return 'RamEncoder'

        raise ValueError('Unknown encoder type')


# protoc captures stdout, so we need to printf debug to stderr.
def debug_print(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)


class ProtoMethod(abc.ABC):
    """Base class for a C++ method for a field in a protobuf message."""
    def __init__(
        self,
        field: ProtoMessageField,
        scope: ProtoNode,
        root: ProtoNode,
    ):
        """Creates an instance of a method.

        Args:
          field: the ProtoMessageField to which the method belongs.
          scope: the ProtoNode namespace in which the method is being defined.
        """
        self._field: ProtoMessageField = field
        self._scope: ProtoNode = scope
        self._root: ProtoNode = root

    @abc.abstractmethod
    def name(self) -> str:
        """Returns the name of the method, e.g. DoSomething."""

    @abc.abstractmethod
    def params(self) -> List[Tuple[str, str]]:
        """Returns the parameters of the method as a list of (type, name) pairs.

        e.g.
        [('int', 'foo'), ('const char*', 'bar')]
        """

    @abc.abstractmethod
    def body(self, encoder_type: EncoderType) -> List[str]:
        """Returns the method body as a list of source code lines.

        e.g.
        [
          'int baz = bar[foo];',
          'return (baz ^ foo) >> 3;'
        ]
        """

    @abc.abstractmethod
    def return_type(self,
                    encoder_type: EncoderType,
                    from_root: bool = False) -> str:
        """Returns the return type of the method, e.g. int.

        For non-primitive return types, the from_root argument determines
        whether the namespace should be relative to the message's scope
        (default) or the root scope.
        """

    @abc.abstractmethod
    def in_class_definition(self) -> bool:
        """Determines where the method should be defined.

        Returns True if the method definition should be inlined in its class
        definition, or False if it should be declared in the class and defined
        later.
        """

    def should_appear(self) -> bool:  # pylint: disable=no-self-use
        """Whether the method should be generated."""
        return True

    def param_string(self) -> str:
        return ', '.join([f'{type} {name}' for type, name in self.params()])

    def field_cast(self) -> str:
        return 'static_cast<uint32_t>(Fields::{})'.format(
            self._field.enum_name())

    def _relative_type_namespace(self, from_root: bool = False) -> str:
        """Returns relative namespace between method's scope and field type."""
        scope = self._root if from_root else self._scope
        type_node = self._field.type_node()
        assert type_node is not None

        # If a class method is referencing its class, the namespace provided
        # must be from the root or it will be empty.
        if type_node == scope:
            scope = self._root

        ancestor = scope.common_ancestor(type_node)
        namespace = type_node.cpp_namespace(ancestor)
        assert namespace
        return namespace


class SubMessageMethod(ProtoMethod):
    """Method which returns a sub-message encoder."""
    def name(self) -> str:
        return 'Get{}Encoder'.format(self._field.name())

    def return_type(self,
                    encoder_type: EncoderType,
                    from_root: bool = False) -> str:
        if encoder_type == EncoderType.LEGACY:
            return '{}::Encoder'.format(
                self._relative_type_namespace(from_root))

        return '{}::StreamEncoder'.format(
            self._relative_type_namespace(from_root))

    def params(self) -> List[Tuple[str, str]]:
        return []

    def body(self, encoder_type: EncoderType) -> List[str]:
        line: str = ''
        if encoder_type == EncoderType.LEGACY:
            line = 'return {}::Encoder(encoder_, {});'.format(
                self._relative_type_namespace(), self.field_cast())
        else:
            line = 'return {}::StreamEncoder(GetNestedEncoder({}));'.format(
                self._relative_type_namespace(), self.field_cast())
        return [line]

    # Submessage methods are not defined within the class itself because the
    # submessage class may not yet have been defined.
    def in_class_definition(self) -> bool:
        return False


class WriteMethod(ProtoMethod):
    """Base class representing an encoder write method.

    Write methods have following format (for the proto field foo):

        Status WriteFoo({params...}) {
          return encoder_->Write{type}(kFoo, {params...});
        }

    """
    def name(self) -> str:
        return 'Write{}'.format(self._field.name())

    def return_type(self,
                    encoder_type: EncoderType,
                    from_root: bool = False) -> str:
        return '::pw::Status'

    def body(self, encoder_type: EncoderType) -> List[str]:
        params = ', '.join([pair[1] for pair in self.params()])
        if encoder_type == EncoderType.LEGACY:
            line = 'return encoder_->{}({}, {});'.format(
                self._encoder_fn(), self.field_cast(), params)
        else:
            line = 'return {}({}, {});'.format(self._encoder_fn(),
                                               self.field_cast(), params)
        return [line]

    def params(self) -> List[Tuple[str, str]]:
        """Method parameters, defined in subclasses."""
        raise NotImplementedError()

    def in_class_definition(self) -> bool:
        return True

    def _encoder_fn(self) -> str:
        """The encoder function to call.

        Defined in subclasses.

        e.g. 'WriteUint32', 'WriteBytes', etc.
        """
        raise NotImplementedError()


class PackedMethod(WriteMethod):
    """A method for a packed repeated field.

    Same as a WriteMethod, but is only generated for repeated fields.
    """
    def should_appear(self) -> bool:
        return self._field.is_repeated()

    def _encoder_fn(self) -> str:
        raise NotImplementedError()


#
# The following code defines write methods for each of the
# primitive protobuf types.
#


class DoubleMethod(WriteMethod):
    """Method which writes a proto double value."""
    def params(self) -> List[Tuple[str, str]]:
        return [('double', 'value')]

    def _encoder_fn(self) -> str:
        return 'WriteDouble'


class PackedDoubleMethod(PackedMethod):
    """Method which writes a packed list of doubles."""
    def params(self) -> List[Tuple[str, str]]:
        return [('std::span<const double>', 'values')]

    def _encoder_fn(self) -> str:
        return 'WritePackedDouble'


class FloatMethod(WriteMethod):
    """Method which writes a proto float value."""
    def params(self) -> List[Tuple[str, str]]:
        return [('float', 'value')]

    def _encoder_fn(self) -> str:
        return 'WriteFloat'


class PackedFloatMethod(PackedMethod):
    """Method which writes a packed list of floats."""
    def params(self) -> List[Tuple[str, str]]:
        return [('std::span<const float>', 'values')]

    def _encoder_fn(self) -> str:
        return 'WritePackedFloat'


class Int32Method(WriteMethod):
    """Method which writes a proto int32 value."""
    def params(self) -> List[Tuple[str, str]]:
        return [('int32_t', 'value')]

    def _encoder_fn(self) -> str:
        return 'WriteInt32'


class PackedInt32Method(PackedMethod):
    """Method which writes a packed list of int32."""
    def params(self) -> List[Tuple[str, str]]:
        return [('std::span<const int32_t>', 'values')]

    def _encoder_fn(self) -> str:
        return 'WritePackedInt32'


class Sint32Method(WriteMethod):
    """Method which writes a proto sint32 value."""
    def params(self) -> List[Tuple[str, str]]:
        return [('int32_t', 'value')]

    def _encoder_fn(self) -> str:
        return 'WriteSint32'


class PackedSint32Method(PackedMethod):
    """Method which writes a packed list of sint32."""
    def params(self) -> List[Tuple[str, str]]:
        return [('std::span<const int32_t>', 'values')]

    def _encoder_fn(self) -> str:
        return 'WritePackedSint32'


class Sfixed32Method(WriteMethod):
    """Method which writes a proto sfixed32 value."""
    def params(self) -> List[Tuple[str, str]]:
        return [('int32_t', 'value')]

    def _encoder_fn(self) -> str:
        return 'WriteSfixed32'


class PackedSfixed32Method(PackedMethod):
    """Method which writes a packed list of sfixed32."""
    def params(self) -> List[Tuple[str, str]]:
        return [('std::span<const int32_t>', 'values')]

    def _encoder_fn(self) -> str:
        return 'WritePackedSfixed32'


class Int64Method(WriteMethod):
    """Method which writes a proto int64 value."""
    def params(self) -> List[Tuple[str, str]]:
        return [('int64_t', 'value')]

    def _encoder_fn(self) -> str:
        return 'WriteInt64'


class PackedInt64Method(PackedMethod):
    """Method which writes a proto int64 value."""
    def params(self) -> List[Tuple[str, str]]:
        return [('std::span<const int64_t>', 'values')]

    def _encoder_fn(self) -> str:
        return 'WritePackedInt64'


class Sint64Method(WriteMethod):
    """Method which writes a proto sint64 value."""
    def params(self) -> List[Tuple[str, str]]:
        return [('int64_t', 'value')]

    def _encoder_fn(self) -> str:
        return 'WriteSint64'


class PackedSint64Method(PackedMethod):
    """Method which writes a proto sint64 value."""
    def params(self) -> List[Tuple[str, str]]:
        return [('std::span<const int64_t>', 'values')]

    def _encoder_fn(self) -> str:
        return 'WritePackedSint64'


class Sfixed64Method(WriteMethod):
    """Method which writes a proto sfixed64 value."""
    def params(self) -> List[Tuple[str, str]]:
        return [('int64_t', 'value')]

    def _encoder_fn(self) -> str:
        return 'WriteSfixed64'


class PackedSfixed64Method(PackedMethod):
    """Method which writes a proto sfixed64 value."""
    def params(self) -> List[Tuple[str, str]]:
        return [('std::span<const int64_t>', 'values')]

    def _encoder_fn(self) -> str:
        return 'WritePackedSfixed4'


class Uint32Method(WriteMethod):
    """Method which writes a proto uint32 value."""
    def params(self) -> List[Tuple[str, str]]:
        return [('uint32_t', 'value')]

    def _encoder_fn(self) -> str:
        return 'WriteUint32'


class PackedUint32Method(PackedMethod):
    """Method which writes a proto uint32 value."""
    def params(self) -> List[Tuple[str, str]]:
        return [('std::span<const uint32_t>', 'values')]

    def _encoder_fn(self) -> str:
        return 'WritePackedUint32'


class Fixed32Method(WriteMethod):
    """Method which writes a proto fixed32 value."""
    def params(self) -> List[Tuple[str, str]]:
        return [('uint32_t', 'value')]

    def _encoder_fn(self) -> str:
        return 'WriteFixed32'


class PackedFixed32Method(PackedMethod):
    """Method which writes a proto fixed32 value."""
    def params(self) -> List[Tuple[str, str]]:
        return [('std::span<const uint32_t>', 'values')]

    def _encoder_fn(self) -> str:
        return 'WritePackedFixed32'


class Uint64Method(WriteMethod):
    """Method which writes a proto uint64 value."""
    def params(self) -> List[Tuple[str, str]]:
        return [('uint64_t', 'value')]

    def _encoder_fn(self) -> str:
        return 'WriteUint64'


class PackedUint64Method(PackedMethod):
    """Method which writes a proto uint64 value."""
    def params(self) -> List[Tuple[str, str]]:
        return [('std::span<const uint64_t>', 'values')]

    def _encoder_fn(self) -> str:
        return 'WritePackedUint64'


class Fixed64Method(WriteMethod):
    """Method which writes a proto fixed64 value."""
    def params(self) -> List[Tuple[str, str]]:
        return [('uint64_t', 'value')]

    def _encoder_fn(self) -> str:
        return 'WriteFixed64'


class PackedFixed64Method(PackedMethod):
    """Method which writes a proto fixed64 value."""
    def params(self) -> List[Tuple[str, str]]:
        return [('std::span<const uint64_t>', 'values')]

    def _encoder_fn(self) -> str:
        return 'WritePackedFixed64'


class BoolMethod(WriteMethod):
    """Method which writes a proto bool value."""
    def params(self) -> List[Tuple[str, str]]:
        return [('bool', 'value')]

    def _encoder_fn(self) -> str:
        return 'WriteBool'


class BytesMethod(WriteMethod):
    """Method which writes a proto bytes value."""
    def params(self) -> List[Tuple[str, str]]:
        return [('std::span<const std::byte>', 'value')]

    def _encoder_fn(self) -> str:
        return 'WriteBytes'


class StringLenMethod(WriteMethod):
    """Method which writes a proto string value with length."""
    def params(self) -> List[Tuple[str, str]]:
        return [('const char*', 'value'), ('size_t', 'len')]

    def _encoder_fn(self) -> str:
        return 'WriteString'


class StringMethod(WriteMethod):
    """Method which writes a proto string value."""
    def params(self) -> List[Tuple[str, str]]:
        return [('const char*', 'value')]

    def _encoder_fn(self) -> str:
        return 'WriteString'


class EnumMethod(WriteMethod):
    """Method which writes a proto enum value."""
    def params(self) -> List[Tuple[str, str]]:
        return [(self._relative_type_namespace(), 'value')]

    def body(self, encoder_type: EncoderType) -> List[str]:
        line: str = ''
        if encoder_type == EncoderType.LEGACY:
            line = 'return encoder_->WriteUint32(' \
                '{}, static_cast<uint32_t>(value));'.format(self.field_cast())
        else:
            line = 'return WriteUint32(' \
                '{}, static_cast<uint32_t>(value));'.format(self.field_cast())
        return [line]

    def in_class_definition(self) -> bool:
        return True

    def _encoder_fn(self) -> str:
        raise NotImplementedError()


# Mapping of protobuf field types to their method definitions.
PROTO_FIELD_METHODS: Dict[int, List] = {
    descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE:
    [DoubleMethod, PackedDoubleMethod],
    descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT:
    [FloatMethod, PackedFloatMethod],
    descriptor_pb2.FieldDescriptorProto.TYPE_INT32:
    [Int32Method, PackedInt32Method],
    descriptor_pb2.FieldDescriptorProto.TYPE_SINT32:
    [Sint32Method, PackedSint32Method],
    descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED32:
    [Sfixed32Method, PackedSfixed32Method],
    descriptor_pb2.FieldDescriptorProto.TYPE_INT64:
    [Int64Method, PackedInt64Method],
    descriptor_pb2.FieldDescriptorProto.TYPE_SINT64:
    [Sint64Method, PackedSint64Method],
    descriptor_pb2.FieldDescriptorProto.TYPE_SFIXED64:
    [Sfixed64Method, PackedSfixed64Method],
    descriptor_pb2.FieldDescriptorProto.TYPE_UINT32:
    [Uint32Method, PackedUint32Method],
    descriptor_pb2.FieldDescriptorProto.TYPE_FIXED32:
    [Fixed32Method, PackedFixed32Method],
    descriptor_pb2.FieldDescriptorProto.TYPE_UINT64:
    [Uint64Method, PackedUint64Method],
    descriptor_pb2.FieldDescriptorProto.TYPE_FIXED64:
    [Fixed64Method, PackedFixed64Method],
    descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: [BoolMethod],
    descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: [BytesMethod],
    descriptor_pb2.FieldDescriptorProto.TYPE_STRING: [
        StringLenMethod, StringMethod
    ],
    descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE: [SubMessageMethod],
    descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: [EnumMethod],
}


def generate_code_for_message(message: ProtoMessage, root: ProtoNode,
                              output: OutputFile,
                              encoder_type: EncoderType) -> None:
    """Creates a C++ class for a protobuf message."""
    assert message.type() == ProtoNode.Type.MESSAGE

    base_class_name = encoder_type.base_class_name()
    encoder_name = encoder_type.codegen_class_name()

    # Message classes inherit from the base proto message class in codegen.h
    # and use its constructor.
    base_class = f'{PROTOBUF_NAMESPACE}::{base_class_name}'
    output.write_line(
        f'class {message.cpp_namespace(root)}::{encoder_name} ' \
        f': public {base_class} {{'
    )
    output.write_line(' public:')

    with output.indent():
        # Inherit the constructors from the base encoder.
        output.write_line(f'using {base_class_name}::{base_class_name};')

        # Declare a move constructor that takes a base encoder.
        if encoder_type != EncoderType.LEGACY:
            output.write_line(
                f'constexpr {encoder_name}({base_class_name}&& parent) '\
                f': {base_class_name}(std::move(parent)) {{}}'
            )

        # Generate methods for each of the message's fields.
        for field in message.fields():
            for method_class in PROTO_FIELD_METHODS[field.type()]:
                method = method_class(field, message, root)
                if not method.should_appear():
                    continue

                output.write_line()
                method_signature = (
                    f'{method.return_type(encoder_type)} '
                    f'{method.name()}({method.param_string()})')

                if not method.in_class_definition():
                    # Method will be defined outside of the class at the end of
                    # the file.
                    output.write_line(f'{method_signature};')
                    continue

                output.write_line(f'{method_signature} {{')
                with output.indent():
                    for line in method.body(encoder_type):
                        output.write_line(line)
                output.write_line('}')

    output.write_line('};')


def define_not_in_class_methods(message: ProtoMessage, root: ProtoNode,
                                output: OutputFile,
                                encoder_type: EncoderType) -> None:
    """Defines methods for a message class that were previously declared."""
    assert message.type() == ProtoNode.Type.MESSAGE

    for field in message.fields():
        for method_class in PROTO_FIELD_METHODS[field.type()]:
            method = method_class(field, message, root)
            if not method.should_appear() or method.in_class_definition():
                continue

            output.write_line()
            class_name = (f'{message.cpp_namespace(root)}::'
                          f'{encoder_type.codegen_class_name()}')
            method_signature = (
                f'inline {method.return_type(encoder_type, from_root=True)} '
                f'{class_name}::{method.name()}({method.param_string()})')
            output.write_line(f'{method_signature} {{')
            with output.indent():
                for line in method.body(encoder_type):
                    output.write_line(line)
            output.write_line('}')


def generate_code_for_enum(proto_enum: ProtoEnum, root: ProtoNode,
                           output: OutputFile) -> None:
    """Creates a C++ enum for a proto enum."""
    assert proto_enum.type() == ProtoNode.Type.ENUM

    output.write_line(f'enum class {proto_enum.cpp_namespace(root)} {{')
    with output.indent():
        for name, number in proto_enum.values():
            output.write_line(f'{name} = {number},')
    output.write_line('};')


def forward_declare(node: ProtoMessage, root: ProtoNode,
                    output: OutputFile) -> None:
    """Generates code forward-declaring entities in a message's namespace."""
    namespace = node.cpp_namespace(root)
    output.write_line()
    output.write_line(f'namespace {namespace} {{')

    # Define an enum defining each of the message's fields and their numbers.
    output.write_line('enum class Fields {')
    with output.indent():
        for field in node.fields():
            output.write_line(f'{field.enum_name()} = {field.number()},')
    output.write_line('};')

    # Declare the message's encoder class and all of its enums.
    output.write_line()
    output.write_line('class Encoder;')
    output.write_line('class StreamEncoder;')
    output.write_line('class RamEncoder;')

    for child in node.children():
        if child.type() == ProtoNode.Type.ENUM:
            output.write_line()
            generate_code_for_enum(cast(ProtoEnum, child), node, output)

    output.write_line(f'}}  // namespace {namespace}')


def generate_encoder_wrappers(package: ProtoNode, encoder_type: EncoderType,
                              output: OutputFile):
    # Run through all messages in the file, generating a class for each.
    for node in package:
        if node.type() == ProtoNode.Type.MESSAGE:
            output.write_line()
            generate_code_for_message(cast(ProtoMessage, node), package,
                                      output, encoder_type)

    # Run a second pass through the classes, this time defining all of the
    # methods which were previously only declared.
    for node in package:
        if node.type() == ProtoNode.Type.MESSAGE:
            define_not_in_class_methods(cast(ProtoMessage, node), package,
                                        output, encoder_type)


def _proto_filename_to_generated_header(proto_file: str) -> str:
    """Returns the generated C++ header name for a .proto file."""
    return os.path.splitext(proto_file)[0] + PROTO_H_EXTENSION


def generate_code_for_package(file_descriptor_proto, package: ProtoNode,
                              output: OutputFile) -> None:
    """Generates code for a single .pb.h file corresponding to a .proto file."""

    assert package.type() == ProtoNode.Type.PACKAGE

    output.write_line(f'// {os.path.basename(output.name())} automatically '
                      f'generated by {PLUGIN_NAME} {PLUGIN_VERSION}')
    output.write_line(f'// on {datetime.now()}')
    output.write_line('#pragma once\n')
    output.write_line('#include <cstddef>')
    output.write_line('#include <cstdint>')
    output.write_line('#include <span>\n')
    output.write_line('#include "pw_protobuf/codegen.h"')
    output.write_line('#include "pw_protobuf/streaming_encoder.h"')

    for imported_file in file_descriptor_proto.dependency:
        generated_header = _proto_filename_to_generated_header(imported_file)
        output.write_line(f'#include "{generated_header}"')

    if package.cpp_namespace():
        file_namespace = package.cpp_namespace()
        if file_namespace.startswith('::'):
            file_namespace = file_namespace[2:]

        output.write_line(f'\nnamespace {file_namespace} {{')

    for node in package:
        if node.type() == ProtoNode.Type.MESSAGE:
            forward_declare(cast(ProtoMessage, node), package, output)

    # Define all top-level enums.
    for node in package.children():
        if node.type() == ProtoNode.Type.ENUM:
            output.write_line()
            generate_code_for_enum(cast(ProtoEnum, node), package, output)

    generate_encoder_wrappers(package, EncoderType.LEGACY, output)
    generate_encoder_wrappers(package, EncoderType.STREAMING, output)
    generate_encoder_wrappers(package, EncoderType.MEMORY, output)

    if package.cpp_namespace():
        output.write_line(f'\n}}  // namespace {package.cpp_namespace()}')


def process_proto_file(proto_file) -> Iterable[OutputFile]:
    """Generates code for a single .proto file."""

    # Two passes are made through the file. The first builds the tree of all
    # message/enum nodes, then the second creates the fields in each. This is
    # done as non-primitive fields need pointers to their types, which requires
    # the entire tree to have been parsed into memory.
    _, package_root = build_node_tree(proto_file)

    output_filename = _proto_filename_to_generated_header(proto_file.name)
    output_file = OutputFile(output_filename)
    generate_code_for_package(proto_file, package_root, output_file)

    return [output_file]
