import os import time from typing import Iterable, Union, Optional, Tuple from serial import Serial class ErrorAT32(IOError): """Common error for FlashAT32 util""" SendData = Union[str, bytes, bytearray, Iterable[int], int] class FlashAT32: """Artery AT32x flashing utility""" CMD_SYNC = 0x7F CMD_ACK = 0x79 CMD_NACK = 0x1F ADDR_FLASH = 0x08000000 CHUNK_SIZE = 256 DEBUG_PRINT = False # def __init__(self, tty: str, baudrate: int = 115200): def __init__(self, tty: str, baudrate: int): self.serial = Serial(port=tty, baudrate=baudrate, timeout=1, parity='E', xonxoff=False) @staticmethod def to_bytes(data: SendData) -> bytes: """Convert various types of data to bytes""" if isinstance(data, str): return data.encode() elif isinstance(data, int): return bytes((data, )) else: return bytes(data) @staticmethod def checksum(buf: bytes) -> int: """Calculate standard CRC-8 checksum""" crc = 0 if len(buf) > 1 else 0xFF for b in buf: crc = crc.__xor__(b) return crc def send(self, data: SendData, add_checksum: bool = False): """Send data to device""" buf = self.to_bytes(data) if add_checksum: buf += bytes((self.checksum(buf), )) if self.DEBUG_PRINT: print(f'> {buf}') self.serial.write(buf) def recv(self, count: int, timeout: float = 1) -> bytes: """Receive bytes from device""" self.serial.timeout = timeout buf = self.serial.read(count) if self.DEBUG_PRINT: print(f'< {buf}') return buf def read(self, count: int, timeout: float = 3) -> bytes: """Read count bytes from device or raise exception""" buf = bytearray() start_time = time.monotonic() while time.monotonic() - start_time < timeout: buf.extend(self.recv(count - len(buf))) if len(buf) == count: return bytes(buf) raise ErrorAT32('Read timeout') def receive_ack(self, timeout: float = 10) -> Optional[bool]: """Wait ACK byte from device, return True on ACN, False on NACK, None on timeout""" start_time = time.monotonic() while time.monotonic() - start_time < timeout: buf = self.recv(1) if buf == bytes((self.CMD_ACK, )): return True if buf == bytes((self.CMD_NACK, )): return False return None def wait_ack(self, timeout: float = 10): """Wait ACK byte from device, raise exception on NACK or timeout""" res = self.receive_ack(timeout) if res is None: raise ErrorAT32('ACK timeout') if not res: raise ErrorAT32('NACK received') def send_cmd(self, code: int): """Send command and wait ACK""" self.send((code, ~code & 0xFF)) self.wait_ack() def set_isp(self): """Send host data to device""" self.send((0xFA, 0x05)) ack = self.receive_ack() if ack is None: raise ErrorAT32('ACK timeout') if not ack: raise ErrorAT32('NACK received') return self.send((0x02, 0x03, 0x54, 0x41), add_checksum=True) self.wait_ack() def connect(self, timeout: float = 10): """Init connection with AT32 bootloader""" start_time = time.monotonic() while time.monotonic() - start_time < timeout: self.send(self.CMD_SYNC) if self.recv(1, 0.01) == bytes((self.CMD_ACK, )): return self.set_isp() raise ErrorAT32('Connection timeout') def access_unprotect(self): """Disable access protection""" self.send_cmd(0x92) self.wait_ack() def get_device_id(self) -> Tuple[int, int]: """Read some device info""" self.send_cmd(0x02) length = self.read(1)[0] if length != 4: raise ErrorAT32('Incorrect response length') buf = self.read(length + 1) project_code, device_code = buf[4], buf[1] | (buf[0] << 8) | (buf[3] << 16) | (buf[2] << 24) if project_code not in {8, }: raise ErrorAT32('Device not supported') return project_code, device_code def reset(self): """SW reboot for device""" self.send_cmd(0xD4) self.wait_ack() def read_mem(self, address: int, count: int) -> bytes: """Read data from device memory""" self.send_cmd(0x11) self.send(address.to_bytes(4, 'big'), add_checksum=True) self.wait_ack() self.send((count - 1).to_bytes(1, 'big'), add_checksum=True) self.wait_ack() return self.read(count) def write_mem(self, address: int, buf: bytes): """Write data to device memory""" self.send_cmd(0x31) self.send(address.to_bytes(4, 'big'), add_checksum=True) self.wait_ack() self.send(bytes((len(buf) - 1, )) + buf, add_checksum=True) self.wait_ack() def get_flash_size(self) -> int: """Read flash size""" return int.from_bytes(self.read_mem(0x1FFFF7E0, 2), 'little') & 0x0000FFFF def get_uid(self) -> bytes: """Read unique chip ID""" return self.read_mem(0x1FFFF7E8, 12) # buf = bytearray() # x = 3 # for i in range(12 // x): # buf.extend(self.read_mem(0x1FFFF7E8 + i * x, 1 * x)) # return bytes(buf) def get_uid_str(self) -> str: """Get unique chip ID as hex string""" return ''.join(f'{b:02X}' for b in self.get_uid()) def read_flash(self, address: int, size: int) -> bytes: """Read all content of flash memory""" chunks, last_bytes_count = divmod(size, self.CHUNK_SIZE) res = bytearray() for i in range(chunks): res.extend(self.read_mem(address + i * self.CHUNK_SIZE, self.CHUNK_SIZE)) if last_bytes_count != 0: res.extend(self.read_mem(address + chunks * self.CHUNK_SIZE, last_bytes_count)) return bytes(res) def read_flash_to_file(self, address: int, size: int, path: str): """Read all content of flash memory to file""" with open(path, 'wb') as f: f.write(self.read_flash(address, size)) def erase_flash(self): """Do mass erase""" self.send_cmd(0x44) self.send((0xFF, 0xFF), add_checksum=True) self.wait_ack() def write_flash(self, address: int, buf: bytes): """Write binary data to flash""" flash_full_chunks, last_chunk_size = divmod(len(buf), self.CHUNK_SIZE) for i in range(flash_full_chunks): self.write_mem(address + i * self.CHUNK_SIZE, buf[i * self.CHUNK_SIZE:(i + 1) * self.CHUNK_SIZE]) if last_chunk_size > 0: self.write_mem(address + flash_full_chunks * self.CHUNK_SIZE, buf[flash_full_chunks * self.CHUNK_SIZE:]) def write_file_to_flash(self, address: int, path: str): """write binary file to flash""" with open(path, 'rb') as f: self.write_flash(address, f.read()) if __name__ == '__main__': start_time = time.time() d = FlashAT32('/dev/ttyUSB0', 115200 * 8) d.DEBUG_PRINT = True try: d.connect() except Exception as e: print(e) print(d.get_flash_size()) print(d.get_uid_str()) d.erase_flash() iap_path = '../output/iap.bin' fw_path = '../output/fw.bin' iap_path_r = '../output/m3_artery_iap.bin' fw_path_r = '../output/m3_artery_fw.bin' d.write_file_to_flash(0x08000000, iap_path) d.write_file_to_flash(0x08021000, fw_path) d.read_flash_to_file(0x08000000, os.path.getsize(iap_path), iap_path_r) d.read_flash_to_file(0x08021000, os.path.getsize(fw_path), fw_path_r) os.system(f'diff {iap_path} {iap_path_r}') os.system(f'diff {fw_path} {fw_path_r}') os.remove(iap_path_r) os.remove(fw_path_r) print(time.time() - start_time)