import os
import time
from typing import Iterable, Union, Optional, Tuple
from serial import Serial


class ErrorAT32(IOError):
	"""Common error for FlashAT32 util"""


SendData = Union[str, bytes, bytearray, Iterable[int], int]


class FlashAT32:
	"""Artery AT32x flashing utility"""
	CMD_SYNC = 0x7F
	CMD_ACK = 0x79
	CMD_NACK = 0x1F
	ADDR_FLASH = 0x08000000
	CHUNK_SIZE = 256
	DEBUG_PRINT = False

	# def __init__(self, tty: str, baudrate: int = 115200):
	def __init__(self, tty: str, baudrate: int):
		self.serial = Serial(port=tty, baudrate=baudrate, timeout=1, parity='E', xonxoff=False)

	@staticmethod
	def to_bytes(data: SendData) -> bytes:
		"""Convert various types of data to bytes"""
		if isinstance(data, str):
			return data.encode()
		elif isinstance(data, int):
			return bytes((data, ))
		else:
			return bytes(data)

	@staticmethod
	def checksum(buf: bytes) -> int:
		"""Calculate standard CRC-8 checksum"""
		crc = 0 if len(buf) > 1 else 0xFF
		for b in buf:
			crc = crc.__xor__(b)
		return crc

	def send(self, data: SendData, add_checksum: bool = False):
		"""Send data to device"""
		buf = self.to_bytes(data)
		if add_checksum:
			buf += bytes((self.checksum(buf), ))
		if self.DEBUG_PRINT:
			print(f'> {buf}')
		self.serial.write(buf)

	def recv(self, count: int, timeout: float = 1) -> bytes:
		"""Receive bytes from device"""
		self.serial.timeout = timeout
		buf = self.serial.read(count)
		if self.DEBUG_PRINT:
			print(f'< {buf}')
		return buf

	def read(self, count: int, timeout: float = 3) -> bytes:
		"""Read count bytes from device or raise exception"""
		buf = bytearray()
		start_time = time.monotonic()
		while time.monotonic() - start_time < timeout:
			buf.extend(self.recv(count - len(buf)))
			if len(buf) == count:
				return bytes(buf)
		raise ErrorAT32('Read timeout')

	def receive_ack(self, timeout: float = 10) -> Optional[bool]:
		"""Wait ACK byte from device, return True on ACN, False on NACK, None on timeout"""
		start_time = time.monotonic()
		while time.monotonic() - start_time < timeout:
			buf = self.recv(1)
			if buf == bytes((self.CMD_ACK, )):
				return True
			if buf == bytes((self.CMD_NACK, )):
				return False
		return None

	def wait_ack(self, timeout: float = 10):
		"""Wait ACK byte from device, raise exception on NACK or timeout"""
		res = self.receive_ack(timeout)
		if res is None:
			raise ErrorAT32('ACK timeout')
		if not res:
			raise ErrorAT32('NACK received')

	def send_cmd(self, code: int):
		"""Send command and wait ACK"""
		self.send((code, ~code & 0xFF))
		self.wait_ack()

	def set_isp(self):
		"""Send host data to device"""
		self.send((0xFA, 0x05))
		ack = self.receive_ack()
		if ack is None:
			raise ErrorAT32('ACK timeout')
		if not ack:
			raise ErrorAT32('NACK received')
			return
		self.send((0x02, 0x03, 0x54, 0x41), add_checksum=True)
		self.wait_ack()

	def connect(self, timeout: float = 10):
		"""Init connection with AT32 bootloader"""
		start_time = time.monotonic()
		while time.monotonic() - start_time < timeout:
			self.send(self.CMD_SYNC)
			if self.recv(1, 0.01) == bytes((self.CMD_ACK, )):
				return self.set_isp()
		raise ErrorAT32('Connection timeout')

	def access_unprotect(self):
		"""Disable access protection"""
		self.send_cmd(0x92)
		self.wait_ack()

	def get_device_id(self) -> Tuple[int, int]:
		"""Read some device info"""
		self.send_cmd(0x02)
		length = self.read(1)[0]
		if length != 4:
			raise ErrorAT32('Incorrect response length')
		buf = self.read(length + 1)
		project_code, device_code = buf[4], buf[1] | (buf[0] << 8) | (buf[3] << 16) | (buf[2] << 24)
		if project_code not in {8, }:
			raise ErrorAT32('Device not supported')
		return project_code, device_code

	def reset(self):
		"""SW reboot for device"""
		self.send_cmd(0xD4)
		self.wait_ack()

	def read_mem(self, address: int, count: int) -> bytes:
		"""Read data from device memory"""
		self.send_cmd(0x11)
		self.send(address.to_bytes(4, 'big'), add_checksum=True)
		self.wait_ack()
		self.send((count - 1).to_bytes(1, 'big'), add_checksum=True)
		self.wait_ack()
		return self.read(count)

	def write_mem(self, address: int, buf: bytes):
		"""Write data to device memory"""
		self.send_cmd(0x31)
		self.send(address.to_bytes(4, 'big'), add_checksum=True)
		self.wait_ack()
		self.send(bytes((len(buf) - 1, )) + buf, add_checksum=True)
		self.wait_ack()

	def get_flash_size(self) -> int:
		"""Read flash size"""
		return int.from_bytes(self.read_mem(0x1FFFF7E0, 2), 'little') & 0x0000FFFF

	def get_uid(self) -> bytes:
		"""Read unique chip ID"""
		return self.read_mem(0x1FFFF7E8, 12)
		# buf = bytearray()
		# x = 3
		# for i in range(12 // x):
		#	 buf.extend(self.read_mem(0x1FFFF7E8 + i * x, 1 * x))
		# return bytes(buf)

	def get_uid_str(self) -> str:
		"""Get unique chip ID as hex string"""
		return ''.join(f'{b:02X}' for b in self.get_uid())

	def read_flash(self, address: int, size: int) -> bytes:
		"""Read all content of flash memory"""
		chunks, last_bytes_count = divmod(size, self.CHUNK_SIZE)
		res = bytearray()
		for i in range(chunks):
			res.extend(self.read_mem(address + i * self.CHUNK_SIZE, self.CHUNK_SIZE))
		if last_bytes_count != 0:
			res.extend(self.read_mem(address + chunks * self.CHUNK_SIZE, last_bytes_count))
		return bytes(res)

	def read_flash_to_file(self, address: int, size: int, path: str):
		"""Read all content of flash memory to file"""
		with open(path, 'wb') as f:
			f.write(self.read_flash(address, size))	

	def erase_flash(self):
		"""Do mass erase"""
		self.send_cmd(0x44)
		self.send((0xFF, 0xFF), add_checksum=True)
		self.wait_ack()

	def write_flash(self, address: int, buf: bytes):
		"""Write binary data to flash"""
		flash_full_chunks, last_chunk_size = divmod(len(buf), self.CHUNK_SIZE)
		for i in range(flash_full_chunks):
			self.write_mem(address + i * self.CHUNK_SIZE, buf[i * self.CHUNK_SIZE:(i + 1) * self.CHUNK_SIZE])
		if last_chunk_size > 0:
			self.write_mem(address + flash_full_chunks * self.CHUNK_SIZE,
						   buf[flash_full_chunks * self.CHUNK_SIZE:])

	def write_file_to_flash(self, address: int, path: str):
		"""write binary file to flash"""
		with open(path, 'rb') as f:
			self.write_flash(address, f.read())


if __name__ == '__main__':
	start_time = time.time()
	d = FlashAT32('/dev/ttyUSB0', 115200 * 8)
	d.DEBUG_PRINT = True
	try:
		d.connect()
	except Exception as e:
		print(e)
	print(d.get_flash_size())
	print(d.get_uid_str())
	d.erase_flash()
	iap_path = '../output/iap.bin'
	fw_path = '../output/fw.bin'
	iap_path_r = '../output/m3_artery_iap.bin'
	fw_path_r = '../output/m3_artery_fw.bin'
	d.write_file_to_flash(0x08000000, iap_path)
	d.write_file_to_flash(0x08021000, fw_path)
	d.read_flash_to_file(0x08000000, os.path.getsize(iap_path), iap_path_r)
	d.read_flash_to_file(0x08021000, os.path.getsize(fw_path), fw_path_r)
	os.system(f'diff {iap_path} {iap_path_r}')
	os.system(f'diff {fw_path} {fw_path_r}')
	os.remove(iap_path_r)
	os.remove(fw_path_r)

	print(time.time() - start_time)