|  | @@ -0,0 +1,233 @@
 | 
											
												
													
														|  | 
 |  | +import os
 | 
											
												
													
														|  | 
 |  | +import time
 | 
											
												
													
														|  | 
 |  | +from typing import Iterable, Union, Optional, Tuple
 | 
											
												
													
														|  | 
 |  | +from serial import Serial
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class ErrorAT32(IOError):
 | 
											
												
													
														|  | 
 |  | +	"""Common error for FlashAT32 util"""
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +SendData = Union[str, bytes, bytearray, Iterable[int], int]
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class FlashAT32:
 | 
											
												
													
														|  | 
 |  | +	"""Artery AT32x flashing utility"""
 | 
											
												
													
														|  | 
 |  | +	CMD_SYNC = 0x7F
 | 
											
												
													
														|  | 
 |  | +	CMD_ACK = 0x79
 | 
											
												
													
														|  | 
 |  | +	CMD_NACK = 0x1F
 | 
											
												
													
														|  | 
 |  | +	ADDR_FLASH = 0x08000000
 | 
											
												
													
														|  | 
 |  | +	CHUNK_SIZE = 256
 | 
											
												
													
														|  | 
 |  | +	DEBUG_PRINT = False
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	# def __init__(self, tty: str, baudrate: int = 115200):
 | 
											
												
													
														|  | 
 |  | +	def __init__(self, tty: str, baudrate: int):
 | 
											
												
													
														|  | 
 |  | +		self.serial = Serial(port=tty, baudrate=baudrate, timeout=1, parity='E', xonxoff=False)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	@staticmethod
 | 
											
												
													
														|  | 
 |  | +	def to_bytes(data: SendData) -> bytes:
 | 
											
												
													
														|  | 
 |  | +		"""Convert various types of data to bytes"""
 | 
											
												
													
														|  | 
 |  | +		if isinstance(data, str):
 | 
											
												
													
														|  | 
 |  | +			return data.encode()
 | 
											
												
													
														|  | 
 |  | +		elif isinstance(data, int):
 | 
											
												
													
														|  | 
 |  | +			return bytes((data, ))
 | 
											
												
													
														|  | 
 |  | +		else:
 | 
											
												
													
														|  | 
 |  | +			return bytes(data)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	@staticmethod
 | 
											
												
													
														|  | 
 |  | +	def checksum(buf: bytes) -> int:
 | 
											
												
													
														|  | 
 |  | +		"""Calculate standard CRC-8 checksum"""
 | 
											
												
													
														|  | 
 |  | +		crc = 0 if len(buf) > 1 else 0xFF
 | 
											
												
													
														|  | 
 |  | +		for b in buf:
 | 
											
												
													
														|  | 
 |  | +			crc = crc.__xor__(b)
 | 
											
												
													
														|  | 
 |  | +		return crc
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	def send(self, data: SendData, add_checksum: bool = False):
 | 
											
												
													
														|  | 
 |  | +		"""Send data to device"""
 | 
											
												
													
														|  | 
 |  | +		buf = self.to_bytes(data)
 | 
											
												
													
														|  | 
 |  | +		if add_checksum:
 | 
											
												
													
														|  | 
 |  | +			buf += bytes((self.checksum(buf), ))
 | 
											
												
													
														|  | 
 |  | +		if self.DEBUG_PRINT:
 | 
											
												
													
														|  | 
 |  | +			print(f'> {buf}')
 | 
											
												
													
														|  | 
 |  | +		self.serial.write(buf)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	def recv(self, count: int, timeout: float = 1) -> bytes:
 | 
											
												
													
														|  | 
 |  | +		"""Receive bytes from device"""
 | 
											
												
													
														|  | 
 |  | +		self.serial.timeout = timeout
 | 
											
												
													
														|  | 
 |  | +		buf = self.serial.read(count)
 | 
											
												
													
														|  | 
 |  | +		if self.DEBUG_PRINT:
 | 
											
												
													
														|  | 
 |  | +			print(f'< {buf}')
 | 
											
												
													
														|  | 
 |  | +		return buf
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	def read(self, count: int, timeout: float = 3) -> bytes:
 | 
											
												
													
														|  | 
 |  | +		"""Read count bytes from device or raise exception"""
 | 
											
												
													
														|  | 
 |  | +		buf = bytearray()
 | 
											
												
													
														|  | 
 |  | +		start_time = time.monotonic()
 | 
											
												
													
														|  | 
 |  | +		while time.monotonic() - start_time < timeout:
 | 
											
												
													
														|  | 
 |  | +			buf.extend(self.recv(count - len(buf)))
 | 
											
												
													
														|  | 
 |  | +			if len(buf) == count:
 | 
											
												
													
														|  | 
 |  | +				return bytes(buf)
 | 
											
												
													
														|  | 
 |  | +		raise ErrorAT32('Read timeout')
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	def receive_ack(self, timeout: float = 10) -> Optional[bool]:
 | 
											
												
													
														|  | 
 |  | +		"""Wait ACK byte from device, return True on ACN, False on NACK, None on timeout"""
 | 
											
												
													
														|  | 
 |  | +		start_time = time.monotonic()
 | 
											
												
													
														|  | 
 |  | +		while time.monotonic() - start_time < timeout:
 | 
											
												
													
														|  | 
 |  | +			buf = self.recv(1)
 | 
											
												
													
														|  | 
 |  | +			if buf == bytes((self.CMD_ACK, )):
 | 
											
												
													
														|  | 
 |  | +				return True
 | 
											
												
													
														|  | 
 |  | +			if buf == bytes((self.CMD_NACK, )):
 | 
											
												
													
														|  | 
 |  | +				return False
 | 
											
												
													
														|  | 
 |  | +		return None
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	def wait_ack(self, timeout: float = 10):
 | 
											
												
													
														|  | 
 |  | +		"""Wait ACK byte from device, raise exception on NACK or timeout"""
 | 
											
												
													
														|  | 
 |  | +		res = self.receive_ack(timeout)
 | 
											
												
													
														|  | 
 |  | +		if res is None:
 | 
											
												
													
														|  | 
 |  | +			raise ErrorAT32('ACK timeout')
 | 
											
												
													
														|  | 
 |  | +		if not res:
 | 
											
												
													
														|  | 
 |  | +			raise ErrorAT32('NACK received')
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	def send_cmd(self, code: int):
 | 
											
												
													
														|  | 
 |  | +		"""Send command and wait ACK"""
 | 
											
												
													
														|  | 
 |  | +		self.send((code, ~code & 0xFF))
 | 
											
												
													
														|  | 
 |  | +		self.wait_ack()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	def set_isp(self):
 | 
											
												
													
														|  | 
 |  | +		"""Send host data to device"""
 | 
											
												
													
														|  | 
 |  | +		self.send((0xFA, 0x05))
 | 
											
												
													
														|  | 
 |  | +		ack = self.receive_ack()
 | 
											
												
													
														|  | 
 |  | +		if ack is None:
 | 
											
												
													
														|  | 
 |  | +			raise ErrorAT32('ACK timeout')
 | 
											
												
													
														|  | 
 |  | +		if not ack:
 | 
											
												
													
														|  | 
 |  | +			raise ErrorAT32('NACK received')
 | 
											
												
													
														|  | 
 |  | +			return
 | 
											
												
													
														|  | 
 |  | +		self.send((0x02, 0x03, 0x54, 0x41), add_checksum=True)
 | 
											
												
													
														|  | 
 |  | +		self.wait_ack()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	def connect(self, timeout: float = 10):
 | 
											
												
													
														|  | 
 |  | +		"""Init connection with AT32 bootloader"""
 | 
											
												
													
														|  | 
 |  | +		start_time = time.monotonic()
 | 
											
												
													
														|  | 
 |  | +		while time.monotonic() - start_time < timeout:
 | 
											
												
													
														|  | 
 |  | +			self.send(self.CMD_SYNC)
 | 
											
												
													
														|  | 
 |  | +			if self.recv(1, 0.01) == bytes((self.CMD_ACK, )):
 | 
											
												
													
														|  | 
 |  | +				return self.set_isp()
 | 
											
												
													
														|  | 
 |  | +		raise ErrorAT32('Connection timeout')
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	def access_unprotect(self):
 | 
											
												
													
														|  | 
 |  | +		"""Disable access protection"""
 | 
											
												
													
														|  | 
 |  | +		self.send_cmd(0x92)
 | 
											
												
													
														|  | 
 |  | +		self.wait_ack()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	def get_device_id(self) -> Tuple[int, int]:
 | 
											
												
													
														|  | 
 |  | +		"""Read some device info"""
 | 
											
												
													
														|  | 
 |  | +		self.send_cmd(0x02)
 | 
											
												
													
														|  | 
 |  | +		length = self.read(1)[0]
 | 
											
												
													
														|  | 
 |  | +		if length != 4:
 | 
											
												
													
														|  | 
 |  | +			raise ErrorAT32('Incorrect response length')
 | 
											
												
													
														|  | 
 |  | +		buf = self.read(length + 1)
 | 
											
												
													
														|  | 
 |  | +		project_code, device_code = buf[4], buf[1] | (buf[0] << 8) | (buf[3] << 16) | (buf[2] << 24)
 | 
											
												
													
														|  | 
 |  | +		if project_code not in {8, }:
 | 
											
												
													
														|  | 
 |  | +			raise ErrorAT32('Device not supported')
 | 
											
												
													
														|  | 
 |  | +		return project_code, device_code
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	def reset(self):
 | 
											
												
													
														|  | 
 |  | +		"""SW reboot for device"""
 | 
											
												
													
														|  | 
 |  | +		self.send_cmd(0xD4)
 | 
											
												
													
														|  | 
 |  | +		self.wait_ack()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	def read_mem(self, address: int, count: int) -> bytes:
 | 
											
												
													
														|  | 
 |  | +		"""Read data from device memory"""
 | 
											
												
													
														|  | 
 |  | +		self.send_cmd(0x11)
 | 
											
												
													
														|  | 
 |  | +		self.send(address.to_bytes(4, 'big'), add_checksum=True)
 | 
											
												
													
														|  | 
 |  | +		self.wait_ack()
 | 
											
												
													
														|  | 
 |  | +		self.send((count - 1).to_bytes(1, 'big'), add_checksum=True)
 | 
											
												
													
														|  | 
 |  | +		self.wait_ack()
 | 
											
												
													
														|  | 
 |  | +		return self.read(count)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	def write_mem(self, address: int, buf: bytes):
 | 
											
												
													
														|  | 
 |  | +		"""Write data to device memory"""
 | 
											
												
													
														|  | 
 |  | +		self.send_cmd(0x31)
 | 
											
												
													
														|  | 
 |  | +		self.send(address.to_bytes(4, 'big'), add_checksum=True)
 | 
											
												
													
														|  | 
 |  | +		self.wait_ack()
 | 
											
												
													
														|  | 
 |  | +		self.send(bytes((len(buf) - 1, )) + buf, add_checksum=True)
 | 
											
												
													
														|  | 
 |  | +		self.wait_ack()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	def get_flash_size(self) -> int:
 | 
											
												
													
														|  | 
 |  | +		"""Read flash size"""
 | 
											
												
													
														|  | 
 |  | +		return int.from_bytes(self.read_mem(0x1FFFF7E0, 2), 'little') & 0x0000FFFF
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	def get_uid(self) -> bytes:
 | 
											
												
													
														|  | 
 |  | +		"""Read unique chip ID"""
 | 
											
												
													
														|  | 
 |  | +		return self.read_mem(0x1FFFF7E8, 12)
 | 
											
												
													
														|  | 
 |  | +		# buf = bytearray()
 | 
											
												
													
														|  | 
 |  | +		# x = 3
 | 
											
												
													
														|  | 
 |  | +		# for i in range(12 // x):
 | 
											
												
													
														|  | 
 |  | +		#	 buf.extend(self.read_mem(0x1FFFF7E8 + i * x, 1 * x))
 | 
											
												
													
														|  | 
 |  | +		# return bytes(buf)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	def get_uid_str(self) -> str:
 | 
											
												
													
														|  | 
 |  | +		"""Get unique chip ID as hex string"""
 | 
											
												
													
														|  | 
 |  | +		return ''.join(f'{b:02X}' for b in self.get_uid())
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	def read_flash(self, address: int, size: int) -> bytes:
 | 
											
												
													
														|  | 
 |  | +		"""Read all content of flash memory"""
 | 
											
												
													
														|  | 
 |  | +		chunks, last_bytes_count = divmod(size, self.CHUNK_SIZE)
 | 
											
												
													
														|  | 
 |  | +		res = bytearray()
 | 
											
												
													
														|  | 
 |  | +		for i in range(chunks):
 | 
											
												
													
														|  | 
 |  | +			res.extend(self.read_mem(address + i * self.CHUNK_SIZE, self.CHUNK_SIZE))
 | 
											
												
													
														|  | 
 |  | +		if last_bytes_count != 0:
 | 
											
												
													
														|  | 
 |  | +			res.extend(self.read_mem(address + chunks * self.CHUNK_SIZE, last_bytes_count))
 | 
											
												
													
														|  | 
 |  | +		return bytes(res)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	def read_flash_to_file(self, address: int, size: int, path: str):
 | 
											
												
													
														|  | 
 |  | +		"""Read all content of flash memory to file"""
 | 
											
												
													
														|  | 
 |  | +		with open(path, 'wb') as f:
 | 
											
												
													
														|  | 
 |  | +			f.write(self.read_flash(address, size))	
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	def erase_flash(self):
 | 
											
												
													
														|  | 
 |  | +		"""Do mass erase"""
 | 
											
												
													
														|  | 
 |  | +		self.send_cmd(0x44)
 | 
											
												
													
														|  | 
 |  | +		self.send((0xFF, 0xFF), add_checksum=True)
 | 
											
												
													
														|  | 
 |  | +		self.wait_ack()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	def write_flash(self, address: int, buf: bytes):
 | 
											
												
													
														|  | 
 |  | +		"""Write binary data to flash"""
 | 
											
												
													
														|  | 
 |  | +		flash_full_chunks, last_chunk_size = divmod(len(buf), self.CHUNK_SIZE)
 | 
											
												
													
														|  | 
 |  | +		for i in range(flash_full_chunks):
 | 
											
												
													
														|  | 
 |  | +			self.write_mem(address + i * self.CHUNK_SIZE, buf[i * self.CHUNK_SIZE:(i + 1) * self.CHUNK_SIZE])
 | 
											
												
													
														|  | 
 |  | +		if last_chunk_size > 0:
 | 
											
												
													
														|  | 
 |  | +			self.write_mem(address + flash_full_chunks * self.CHUNK_SIZE,
 | 
											
												
													
														|  | 
 |  | +						   buf[flash_full_chunks * self.CHUNK_SIZE:])
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	def write_file_to_flash(self, address: int, path: str):
 | 
											
												
													
														|  | 
 |  | +		"""write binary file to flash"""
 | 
											
												
													
														|  | 
 |  | +		with open(path, 'rb') as f:
 | 
											
												
													
														|  | 
 |  | +			self.write_flash(address, f.read())
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +if __name__ == '__main__':
 | 
											
												
													
														|  | 
 |  | +	start_time = time.time()
 | 
											
												
													
														|  | 
 |  | +	d = FlashAT32('/dev/ttyUSB0', 115200 * 8)
 | 
											
												
													
														|  | 
 |  | +	d.DEBUG_PRINT = True
 | 
											
												
													
														|  | 
 |  | +	try:
 | 
											
												
													
														|  | 
 |  | +		d.connect()
 | 
											
												
													
														|  | 
 |  | +	except Exception as e:
 | 
											
												
													
														|  | 
 |  | +		print(e)
 | 
											
												
													
														|  | 
 |  | +	print(d.get_flash_size())
 | 
											
												
													
														|  | 
 |  | +	print(d.get_uid_str())
 | 
											
												
													
														|  | 
 |  | +	d.erase_flash()
 | 
											
												
													
														|  | 
 |  | +	iap_path = '../output/iap.bin'
 | 
											
												
													
														|  | 
 |  | +	fw_path = '../output/fw.bin'
 | 
											
												
													
														|  | 
 |  | +	iap_path_r = '../output/m3_artery_iap.bin'
 | 
											
												
													
														|  | 
 |  | +	fw_path_r = '../output/m3_artery_fw.bin'
 | 
											
												
													
														|  | 
 |  | +	d.write_file_to_flash(0x08000000, iap_path)
 | 
											
												
													
														|  | 
 |  | +	d.write_file_to_flash(0x08021000, fw_path)
 | 
											
												
													
														|  | 
 |  | +	d.read_flash_to_file(0x08000000, os.path.getsize(iap_path), iap_path_r)
 | 
											
												
													
														|  | 
 |  | +	d.read_flash_to_file(0x08021000, os.path.getsize(fw_path), fw_path_r)
 | 
											
												
													
														|  | 
 |  | +	os.system(f'diff {iap_path} {iap_path_r}')
 | 
											
												
													
														|  | 
 |  | +	os.system(f'diff {fw_path} {fw_path_r}')
 | 
											
												
													
														|  | 
 |  | +	os.remove(iap_path_r)
 | 
											
												
													
														|  | 
 |  | +	os.remove(fw_path_r)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	print(time.time() - start_time)
 |