#include <Arduino.h>
#include <SPI.h>
#include <bitset>
#include <cstddef>
#include <stdint.h>

#include "gonk/adc.h"
#include "gonk_adc/adc_measurement.pwpb.h"

#define PW_LOG_LEVEL PW_LOG_LEVEL_INFO
#define PW_LOG_MODULE_NAME "Adc"

#include "pw_bytes/bit.h"
#include "pw_bytes/endian.h"
#include "pw_bytes/span.h"
#include "pw_log/log.h"
#include "pw_result/result.h"
#include "pw_span/span.h"
#include "pw_status/status.h"

namespace gonk::adc {

namespace {

volatile uint8_t fpga_valid_pulse_ = 0;

void io_valid_rising_isr() {
  fpga_valid_pulse_ = 1;
  PW_LOG_DEBUG("IO Valid: HIGH");
}

void ClearValidPulse() { fpga_valid_pulse_ = 0; }

} // namespace

constexpr uint8_t INA229_CONFIG = 0x0;
constexpr uint8_t INA229_ADC_CONFIG = 0x1;
constexpr uint8_t INA229_SHUNT_CALIBRATION = 0x2;
constexpr uint8_t INA229_SHUNT_TEMP_COEFFICIENT = 0x3;
constexpr uint8_t INA229_VSHUNT = 0x4;
constexpr uint8_t INA229_VBUS = 0x5;
constexpr uint8_t INA229_DIETEMP = 0x6;
constexpr uint8_t INA229_CURRENT = 0x7;
constexpr uint8_t INA229_POWER = 0x8;
constexpr uint8_t INA229_ENERGY = 0x9;
constexpr uint8_t INA229_CHARGE = 0xA;
constexpr uint8_t INA229_DIAG_ALERT = 0xB;
constexpr uint8_t INA229_SHUNT_OVERVOLT_THRESHOLD = 0xC;
constexpr uint8_t INA229_SHUNT_UNDERVOLT_THRESHOLD = 0xD;
constexpr uint8_t INA229_BUS_OVERVOLT_THRESHOLD = 0xE;
constexpr uint8_t INA229_BUS_UNDERVOLT_THRESHOLD = 0xF;
constexpr uint8_t INA229_TEMP_LIMIT = 0x10;
constexpr uint8_t INA229_POWER_LIMIT = 0x11;

constexpr uint8_t INA229_MANUFACTURER_ID = 0x3e;
constexpr uint8_t INA229_DEVICE_ID = 0x3f;

constexpr uint8_t INA229RegisterByteSize[]{
    [INA229_CONFIG] = 2,
    [INA229_ADC_CONFIG] = 2,
    [INA229_SHUNT_CALIBRATION] = 2,
    [INA229_SHUNT_TEMP_COEFFICIENT] = 2,
    [INA229_VSHUNT] = 3,
    [INA229_VBUS] = 3,
    [INA229_DIETEMP] = 2,
    [INA229_CURRENT] = 3,
    [INA229_POWER] = 3,
    [INA229_ENERGY] = 5,
    [INA229_CHARGE] = 5,
    [INA229_DIAG_ALERT] = 2,
    [INA229_SHUNT_OVERVOLT_THRESHOLD] = 2,
    [INA229_SHUNT_UNDERVOLT_THRESHOLD] = 2,
    [INA229_BUS_OVERVOLT_THRESHOLD] = 2,
    [INA229_BUS_UNDERVOLT_THRESHOLD] = 2,
    [INA229_TEMP_LIMIT] = 2,
    [INA229_POWER_LIMIT] = 2,
};
constexpr pw::span<const uint8_t>
    kINA229RegisterByteSize(INA229RegisterByteSize);

Adc::Adc(Stream &serial_stream, SPIClass &fpga_spi, uint32_t fpga_spi_baudrate,
         uint16_t fpga_cs_pin, uint16_t fpga_mode_pin, uint16_t fpga_reset_pin,
         uint16_t fpga_valid_pin)
    : serial_(serial_stream), fpga_spi_(fpga_spi), fpga_reset_(fpga_reset_pin),
      fpga_mode_(fpga_mode_pin), fpga_valid_(fpga_valid_pin),
      fpga_cs_(fpga_cs_pin),
      spi_settings_(fpga_spi_baudrate, MSBFIRST, SPI_MODE1) {
  sample_read_index_ = 0;
  measurement_timestamp_ = 0;
  previous_timestamp_ = 0;
  sample_read_time_.fill(0);
}

void Adc::SetReadWriteMode() {
  // Reset FPGA logic
  digitalWrite(fpga_reset_, HIGH);
  delay(10);
  // Toggle mode to 0
  digitalWrite(fpga_mode_, LOW);
  delay(10);
  // Release reset
  digitalWrite(fpga_reset_, LOW);
  delay(10);
}

void Adc::SetContinuousReadMode() {
  // Reset FPGA logic
  digitalWrite(fpga_reset_, HIGH);
  delay(10);

  // Valid signal can go high from the FPGA side as soon as mode is set to 1.
  ClearValidPulse();

  // Toggle mode to 1
  digitalWrite(fpga_mode_, HIGH);
  delay(10);

  // Release reset
  digitalWrite(fpga_reset_, LOW);
  delay(10);
}

Status Adc::WaitForFpgaIOValid(uint32_t timeout_ms = 2000) {
  PW_LOG_DEBUG("FPGA Valid Signal: Waiting");

  uint32_t last_update = millis();
  uint32_t this_update = millis();
  bool done = false;
  int valid = 0;

  while (!done) {
    valid = digitalRead(fpga_valid_);
    if (valid == 1) {
      PW_LOG_DEBUG("FPGA Valid Signal: Result Ready");
      break;
    }

    this_update = millis();
    // If more than two seconds have passed something likely went wrong.
    if (this_update > last_update + timeout_ms) {
      PW_LOG_ERROR("FPGA Valid Signal: Timeout");
      last_update = this_update;
      return pw::Status::DeadlineExceeded();
    }
  }

  return pw::OkStatus();
}

Status Adc::WaitForFpgaIOValidPulse(uint32_t timeout_ms = 2000) {
  PW_LOG_DEBUG("FPGA Valid Pulse: Waiting");

  uint32_t last_update = millis();
  uint32_t this_update = millis();
  bool done = false;

  while (!done) {
    if (fpga_valid_pulse_ == 1) {
      PW_LOG_DEBUG("FPGA Valid Pulse: Detected");
      break;
    }

    this_update = millis();
    // If more than two seconds have passed something likely went wrong.
    if (this_update > last_update + timeout_ms) {
      PW_LOG_ERROR("FPGA Valid Pulse: Timeout");
      last_update = this_update;
      return pw::Status::DeadlineExceeded();
    }
  }

  return pw::OkStatus();
}

uint32_t Adc::ADCAddress(uint8_t adc_number, uint8_t adc_command,
                         uint8_t mode = 1) {
  // Address + Read/Write bit (24 bits)
  //
  // [23:18] - Dont Care bits
  // [17:7] - ADC Select. These 11 bits are used to select one or multiple ADCs
  //          for a write operation.
  //   [17] - ADC11 Select (ADC # from schematics)
  //   [16] - ADC10 Select
  //   ...
  //   [8] - ADC2 Select
  //   [7] - ADC1 Select
  //   NoTE: Select only one ADC for a read operation.
  // [6:1] - Register offset address for INA229.
  // [0] -  R/W bit. 1: READ, 0: WRITE

  const uint8_t adc_index = adc_number - 1;

  return (
      // ADC selection
      1 << (adc_index + 7)
      // ADC Register
      | ((adc_command & 0x3f) << 1)
      // 1=read, 0=write
      | mode);
}

uint32_t Adc::ADCAddressWrite(uint8_t adc_number, uint8_t adc_command) {
  // Address with mode=0 for a write.
  return ADCAddress(adc_number, adc_command, 0);
}

uint32_t Adc::ADCAddressWriteAll(uint8_t adc_command) {
  return (
      // Select all 11 ADCs
      0x7FF << 7
      // ADC Register
      | ((adc_command & 0x3f) << 1)
      // 1=read, 0=write
      | 0);
}

pw::Result<uint8_t> Adc::RegisterSize(uint8_t adc_register) {
  if (adc_register == INA229_MANUFACTURER_ID ||
      adc_register == INA229_DEVICE_ID) {
    return 2;
  }
  if (adc_register >= kINA229RegisterByteSize.size()) {
    PW_LOG_ERROR("Invalid register: %x", adc_register);
    return pw::Status::OutOfRange();
  }
  return INA229RegisterByteSize[adc_register];
}

void Adc::StartSpiTransaction() { fpga_spi_.beginTransaction(spi_settings_); }

void Adc::EndSpiTransaction() { fpga_spi_.endTransaction(); }

void Adc::WriteAddress(uint32_t adc_address) {
  uint8_t address[3];
  address[0] = (adc_address >> 16) & 0xff;
  address[1] = (adc_address >> 8) & 0xff;
  address[2] = adc_address & 0xff;

  PW_LOG_DEBUG("WriteAddress: %x %x %x", address[2], address[1], address[0]);

  fpga_spi_.transfer(address, 3);
}

pw::Result<pw::ConstByteSpan> Adc::ReadData(pw::ByteSpan read_buffer) {
  fpga_spi_.transfer(read_buffer.data(), read_buffer.size());

  return pw::ConstByteSpan(read_buffer);
}

Status Adc::WriteData(pw::ByteSpan write_buffer) {
  fpga_spi_.transfer(write_buffer.data(), write_buffer.size());

  return pw::OkStatus();
}

pw::Result<pw::ConstByteSpan> Adc::GetManufacturerID(uint8_t adc_number) {
  return GetRegister(adc_number, INA229_MANUFACTURER_ID);
}

pw::Result<pw::ConstByteSpan> Adc::GetRegister(uint8_t adc_number,
                                               uint8_t adc_register) {
  uint32_t adc_address = ADCAddress(adc_number, adc_register);

  StartSpiTransaction();
  WriteAddress(adc_address);

  Status wait_result = WaitForFpgaIOValid();
  if (!wait_result.ok()) {
    EndSpiTransaction();
    return wait_result;
  }

  pw::Result<uint8_t> register_size_result = RegisterSize(adc_register);
  if (!register_size_result.ok()) {
    return register_size_result.status();
  }

  std::array<std::byte, 5> read_buffer;
  pw::ByteSpan read_span =
      pw::ByteSpan(read_buffer.data(), register_size_result.value());

  const pw::Result<pw::ConstByteSpan> read_result = ReadData(read_span);
  if (!read_result.ok()) {
    return read_result.status();
  }

  EndSpiTransaction();

  return read_result.value();
}

pw::Result<pw::ConstByteSpan> Adc::GetThreeBytes() {
  std::array<std::byte, 3> read_buffer;
  pw::ByteSpan read_span = pw::ByteSpan(read_buffer);

  return ReadData(read_span);
}

Status Adc::UpdateContinuousMeasurements() {

  uint32_t start_time = micros();

  StartSpiTransaction();

  Status wait_result = WaitForFpgaIOValid();
  if (!wait_result.ok()) {
    EndSpiTransaction();
    return wait_result;
  }

  previous_timestamp_ = measurement_timestamp_;
  measurement_timestamp_ = micros();
  measurement_delta_micros_ = measurement_timestamp_ - previous_timestamp_;
  for (int i = 0; i < kMaxStreamingAdcCount; i++) {
    // Read VBUS
    std::array<std::byte, 3> vbus_read_buffer;
    const pw::Result<pw::ConstByteSpan> vbus_read_result =
        ReadData(pw::ByteSpan(vbus_read_buffer));
    if (!vbus_read_result.ok()) {
      PW_LOG_ERROR("vbus_read failed i=%d", i);
      return vbus_read_result.status();
    }
    int32_t vbus_value = VoltageMeasurement(vbus_read_result.value());
    PW_LOG_DEBUG("Continuous Read ADC #%02d: VBUS   = %02x %02x %02x = %d", i,
                 vbus_read_result.value()[0], vbus_read_result.value()[1],
                 vbus_read_result.value()[2], vbus_value);

    // Read VSHUNT
    std::array<std::byte, 3> vshunt_read_buffer;
    const pw::Result<pw::ConstByteSpan> vshunt_read_result =
        ReadData(pw::ByteSpan(vshunt_read_buffer));
    if (!vshunt_read_result.ok()) {
      PW_LOG_ERROR("vshunt_read failed i=%d", i);
      return vshunt_read_result.status();
    }
    int32_t vshunt_value = VoltageMeasurement(vshunt_read_result.value());
    PW_LOG_DEBUG("Continuous Read ADC #%02d: VSHUNT = %02x %02x %02x = %d", i,
                 vshunt_read_result.value()[0], vshunt_read_result.value()[1],
                 vshunt_read_result.value()[2], vshunt_value);

    // Save this measurement.
    measurements_[i].vbus_value_ = vbus_value;
    measurements_[i].vshunt_value_ = vshunt_value;
    for (size_t i = 0; i < 3; i++) {
      measurements_[i].vbus_bytes_[i] = vbus_read_result.value()[i];
      measurements_[i].vshunt_bytes_[i] = vshunt_read_result.value()[i];
    }
  }

  // All data has been read: clear the pulse signal variable.
  ClearValidPulse();

  EndSpiTransaction();

  uint32_t end_time = micros();
  sample_read_time_[sample_read_index_] = end_time - start_time;
  sample_read_index_ = (sample_read_index_ + 1) % sample_read_time_.size();

  return pw::OkStatus();
}

// Ensure the proto max size is == the expected kMaxStreamingAdcCount
static_assert(Payload::kAdcMeasurementsMaxSize == kMaxStreamingAdcCount);

Status Adc::WriteMeasurementPacket() {
  std::array<std::byte, FramedProto::kMaxEncodedSizeBytes +
                            (AdcMeasure::kMaxEncodedSizeBytes *
                             Payload::kAdcMeasurementsMaxSize)>
      packet_buffer;
  FramedProto::MemoryEncoder packet(packet_buffer);
  Status status = packet.WriteMagicStart(kFramedProtoMagicConstant);
  if (!status.ok()) {
    PW_LOG_ERROR("WriteMagicStart %d", status);
    return status;
  }

  // Controls lifetime of payload.
  {
    auto payload = packet.GetPayloadEncoder();
    status = payload.WriteTimestamp(measurement_delta_micros_);
    if (!status.ok()) {
      PW_LOG_ERROR("WriteTimestamp: %d", status);
      return status;
    }
    for (size_t i = 0; i < kMaxStreamingAdcCount; i++) {
      auto adc_encoder = payload.GetAdcMeasurementsEncoder();
      status = adc_encoder.WriteVbusValue(measurements_[i].vbus_value_);
      if (!status.ok()) {
        PW_LOG_ERROR("WriteVbusValue: #%d %d", i, status);
        return status;
      }
      status = adc_encoder.WriteVshuntValue(measurements_[i].vshunt_value_);
      if (!status.ok()) {
        PW_LOG_ERROR("WriteVshuntValue: #%d %d", i, status);
        return status;
      }
    }
  }

  if (!packet.status().ok()) {
    PW_LOG_ERROR("packet.status %d", status);
    return packet.status();
  }

  // Write proto packet bytes over serial.
  serial_.write(reinterpret_cast<const uint8_t *>(packet.data()),
                packet.size());
  serial_.flush();

  return pw::OkStatus();
}

Status Adc::WriteRegister(uint8_t adc_number, uint8_t adc_register,
                          pw::ByteSpan write_buffer) {
  uint32_t adc_address = ADCAddressWrite(adc_number, adc_register);

  StartSpiTransaction();
  ClearValidPulse();
  WriteAddress(adc_address);
  WriteData(write_buffer);
  Status wait_result = WaitForFpgaIOValidPulse();
  EndSpiTransaction();

  return wait_result;
}

Status Adc::WriteRegisterAll(uint8_t adc_register, pw::ByteSpan write_buffer) {
  uint32_t adc_address = ADCAddressWriteAll(adc_register);

  StartSpiTransaction();
  ClearValidPulse();
  WriteAddress(adc_address);
  WriteData(write_buffer);
  Status wait_result = WaitForFpgaIOValidPulse();
  EndSpiTransaction();

  return wait_result;
}

pw::Result<pw::ConstByteSpan> Adc::GetADCConfiguration(uint8_t adc_number) {
  return GetRegister(adc_number, INA229_ADC_CONFIG);
}

Status Adc::SetADCConfiguration(uint8_t adc_number, pw::ByteSpan write_buffer) {
  return WriteRegister(adc_number, INA229_ADC_CONFIG, write_buffer);
}

pw::Result<pw::ConstByteSpan> Adc::GetShuntCalibration(uint8_t adc_number) {
  return GetRegister(adc_number, INA229_SHUNT_CALIBRATION);
}

pw::Result<int32_t> Adc::GetShuntVoltageMeasurement(uint8_t adc_number) {
  pw::Result<pw::ConstByteSpan> read_result =
      GetRegister(adc_number, INA229_VSHUNT);
  if (!read_result.ok()) {
    return read_result.status();
  }

  int32_t value = VoltageMeasurement(read_result.value());
  PW_LOG_INFO("ADC #%02d: VSHUNT = %02x %02x %02x = %d", adc_number,
              read_result.value()[0], read_result.value()[1],
              read_result.value()[2], value);

  return value;
}

pw::Result<int32_t> Adc::GetBusVoltageMeasurement(uint8_t adc_number) {
  pw::Result<pw::ConstByteSpan> read_result =
      GetRegister(adc_number, INA229_VBUS);
  if (!read_result.ok()) {
    return read_result.status();
  }

  int32_t value = VoltageMeasurement(read_result.value());
  PW_LOG_INFO("ADC #%02d: VBUS   = %02x %02x %02x = %d", adc_number,
              read_result.value()[0], read_result.value()[1],
              read_result.value()[2], value);

  return value;
}

int32_t Adc::VoltageMeasurement(pw::ConstByteSpan read_buffer) {
  // Convert bits 23 to 4 to a signed integer. Bits 3 to 0 are discarded.
  //
  // read_buffer[0] << 12  76543210<-----------|
  // read_buffer[1] << 4           76543210<---|
  // read_buffer[2] >> 4                   76543210--->
  // result                98765432109876543210|
  //
  return pw::bytes::SignExtend<20>((uint32_t)read_buffer[0] << 12 |
                                   (uint32_t)read_buffer[1] << 4 |
                                   (uint32_t)read_buffer[2] >> 4);
}

Status Adc::InitAdcs() {
  attachInterrupt(/*pin=*/fpga_valid_, /*callback=*/&io_valid_rising_isr,
                  /*mode=*/HIGH);

  SetReadWriteMode();

  PW_LOG_INFO("Init ADCs");

  // Set all ADC_CONFIG registers.
  std::array<std::byte, 2> adc_config = {
      // [15-12] MODE=0xB:   Continuous shunt and bus voltage
      // [8-6]   VBUSCT=0x0: 50us conversion time
      // [5-3]   VSHCT=0x0:  50us conversion time
      // [0-2]   AVG=0x0:    1 sample averaging count
      std::byte(0b10110000),
      std::byte(0b00000000),
  };
  WriteRegisterAll(INA229_ADC_CONFIG, adc_config);

  // Set all DIAG_ALERT registers
  std::array<std::byte, 2> diag_alert = {
      // CNVR=1: Enablne conversion ready flag
      std::byte(0b01000000),
      // MEMSTAT=1: Normal operation
      std::byte(0b00000000),
  };
  WriteRegisterAll(INA229_DIAG_ALERT, diag_alert);

  return pw::OkStatus();
}

Status Adc::CheckAllAdcs() {
  // Check all ADCs are reachable
  for (int adc_number = 1; adc_number <= kTotalAdcCount; adc_number++) {

    // Read and log the ADC_CONFIG register.
    const pw::Result<pw::ConstByteSpan> adc_config_result =
        GetADCConfiguration(adc_number);
    if (adc_config_result.ok()) {
      uint16_t adc_value = pw::bytes::ReadInOrder<uint16_t>(
          pw::endian::big, adc_config_result.value().data());
      PW_LOG_INFO("ADC #%02d: ADC Config = %x", adc_number, adc_value);
    } else {
      return adc_config_result.status();
    }

    // Read and log the VSHUNT register.
    const pw::Result<int32_t> vshunt_result =
        GetShuntVoltageMeasurement(adc_number);
    if (!vshunt_result.ok()) {
      PW_LOG_ERROR("GetShuntVoltageMeasurement failed: %d",
                   vshunt_result.status());
    }
    // Read and log the VBUS register.
    const pw::Result<int32_t> vbus_result =
        GetBusVoltageMeasurement(adc_number);
    if (!vbus_result.ok()) {
      PW_LOG_ERROR("GetBusVoltageMeasurement failed: %d", vbus_result.status());
    }
  }

  return pw::OkStatus();
}

} // namespace gonk::adc
