#include <limits.h>
#include <stdbool.h>
#include <stdint.h>
#include "tinystdio.h"
#include <string.h>
#include "stm32f4xx_usart.h"
#include "stm32f4xx_crc.h"
#include "stm32sprog.h"
#include "serial.h"
#include "systick.h"
#include "common_config.h"


#undef DBG
#define DBG if(0)

static int cmdIndex(uint8_t cmd);
static bool cmdSupported(Command cmd);

static bool stmRecvAck(void);
static bool stmSendByte(uint8_t byte);
static bool stmSendAddr(uint32_t addr);
static bool stmSendBlock(const uint8_t *buffer, size_t size);
static void printProgressBar(int percent);
static bool stmReadBlock(uint32_t addr, uint8_t *buff, size_t size);
static bool stmRun();

static SerialDev *dev = NULL;
static DeviceParameters devParams;

static uint32_t flash_start = 0;

void usleep(uint32_t us) {
    if (us < 1000) us = 1000;
    Delay_ms(us / 1000);
}


int stm32sprog_test(void) {
	bool ok = false;

	stmRebootForFlash();

	/* Connect */
	ok = stmConnect();

	if(!ok) {
		printf("STM32 not detected.\n");
		return 1;
	} else printf("STM32 connected\n");

	/* Get params */
	ok = stmGetDevParams();
	if(!ok) {
		printf("Device not supported.\n");
		return 2;
	}

	int major = devParams.bootloaderVer >> 4;
	int minor = devParams.bootloaderVer & 0x0F;
	printf("Bootloader version %d.%d detected.\n", major, minor);

	/* Erase */
	printf("Erasing.. ");
	ok = stmEraseFlash();
	if(ok) printf("OK\n");
	else {
		printf("Erase error!\n");
		return 3;
	}

	/* Write */
    uint8_t buff[MAX_BLOCK_SIZE];
    memset(buff, 0xAA, MAX_BLOCK_SIZE);

    uint32_t addr = devParams.flashBeginAddr;
    printf("Start addr: 0x%X\n", (unsigned int)devParams.flashBeginAddr);
    printf("End addr: x%X\n", (unsigned int)devParams.flashEndAddr);

    ok = stmWriteBlock(addr, buff, MAX_BLOCK_SIZE);
    usleep(devParams.writeDelay);

    printf("Writing.. ");

    if (ok) printf("Written %d bytes\r\n", MAX_BLOCK_SIZE);
    	else printf("Write error!\r\n");

    printf("Comparing.. ");

	uint8_t refbuff[MAX_BLOCK_SIZE];
	memset(refbuff, 0xBB, MAX_BLOCK_SIZE);
	ok = false;

	ok = stmReadBlock(addr, refbuff, MAX_BLOCK_SIZE);
	if(ok) ok = (memcmp(buff, refbuff, MAX_BLOCK_SIZE) == 0);
		else printf("Read error!\r\n");

    if (ok) printf("OK\r\n");
    	else printf("Compare error!\r\n");

    printf("Erasing.. ");
    if (stmEraseFlash()) printf("OK\r\n");
        else printf("Erase error!\r\n");

    if (ok) printf("Test OK\r\n");
    	else printf("Test failed!\n");

    return 0;
}

bool stmConnect(void) {

    uint8_t data = 0x7F;
    int retries = 0;
    do {
    	usleep(10000);
        if(++retries > MAX_CONNECT_RETRIES) return false;
        (void)serialWrite(dev, &data, 1);
    } while(!stmRecvAck());
    return true;
}

static int cmdIndex(uint8_t cmd) {
    int idx = -1;
    switch(cmd) {
    case CMD_GET_VERSION:     idx++;
    case CMD_GET_READ_STATUS: idx++;
    case CMD_GET_ID:          idx++;
    case CMD_READ_MEM:        idx++;
    case CMD_GO:              idx++;
    case CMD_WRITE_MEM:       idx++;
    case CMD_ERASE:           idx++;
    case CMD_EXTENDED_ERASE:  idx++;
    case CMD_WRITE_PROTECT:   idx++;
    case CMD_WRITE_UNPROTECT: idx++;
    case CMD_READ_PROTECT:    idx++;
    case CMD_READ_UNPROTECT:  idx++;
    default:                  break;
    }
    //assert(idx < NUM_COMMANDS_KNOWN);
    return idx;
}

