#!/usr/bin/env python3
"""
MCTP Packet Parser
Parses MCTP packets based on DSP0236 v1.3.1 and SPDM over MCTP (DSP0275 v1.0.2)
Parses SPDM messages based on DSP0274 v1.3.0
"""

import sys
import argparse
from typing import List, Optional, Dict


class SPDMParser:
    """Parser for SPDM messages (DSP0274)"""

    # Request codes (DSP0274 Table 4)
    REQUEST_CODES = {
        0x81: "GET_DIGESTS",
        0x82: "GET_CERTIFICATE",
        0x83: "CHALLENGE",
        0x84: "GET_VERSION",
        0x85: "CHUNK_SEND",
        0x86: "CHUNK_GET",
        0x87: "GET_ENDPOINT_INFO",
        0xE0: "GET_MEASUREMENTS",
        0xE1: "GET_CAPABILITIES",
        0xE2: "GET_SUPPORTED_EVENT_TYPES",
        0xE3: "NEGOTIATE_ALGORITHMS",
        0xE4: "KEY_EXCHANGE",
        0xE5: "FINISH",
        0xE6: "PSK_EXCHANGE",
        0xE7: "PSK_FINISH",
        0xE8: "HEARTBEAT",
        0xE9: "KEY_UPDATE",
        0xEA: "GET_ENCAPSULATED_REQUEST",
        0xEB: "DELIVER_ENCAPSULATED_RESPONSE",
        0xEC: "END_SESSION",
        0xED: "GET_CSR",
        0xEE: "SET_CERTIFICATE",
        0xEF: "GET_MEASUREMENT_EXTENSION_LOG",
        0xF0: "SUBSCRIBE_EVENT_TYPES",
        0xF1: "SEND_EVENT",
        0xFC: "GET_KEY_PAIR_INFO",
        0xFD: "SET_KEY_PAIR_INFO",
        0xFE: "VENDOR_DEFINED_REQUEST",
        0xFF: "RESPOND_IF_READY",
    }

    # Response codes (DSP0274 Table 5)
    RESPONSE_CODES = {
        0x01: "DIGESTS",
        0x02: "CERTIFICATE",
        0x03: "CHALLENGE_AUTH",
        0x04: "VERSION",
        0x05: "CHUNK_SEND_ACK",
        0x06: "CHUNK_RESPONSE",
        0x07: "ENDPOINT_INFO",
        0x60: "MEASUREMENTS",
        0x61: "CAPABILITIES",
        0x62: "SUPPORTED_EVENT_TYPES",
        0x63: "ALGORITHMS",
        0x64: "KEY_EXCHANGE_RSP",
        0x65: "FINISH_RSP",
        0x66: "PSK_EXCHANGE_RSP",
        0x67: "PSK_FINISH_RSP",
        0x68: "HEARTBEAT_ACK",
        0x69: "KEY_UPDATE_ACK",
        0x6A: "ENCAPSULATED_REQUEST",
        0x6B: "ENCAPSULATED_RESPONSE_ACK",
        0x6C: "END_SESSION_ACK",
        0x6D: "CSR",
        0x6E: "SET_CERTIFICATE_RSP",
        0x6F: "MEASUREMENT_EXTENSION_LOG",
        0x70: "SUBSCRIBE_EVENT_TYPES_ACK",
        0x71: "EVENT_ACK",
        0x7C: "KEY_PAIR_INFO",
        0x7D: "SET_KEY_PAIR_INFO_ACK",
        0x7E: "VENDOR_DEFINED_RESPONSE",
        0x7F: "ERROR",
    }

    @staticmethod
    def parse(payload: List[int]) -> Optional[Dict]:
        """Parse SPDM message from payload"""
        if len(payload) < 2:
            return None

        result = {}

        # Byte 0: SPDM version (major [7:4], minor [3:0])
        version_byte = payload[0]
        result['version_major'] = (version_byte >> 4) & 0x0F
        result['version_minor'] = version_byte & 0x0F
        result['version'] = f"{result['version_major']}.{result['version_minor']}"

        # Byte 1: Request/Response code
        code = payload[1]
        result['code'] = code

        # Determine if request or response
        if code in SPDMParser.REQUEST_CODES:
            result['msg_direction'] = "Request"
            result['msg_name'] = SPDMParser.REQUEST_CODES[code]
        elif code in SPDMParser.RESPONSE_CODES:
            result['msg_direction'] = "Response"
            result['msg_name'] = SPDMParser.RESPONSE_CODES[code]
        else:
            result['msg_direction'] = "Unknown"
            result['msg_name'] = f"Unknown (0x{code:02X})"

        # Byte 2: Param1 (if present)
        if len(payload) > 2:
            result['param1'] = payload[2]

        # Byte 3: Param2 (if present)
        if len(payload) > 3:
            result['param2'] = payload[3]

        # Remaining data
        if len(payload) > 4:
            result['data'] = payload[4:]
            result['data_len'] = len(result['data'])
        else:
            result['data'] = []
            result['data_len'] = 0

        return result