static bool cmdSupported(Command cmd) {
    int idx = cmdIndex(cmd);
    //assert(idx >= 0);
    return devParams.commands[idx];
}

static bool stmRecvAck(void) {
    uint8_t data = 0;
    if(!serialRead(dev, &data, 1)) return false;
    return data == ACK;
}

static bool stmSendByte(uint8_t byte) {
    uint8_t buffer[] = { byte, ~byte };
    if(!serialWrite(dev, buffer, sizeof(buffer))) return false;
    return stmRecvAck();
}

static bool stmSendAddr(uint32_t addr) {
    //assert(addr % 4 == 0);
    uint8_t buffer[5] = { 0 };
    int i;
    for(i = 0; i < 4; ++i) {
        buffer[i] = (uint8_t)(addr >> ((3 - i) * CHAR_BIT));
        buffer[4] ^= buffer[i];
    }
    if(!serialWrite(dev, buffer, sizeof(buffer))) return false;
    return stmRecvAck();
}

static bool stmSendBlock(const uint8_t *buffer, size_t size) {
    //assert(size > 0 && size <= MAX_BLOCK_SIZE);
    size_t padding = (4 - (size % 4)) % 4;
    uint8_t n = size + padding - 1;
    uint8_t checksum = n;
    size_t i;
    for(i = 0; i < size; ++i) checksum ^= buffer[i];
    if(!serialWrite(dev, &n, 1)) return false;
    if(!serialWrite(dev, buffer, size)) return false;
    for(i = 0; i < padding; ++i) {
        uint8_t data = 0xFF;
        checksum ^= data;
        if(!serialWrite(dev, &data, 1)) return false;
    }
    if(!serialWrite(dev, &checksum, 1)) return false;
    return stmRecvAck();
}

bool stmGetDevParams(void) {
    uint8_t data = 0;
    int i;

	if (flash_start) devParams.flashBeginAddr = flash_start;
    else devParams.flashBeginAddr = 0x08000000;
    devParams.flashEndAddr = 0x08008000;
    devParams.flashPagesPerSector = 4;
    devParams.flashPageSize = 1024;
    devParams.eraseDelay = 40000;
    devParams.writeDelay = 80000;

    if(!stmSendByte(CMD_GET_VERSION)) return false;
    if(!serialRead(dev, &data, 1)) return false;
    if(!serialRead(dev, &devParams.bootloaderVer, 1)) return false;
    for(i = 0; i < NUM_COMMANDS_KNOWN; ++i) devParams.commands[i] = false;
    for(i = data; i > 0; --i) {
        if(!serialRead(dev, &data, 1)) return false;
        int idx = cmdIndex(data);
        if(idx >= 0) devParams.commands[idx] = true;
    }
    if(!stmRecvAck()) return false;

    if(!cmdSupported(CMD_GET_ID)) {
        printf("Target device does not support GET_ID command.\n");
        return false;
    }
    if(!stmSendByte(CMD_GET_ID)) return false;
    if(!serialRead(dev, &data, 1)) return false;
    if(data != 1) return false;
    uint16_t id = 0;
    for(i = data; i >= 0; --i) {
        if(!serialRead(dev, &data, 1)) return false;
        if(i < 2) {
            id |= data << (i * CHAR_BIT);
        }
    }
    if(!stmRecvAck()) return false;
    switch(id) {
    case ID_LOW_DENSITY:
        devParams.flashEndAddr = 0x08008000;
        break;
    case ID_MED_DENSITY:
        devParams.flashEndAddr = 0x08020000;
        break;
    case ID_HI_DENSITY:
        devParams.flashEndAddr = 0x08080000;
        devParams.flashPagesPerSector = 2;
        devParams.flashPageSize = 2048;
        break;
    case ID_CONNECTIVITY:
        devParams.flashEndAddr = 0x08040000;
        devParams.flashPagesPerSector = 2;
        devParams.flashPageSize = 2048;
        break;
    case ID_VALUE:
        devParams.flashEndAddr = 0x08020000;
        break;
    case ID_HI_DENSITY_VALUE:
        devParams.flashEndAddr = 0x08080000;
        devParams.flashPagesPerSector = 2;
        devParams.flashPageSize = 2048;
        break;
    case ID_XL_DENSITY:
        devParams.flashEndAddr = 0x08100000;
        devParams.flashPagesPerSector = 2;
        devParams.flashPageSize = 2048;
        break;
    case ID_MED_DENSITY_ULTRA_LOW_POWER:
        devParams.flashEndAddr = 0x08060000;
        devParams.flashPagesPerSector = 16;
        devParams.flashPageSize = 256;
        break;
    case ID_HI_DENSITY_ULTRA_LOW_POWER:
        devParams.flashEndAddr = 0x08020000;
        devParams.flashPagesPerSector = 16;
        devParams.flashPageSize = 256;
        break;
    default:
        printf("Device id %d\n not supported", id);
        return false;
    }

    return true;
}

bool stmEraseFlash(void) {
    if(cmdSupported(CMD_ERASE)) {
        if(!stmSendByte(CMD_ERASE)) return false;
        if(!stmSendByte(0xFF)) return false;
    } else if(cmdSupported(CMD_EXTENDED_ERASE)) {
        /* CMD_EXTENDED_ERASE not tested */
    	printf("CMD_EXTENDED_ERASE supported\r\n");
    	int c;
        do {
            c = 'y';
            if(c == 'n' || c == 'N' || c == '\n') {
                printf("\n");
                return false;
            }
        } while(c != 'y' && c != 'Y');
        printf("\n");

        if(!stmSendByte(CMD_EXTENDED_ERASE)) return false;
        uint8_t data[] = { 0xFF, 0xFF, 0x00 };
        if(!serialWrite(dev, data, sizeof(data))) return false;
        if(!stmRecvAck()) return false;
    } else {
        printf("Target device does not support known erase commands.\n");
        return false;
    }

    /*
	int i;
    int delay = 60000;
    printf("Erasing:\n");
    for(i = 1; i <= 100; ++i) {
        usleep(delay);
        printProgressBar(i);
    }
    printf("\n");
    */

    return true;
}

/* Erase only FW flash pages, keep settings page */
bool stmEraseFW(void) {
    if(cmdSupported(CMD_ERASE)) {
        if(!stmSendByte(CMD_ERASE)) return false;
        /* n + 1 pages would be erased */
        uint8_t n = DB_CPU_SETTINGS_PAGE - 1;
        if(!serialWrite(dev, &n, 1)) return false;
        uint8_t checksum = n;

        for (uint8_t i = 0; i <= n; i++) {
            if(!serialWrite(dev, &i, 1)) return false;
            checksum ^= i;
        }

        if(!serialWrite(dev, &checksum, 1)) return false;
        if(!stmRecvAck()) return false;
    }
    else {
        printf("Target device does not support known erase commands.\n");
        return false;
    }
    return true;
}

bool stmWriteBlock(uint32_t addr, const uint8_t *buff, size_t size) {
    if(!stmSendByte(CMD_WRITE_MEM)) return false;
    if(!stmSendAddr(addr)) return false;
    if(!stmSendBlock(buff, size)) return false;
    return true;
}

static bool stmReadBlock(uint32_t addr, uint8_t *buff, size_t size) {
    if(!stmSendByte(CMD_READ_MEM)) return false;
    if(!stmSendAddr(addr)) return false;
    if(!stmSendByte(size - 1)) return false;
    return serialRead(dev, buff, size);
}

void stmRebootForFlash(void) {
    IO_SetDbBoot0();
    IO_ClearDbReset();
    usleep(60000);
    IO_SetDbReset();
    usleep(60000);
    IO_ClearDbBoot0();
}

void stmReboot(void) {
    IO_ClearDbBoot0();
    IO_ClearDbReset();
    usleep(60000);
    IO_SetDbReset();
}