class MCTPParser:
    """Parser for MCTP packets"""

    # Message type definitions (DSP0236 Table 3)
    MESSAGE_TYPES = {
        0x00: "MCTP Control Message",
        0x01: "PLDM",
        0x02: "NC-SI over MCTP",
        0x03: "Ethernet over MCTP",
        0x04: "NVMe-MI over MCTP",
        0x05: "SPDM over MCTP",
        0x06: "SECDED over MCTP",
        0x07: "CXL FM API over MCTP",
        0x08: "CXL CCI over MCTP",
    }

    def __init__(self, data: List[int]):
        """Initialize parser with packet data"""
        self.data = data
        self.offset = 0

    def parse(self) -> dict:
        """Parse MCTP packet and return structured data"""
        if len(self.data) < 4:
            raise ValueError("Packet too short - minimum 4 bytes required for MCTP header")

        result = {}

        # Byte 0: Destination EID
        result['dest_eid'] = self.data[0]

        # Byte 1: Header Version [7:4], Reserved [3:0]
        byte1 = self.data[1]
        result['header_version'] = (byte1 >> 4) & 0x0F
        result['rsvd'] = byte1 & 0x0F

        # Byte 2: Source EID
        result['src_eid'] = self.data[2]

        # Byte 3: SOM, EOM, Pkt_Seq, TO, Msg_Tag
        byte3 = self.data[3]
        result['som'] = bool(byte3 & 0x80)  # Start of Message
        result['eom'] = bool(byte3 & 0x40)  # End of Message
        result['pkt_seq'] = (byte3 >> 4) & 0x03  # Packet sequence number
        result['to'] = bool(byte3 & 0x08)  # Tag Owner
        result['msg_tag'] = byte3 & 0x07  # Message Tag

        # Message body starts at byte 4
        if len(self.data) > 4:
            # Byte 4: IC [7], Message Type [6:0]
            byte4 = self.data[4]
            result['ic'] = bool(byte4 & 0x80)  # Integrity Check
            result['msg_type'] = byte4 & 0x7F
            result['msg_type_name'] = self.MESSAGE_TYPES.get(
                result['msg_type'],
                f"Vendor Defined (0x{result['msg_type']:02X})" if result['msg_type'] >= 0x7E
                else f"Reserved (0x{result['msg_type']:02X})"
            )

            # Payload starts at byte 5
            if len(self.data) > 5:
                result['payload'] = self.data[5:]
                result['payload_len'] = len(result['payload'])

                # Parse SPDM if message type is 0x05
                if result['msg_type'] == 0x05 and result['payload_len'] > 0:
                    result['spdm'] = SPDMParser.parse(result['payload'])
            else:
                result['payload'] = []
                result['payload_len'] = 0

        return result

    def format_output(self, parsed: dict) -> str:
        """Format parsed data as human-readable output"""
        lines = []
        lines.append("=" * 70)
        lines.append("MCTP Packet Analysis")
        lines.append("=" * 70)

        lines.append("\nTransport Header:")
        lines.append(f"  Destination EID:    0x{parsed['dest_eid']:02X} ({parsed['dest_eid']})")
        lines.append(f"  Source EID:         0x{parsed['src_eid']:02X} ({parsed['src_eid']})")
        lines.append(f"  Header Version:     {parsed['header_version']}")

        lines.append("\nMessage Framing:")
        lines.append(f"  SOM (Start):        {'Yes' if parsed['som'] else 'No'}")
        lines.append(f"  EOM (End):          {'Yes' if parsed['eom'] else 'No'}")
        lines.append(f"  Packet Sequence:    {parsed['pkt_seq']}")
        lines.append(f"  Tag Owner (TO):     {parsed['to']}")
        lines.append(f"  Message Tag:        {parsed['msg_tag']}")

        if 'msg_type' in parsed:
            lines.append("\nMessage Body:")
            lines.append(f"  Integrity Check:    {'Yes' if parsed['ic'] else 'No'}")
            lines.append(f"  Message Type:       0x{parsed['msg_type']:02X} - {parsed['msg_type_name']}")

            # SPDM message parsing
            if 'spdm' in parsed and parsed['spdm']:
                spdm = parsed['spdm']
                lines.append("\n" + "=" * 70)
                lines.append("SPDM Message (DSP0274)")
                lines.append("=" * 70)
                lines.append(f"  SPDM Version:       {spdm['version']}")
                lines.append(f"  Message Direction:  {spdm['msg_direction']}")
                lines.append(f"  Message Code:       0x{spdm['code']:02X}")
                lines.append(f"  Message Name:       {spdm['msg_name']}")

                if 'param1' in spdm:
                    lines.append(f"  Param1:             0x{spdm['param1']:02X} ({spdm['param1']})")
                if 'param2' in spdm:
                    lines.append(f"  Param2:             0x{spdm['param2']:02X} ({spdm['param2']})")

                if spdm['data_len'] > 0:
                    lines.append(f"\n  SPDM Data ({spdm['data_len']} bytes):")
                    data = spdm['data']
                    for i in range(0, len(data), 16):
                        chunk = data[i:i+16]
                        hex_str = ' '.join(f'{b:02X}' for b in chunk)
                        ascii_str = ''.join(chr(b) if 32 <= b < 127 else '.' for b in chunk)
                        lines.append(f"    {i:04X}:  {hex_str:<48}  {ascii_str}")

            elif parsed['payload_len'] > 0:
                lines.append(f"\nPayload ({parsed['payload_len']} bytes):")
                # Format payload in hex dump style
                payload = parsed['payload']
                for i in range(0, len(payload), 16):
                    chunk = payload[i:i+16]
                    hex_str = ' '.join(f'{b:02X}' for b in chunk)
                    ascii_str = ''.join(chr(b) if 32 <= b < 127 else '.' for b in chunk)
                    lines.append(f"  {i:04X}:  {hex_str:<48}  {ascii_str}")

        lines.append("=" * 70)
        return '\n'.join(lines)