uint32_t stmCalcFlashCrc(void (* periodic_handler)(uint8_t)) {
    bool ok = false;
    uint8_t buf[4];
    uint32_t crc = 0;
    uint8_t block[MAX_BLOCK_SIZE];
    uint32_t last_block_size, n = 0;
    volatile uint32_t* ptr;
    static uint32_t last_progress = 0;

    CRC_ResetDR();

#if 0
    /* Read by 4 bytes blocks. Slow method */
    for(uint32_t* ptr = (uint32_t*)DB_CPU_FLASH_FIRST_PAGE_ADDRESS; ptr != (uint32_t*)DB_CPU_FLASH_CRC_ADDRESS; ptr++) {
        ok = stmReadBlock((uint32_t)ptr, buf, 4);
        if(ok) {
            crc = CRC_CalcCRC(*(uint32_t *)buf);
            DBG printf("Verify block %u\r\n", n++);
        }
        else printf("Device read error!\r\n");
    }
#endif

    /* Read by MAX_BLOCK_SIZE bytes blocks. Fast method */
    for (ptr = (uint32_t*)DB_CPU_FLASH_FIRST_PAGE_ADDRESS; ptr < (uint32_t*)DB_CPU_FLASH_CRC_ADDRESS - MAX_BLOCK_SIZE / 4; ptr += MAX_BLOCK_SIZE/4) {

        ok = stmReadBlock((uint32_t)ptr, block, MAX_BLOCK_SIZE);
        if (ok) {
            crc = CRC_CalcBlockCRC((uint32_t *)block, MAX_BLOCK_SIZE / 4);
            DBG printf("Verify block %u\r\n", n++);
            if (periodic_handler != NULL) {
                uint32_t progress = ((uint32_t)ptr - DB_CPU_FLASH_FIRST_PAGE_ADDRESS) * 100 /
                    (DB_CPU_FLASH_CRC_ADDRESS - MAX_BLOCK_SIZE / 4 - DB_CPU_FLASH_FIRST_PAGE_ADDRESS);
                if (progress > last_progress) {
                    last_progress = progress;
                    periodic_handler(progress);
                }
            }
        }
        else printf("Device read error!\r\n");
    }

    last_block_size = ((uint32_t*)DB_CPU_FLASH_CRC_ADDRESS - ptr) * 4;
    ok = stmReadBlock((uint32_t)ptr, block, last_block_size);
    if(ok) {
        crc = CRC_CalcBlockCRC((uint32_t *)block, last_block_size / 4);
        DBG printf("Verify block %u\r\n", n++);
        if (periodic_handler != NULL) {
            periodic_handler(100);
        }
    }
    else printf("Device read error!\r\n");

    return crc;
}

uint32_t stmReadFlashCrc(void) {
    bool ok = false;
    uint8_t buf[4];
    uint32_t crc = 0x0;

    ok = stmReadBlock(DB_CPU_FLASH_CRC_ADDRESS, buf, 4);
    if(ok) {
          crc = *(uint32_t *)buf;
    }
    else printf("Read error!\r\n");

    printf("Written CRC: 0x%X\r\n", (unsigned int)crc);

    return crc;
}

void stmProg( uint32_t* addr, uint8_t * ptr, uint32_t len)
{
  uint32_t count, remain = 0, i = 0;
  static uint32_t n = 0;
  bool ok = false;

  /* Reset blocks count */
  if (*addr == DB_CPU_FLASH_FIRST_PAGE_ADDRESS) n = 0;


  /* write received bytes into flash */
  count = len / MAX_BLOCK_SIZE;

  /* check if remaining bytes */
  remain = len % MAX_BLOCK_SIZE;

  for (i = 0; i < count; i++) {
    ok = stmWriteBlock(*addr, ptr + (i * MAX_BLOCK_SIZE ), MAX_BLOCK_SIZE);
    usleep(8000);

    if (ok) {
        DBG printf("Block %u (0x%X) %ub\r\n", (unsigned int)n++, *addr, MAX_BLOCK_SIZE);
    }
    else {
      printf("Device write error!\r\n");
      return;
    }

      *addr += MAX_BLOCK_SIZE;
  }

  if (remain > 0)
  {
      ok = stmWriteBlock(*addr, ptr + (i * MAX_BLOCK_SIZE ), remain);
      if (ok) {
          DBG printf("Block %u (0x%X) %ub\r\n", (unsigned int)n++, *addr, (unsigned int)remain);
      }
      else {
          printf("Device write error!\r\n");
          return;
      }

      *addr += remain;
  }
}

static bool stmRun() {
    if(!stmSendByte(CMD_GO)) return false;
    return stmSendAddr(devParams.flashBeginAddr);
}

static void printProgressBar(int percent) {
    int num = percent * 70 / 100;
    int i = 0;
    printf("\r%3d%%[", percent);
    for(i = 0; i < 70; ++i) {
        printf(i < num ? "=" : " ");
    }
    printf("]");
}