def calculate_pec(data: List[int]) -> int:
    """
    Calculate SMBus PEC (Packet Error Code) using CRC-8.

    SMBus uses CRC-8 with polynomial 0x07 (x^8 + x^2 + x + 1).
    Initial value is 0x00.

    Args:
        data: List of bytes to calculate PEC for

    Returns:
        PEC value (0-255)
    """
    crc = 0x00
    polynomial = 0x07

    for byte in data:
        crc ^= byte
        for _ in range(8):
            if crc & 0x80:
                crc = (crc << 1) ^ polynomial
            else:
                crc = crc << 1
            crc &= 0xFF

    return crc


def parse_hex_bytes(hex_string: str) -> List[int]:
    """Parse space-separated hex bytes without 0x prefix"""
    hex_string = hex_string.strip()
    if not hex_string:
        return []

    # Split on whitespace and filter empty strings
    parts = [p.strip() for p in hex_string.split() if p.strip()]

    bytes_list = []
    for part in parts:
        # Remove any 0x prefix if present
        if part.lower().startswith('0x'):
            part = part[2:]

        try:
            bytes_list.append(int(part, 16))
        except ValueError:
            raise ValueError(f"Invalid hex byte: {part}")

    return bytes_list


def parse_smbus_transport(data: List[int]) -> Optional[Dict]:
    """
    Parse SMBus/I2C transport binding header (DSP0237).
    Returns parsed transport info or None if not SMBus format.
    """
    if len(data) < 3:
        return None

    # Check for MCTP over SMBus command code (0x0F)
    if data[0] != 0x0F:
        return None

    result = {}
    result['cmd_code'] = data[0]
    result['byte_count'] = data[1]

    expected_total = 2 + result['byte_count'] + 1  # cmd + count + data + PEC
    has_pec = (len(data) == expected_total)

    if has_pec:
        result['pec_received'] = data[-1]
        # Calculate expected PEC over all bytes except the PEC itself
        # For MCTP over SMBus, PEC is calculated over: cmd_code + byte_count + all data bytes
        pec_data = data[:-1]
        result['pec_calculated'] = calculate_pec(pec_data)
        result['pec_valid'] = (result['pec_received'] == result['pec_calculated'])
        data_end = -1
    else:
        result['pec_received'] = None
        result['pec_calculated'] = None
        result['pec_valid'] = None
        data_end = len(data)

    # SMBus-specific headers start at byte 2
    if result['byte_count'] < 5:
        return None

    offset = 2
    result['source_slave_addr'] = data[offset]
    offset += 1

    # MCTP reserved + header version
    byte_hdr = data[offset]
    result['mctp_reserved'] = (byte_hdr >> 4) & 0x0F
    result['mctp_hdr_version'] = byte_hdr & 0x0F
    offset += 1

    # Destination and Source EIDs are swapped in SMBus binding
    result['dest_eid'] = data[offset]
    offset += 1
    result['src_eid'] = data[offset]
    offset += 1

    # MCTP packet starts here (SOM/EOM byte)
    result['mctp_offset'] = offset

    return result


def main():
    parser = argparse.ArgumentParser(
        description='Parse MCTP packets (DSP0236 v1.3.1)',
        epilog='Data format: Space-separated hex bytes without 0x prefix (e.g., "01 02 03 04")'
    )

    parser.add_argument(
        '-i', '--input',
        type=str,
        metavar='FILE',
        help='Read packet data from file instead of command line'
    )

    parser.add_argument(
        '--no-transport',
        action='store_true',
        help='Skip I2C transport binding header detection, treat as raw MCTP'
    )

    parser.add_argument(
        '--append-pec',
        action='store_true',
        help='Calculate PEC for all input bytes (default: assumes last byte is PEC)'
    )

    parser.add_argument(
        '--dest-addr',
        type=str,
        metavar='ADDR',
        help='I2C destination address (7-bit) to prepend in PEC calculation, like libmctp does (e.g., "0x10" or "10")'
    )

    parser.add_argument(
        'data',
        nargs='*',
        help='Hex bytes to parse (space-separated, no 0x prefix)'
    )

    args = parser.parse_args()

    # Get data from file or command line
    if args.input:
        try:
            with open(args.input, 'r') as f:
                data_str = f.read()
        except FileNotFoundError:
            print(f"Error: File not found: {args.input}", file=sys.stderr)
            return 1
        except IOError as e:
            print(f"Error reading file: {e}", file=sys.stderr)
            return 1
    elif args.data:
        data_str = ' '.join(args.data)
    else:
        parser.print_help()
        return 1

    # Parse hex bytes
    try:
        packet_bytes = parse_hex_bytes(data_str)
    except ValueError as e:
        print(f"Error parsing hex data: {e}", file=sys.stderr)
        return 1

    if not packet_bytes:
        print("Error: No data provided", file=sys.stderr)
        return 1

    # Parse destination address if provided
    dest_addr_byte = None
    if args.dest_addr:
        try:
            addr_str = args.dest_addr.strip()
            if addr_str.lower().startswith('0x'):
                addr_str = addr_str[2:]
            dest_addr_7bit = int(addr_str, 16)
            if dest_addr_7bit > 0x7F:
                print(f"Error: Destination address must be 7-bit (0x00-0x7F), got 0x{dest_addr_7bit:02X}", file=sys.stderr)
                return 1
            # Convert to I2C write address (7-bit address << 1, write bit = 0)
            dest_addr_byte = (dest_addr_7bit << 1) | 0
        except ValueError:
            print(f"Error: Invalid destination address: {args.dest_addr}", file=sys.stderr)
            return 1

    # Detect and handle SMBus/I2C transport binding
    smbus_info = None
    if not args.no_transport:
        smbus_info = parse_smbus_transport(packet_bytes)

    if smbus_info:
        # Print SMBus transport header
        print("=" * 70)
        print("SMBus/I2C Transport Binding (DSP0237)")
        print("=" * 70)
        print(f"  Command Code:       0x{smbus_info['cmd_code']:02X}")
        print(f"  Byte Count:         {smbus_info['byte_count']}")
        print(f"  Source Slave Addr:  0x{smbus_info['source_slave_addr']:02X}")
        print(f"  MCTP Reserved:      0x{smbus_info['mctp_reserved']:X}")
        print(f"  MCTP Header Ver:    {smbus_info['mctp_hdr_version']}")
        print(f"  Destination EID:    0x{smbus_info['dest_eid']:02X} ({smbus_info['dest_eid']})")
        print(f"  Source EID:         0x{smbus_info['src_eid']:02X} ({smbus_info['src_eid']})")
        if smbus_info['pec_received'] is not None:
            print(f"  PEC (Received):     0x{smbus_info['pec_received']:02X}")
            print(f"  PEC (Calculated):   0x{smbus_info['pec_calculated']:02X}")
            if smbus_info['pec_valid']:
                print(f"  PEC Status:         ✓ VALID")
            else:
                print(f"  PEC Status:         ✗ INVALID (mismatch)")
        print()

        # Extract MCTP packet portion (starting from SOM/EOM byte)
        mctp_start = smbus_info['mctp_offset']
        if smbus_info['pec_received'] is not None:
            mctp_end = 2 + smbus_info['byte_count']  # Exclude PEC
        else:
            mctp_end = len(packet_bytes)

        mctp_bytes = packet_bytes[mctp_start:mctp_end]

        # Parse MCTP framing and message
        try:
            # Create a synthetic MCTP packet with EIDs for the parser
            # MCTPParser expects: dest_eid, hdr_ver, src_eid, som/eom, ic/msg_type, payload
            synthetic_packet = [
                smbus_info['dest_eid'],
                (smbus_info['mctp_hdr_version'] << 4) | smbus_info['mctp_reserved'],
                smbus_info['src_eid']
            ] + mctp_bytes

            mctp = MCTPParser(synthetic_packet)
            parsed = mctp.parse()
            output = mctp.format_output(parsed)
            print(output)

            # Verify all bytes accounted for
            total_bytes = len(packet_bytes)
            accounted = 2 + smbus_info['byte_count']  # cmd + count + data
            if smbus_info['pec_received'] is not None:
                accounted += 1

            if accounted != total_bytes:
                print(f"\nWarning: {total_bytes - accounted} byte(s) unaccounted for")

        except Exception as e:
            print(f"Error parsing MCTP packet: {e}", file=sys.stderr)
            return 1
    else:
        # Parse as raw MCTP packet
        try:
            mctp = MCTPParser(packet_bytes)
            parsed = mctp.parse()
            output = mctp.format_output(parsed)
            print(output)
        except Exception as e:
            print(f"Error parsing packet: {e}", file=sys.stderr)
            return 1

    # Always calculate and display PEC for the input data
    print("\n" + "=" * 70)
    print("PEC Calculation")
    print("=" * 70)

    if args.append_pec:
        # Calculate PEC for all input bytes (user wants to append PEC)
        if dest_addr_byte is not None:
            # libmctp style: include destination address in PEC
            print(f"Using libmctp-style PEC (includes I2C destination address)")
            print(f"  Destination address (7-bit): 0x{dest_addr_byte >> 1:02X}")
            print(f"  Destination address (8-bit with W bit): 0x{dest_addr_byte:02X}")
            pec = calculate_pec([dest_addr_byte] + packet_bytes)
        else:
            # mctp-estack style: no destination address
            print(f"Using mctp-estack-style PEC (no destination address)")
            pec = calculate_pec(packet_bytes)

        print(f"\nInput data ({len(packet_bytes)} bytes):")
        for i in range(0, len(packet_bytes), 16):
            chunk = packet_bytes[i:i+16]
            hex_str = ' '.join(f'{b:02X}' for b in chunk)
            print(f"  {hex_str}")
        print(f"\nCalculated PEC:     0x{pec:02X} ({pec})")
        print(f"\nComplete message with PEC ({len(packet_bytes) + 1} bytes):")
        complete = packet_bytes + [pec]
        for i in range(0, len(complete), 16):
            chunk = complete[i:i+16]
            hex_str = ' '.join(f'{b:02X}' for b in chunk)
            print(f"  {hex_str}")
    else:
        # Default: assume last byte is PEC, verify it
        if len(packet_bytes) < 2:
            print("Warning: Message too short to have PEC (need at least 2 bytes)")
        else:
            received_pec = packet_bytes[-1]
            data_bytes = packet_bytes[:-1]

            if dest_addr_byte is not None:
                # libmctp style: include destination address in PEC
                print(f"Using libmctp-style PEC (includes I2C destination address)")
                print(f"  Destination address (7-bit): 0x{dest_addr_byte >> 1:02X}")
                print(f"  Destination address (8-bit with W bit): 0x{dest_addr_byte:02X}")
                calculated_pec = calculate_pec([dest_addr_byte] + data_bytes)
            else:
                # mctp-estack style: no destination address
                print(f"Using mctp-estack-style PEC (no destination address)")
                calculated_pec = calculate_pec(data_bytes)

            print(f"\nMessage data ({len(data_bytes)} bytes):")
            for i in range(0, len(data_bytes), 16):
                chunk = data_bytes[i:i+16]
                hex_str = ' '.join(f'{b:02X}' for b in chunk)
                print(f"  {hex_str}")

            print(f"\nPEC (Received):     0x{received_pec:02X} ({received_pec})")
            print(f"PEC (Calculated):   0x{calculated_pec:02X} ({calculated_pec})")

            if received_pec == calculated_pec:
                print(f"PEC Status:         ✓ VALID")
            else:
                print(f"PEC Status:         ✗ INVALID (mismatch)")
                print(f"\nCorrect message should be ({len(packet_bytes)} bytes):")
                correct = data_bytes + [calculated_pec]
                for i in range(0, len(correct), 16):
                    chunk = correct[i:i+16]
                    hex_str = ' '.join(f'{b:02X}' for b in chunk)
                    print(f"  {hex_str}")

    print("=" * 70)

    return 0


if __name__ == '__main__':
    sys.exit(main())
