|
@@ -0,0 +1,233 @@
|
|
|
+import os
|
|
|
+import time
|
|
|
+from typing import Iterable, Union, Optional, Tuple
|
|
|
+from serial import Serial
|
|
|
+
|
|
|
+
|
|
|
+class ErrorAT32(IOError):
|
|
|
+ """Common error for FlashAT32 util"""
|
|
|
+
|
|
|
+
|
|
|
+SendData = Union[str, bytes, bytearray, Iterable[int], int]
|
|
|
+
|
|
|
+
|
|
|
+class FlashAT32:
|
|
|
+ """Artery AT32x flashing utility"""
|
|
|
+ CMD_SYNC = 0x7F
|
|
|
+ CMD_ACK = 0x79
|
|
|
+ CMD_NACK = 0x1F
|
|
|
+ ADDR_FLASH = 0x08000000
|
|
|
+ CHUNK_SIZE = 256
|
|
|
+ DEBUG_PRINT = False
|
|
|
+
|
|
|
+ # def __init__(self, tty: str, baudrate: int = 115200):
|
|
|
+ def __init__(self, tty: str, baudrate: int):
|
|
|
+ self.serial = Serial(port=tty, baudrate=baudrate, timeout=1, parity='E', xonxoff=False)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def to_bytes(data: SendData) -> bytes:
|
|
|
+ """Convert various types of data to bytes"""
|
|
|
+ if isinstance(data, str):
|
|
|
+ return data.encode()
|
|
|
+ elif isinstance(data, int):
|
|
|
+ return bytes((data, ))
|
|
|
+ else:
|
|
|
+ return bytes(data)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def checksum(buf: bytes) -> int:
|
|
|
+ """Calculate standard CRC-8 checksum"""
|
|
|
+ crc = 0 if len(buf) > 1 else 0xFF
|
|
|
+ for b in buf:
|
|
|
+ crc = crc.__xor__(b)
|
|
|
+ return crc
|
|
|
+
|
|
|
+ def send(self, data: SendData, add_checksum: bool = False):
|
|
|
+ """Send data to device"""
|
|
|
+ buf = self.to_bytes(data)
|
|
|
+ if add_checksum:
|
|
|
+ buf += bytes((self.checksum(buf), ))
|
|
|
+ if self.DEBUG_PRINT:
|
|
|
+ print(f'> {buf}')
|
|
|
+ self.serial.write(buf)
|
|
|
+
|
|
|
+ def recv(self, count: int, timeout: float = 1) -> bytes:
|
|
|
+ """Receive bytes from device"""
|
|
|
+ self.serial.timeout = timeout
|
|
|
+ buf = self.serial.read(count)
|
|
|
+ if self.DEBUG_PRINT:
|
|
|
+ print(f'< {buf}')
|
|
|
+ return buf
|
|
|
+
|
|
|
+ def read(self, count: int, timeout: float = 3) -> bytes:
|
|
|
+ """Read count bytes from device or raise exception"""
|
|
|
+ buf = bytearray()
|
|
|
+ start_time = time.monotonic()
|
|
|
+ while time.monotonic() - start_time < timeout:
|
|
|
+ buf.extend(self.recv(count - len(buf)))
|
|
|
+ if len(buf) == count:
|
|
|
+ return bytes(buf)
|
|
|
+ raise ErrorAT32('Read timeout')
|
|
|
+
|
|
|
+ def receive_ack(self, timeout: float = 10) -> Optional[bool]:
|
|
|
+ """Wait ACK byte from device, return True on ACN, False on NACK, None on timeout"""
|
|
|
+ start_time = time.monotonic()
|
|
|
+ while time.monotonic() - start_time < timeout:
|
|
|
+ buf = self.recv(1)
|
|
|
+ if buf == bytes((self.CMD_ACK, )):
|
|
|
+ return True
|
|
|
+ if buf == bytes((self.CMD_NACK, )):
|
|
|
+ return False
|
|
|
+ return None
|
|
|
+
|
|
|
+ def wait_ack(self, timeout: float = 10):
|
|
|
+ """Wait ACK byte from device, raise exception on NACK or timeout"""
|
|
|
+ res = self.receive_ack(timeout)
|
|
|
+ if res is None:
|
|
|
+ raise ErrorAT32('ACK timeout')
|
|
|
+ if not res:
|
|
|
+ raise ErrorAT32('NACK received')
|
|
|
+
|
|
|
+ def send_cmd(self, code: int):
|
|
|
+ """Send command and wait ACK"""
|
|
|
+ self.send((code, ~code & 0xFF))
|
|
|
+ self.wait_ack()
|
|
|
+
|
|
|
+ def set_isp(self):
|
|
|
+ """Send host data to device"""
|
|
|
+ self.send((0xFA, 0x05))
|
|
|
+ ack = self.receive_ack()
|
|
|
+ if ack is None:
|
|
|
+ raise ErrorAT32('ACK timeout')
|
|
|
+ if not ack:
|
|
|
+ raise ErrorAT32('NACK received')
|
|
|
+ return
|
|
|
+ self.send((0x02, 0x03, 0x54, 0x41), add_checksum=True)
|
|
|
+ self.wait_ack()
|
|
|
+
|
|
|
+ def connect(self, timeout: float = 10):
|
|
|
+ """Init connection with AT32 bootloader"""
|
|
|
+ start_time = time.monotonic()
|
|
|
+ while time.monotonic() - start_time < timeout:
|
|
|
+ self.send(self.CMD_SYNC)
|
|
|
+ if self.recv(1, 0.01) == bytes((self.CMD_ACK, )):
|
|
|
+ return self.set_isp()
|
|
|
+ raise ErrorAT32('Connection timeout')
|
|
|
+
|
|
|
+ def access_unprotect(self):
|
|
|
+ """Disable access protection"""
|
|
|
+ self.send_cmd(0x92)
|
|
|
+ self.wait_ack()
|
|
|
+
|
|
|
+ def get_device_id(self) -> Tuple[int, int]:
|
|
|
+ """Read some device info"""
|
|
|
+ self.send_cmd(0x02)
|
|
|
+ length = self.read(1)[0]
|
|
|
+ if length != 4:
|
|
|
+ raise ErrorAT32('Incorrect response length')
|
|
|
+ buf = self.read(length + 1)
|
|
|
+ project_code, device_code = buf[4], buf[1] | (buf[0] << 8) | (buf[3] << 16) | (buf[2] << 24)
|
|
|
+ if project_code not in {8, }:
|
|
|
+ raise ErrorAT32('Device not supported')
|
|
|
+ return project_code, device_code
|
|
|
+
|
|
|
+ def reset(self):
|
|
|
+ """SW reboot for device"""
|
|
|
+ self.send_cmd(0xD4)
|
|
|
+ self.wait_ack()
|
|
|
+
|
|
|
+ def read_mem(self, address: int, count: int) -> bytes:
|
|
|
+ """Read data from device memory"""
|
|
|
+ self.send_cmd(0x11)
|
|
|
+ self.send(address.to_bytes(4, 'big'), add_checksum=True)
|
|
|
+ self.wait_ack()
|
|
|
+ self.send((count - 1).to_bytes(1, 'big'), add_checksum=True)
|
|
|
+ self.wait_ack()
|
|
|
+ return self.read(count)
|
|
|
+
|
|
|
+ def write_mem(self, address: int, buf: bytes):
|
|
|
+ """Write data to device memory"""
|
|
|
+ self.send_cmd(0x31)
|
|
|
+ self.send(address.to_bytes(4, 'big'), add_checksum=True)
|
|
|
+ self.wait_ack()
|
|
|
+ self.send(bytes((len(buf) - 1, )) + buf, add_checksum=True)
|
|
|
+ self.wait_ack()
|
|
|
+
|
|
|
+ def get_flash_size(self) -> int:
|
|
|
+ """Read flash size"""
|
|
|
+ return int.from_bytes(self.read_mem(0x1FFFF7E0, 2), 'little') & 0x0000FFFF
|
|
|
+
|
|
|
+ def get_uid(self) -> bytes:
|
|
|
+ """Read unique chip ID"""
|
|
|
+ return self.read_mem(0x1FFFF7E8, 12)
|
|
|
+ # buf = bytearray()
|
|
|
+ # x = 3
|
|
|
+ # for i in range(12 // x):
|
|
|
+ # buf.extend(self.read_mem(0x1FFFF7E8 + i * x, 1 * x))
|
|
|
+ # return bytes(buf)
|
|
|
+
|
|
|
+ def get_uid_str(self) -> str:
|
|
|
+ """Get unique chip ID as hex string"""
|
|
|
+ return ''.join(f'{b:02X}' for b in self.get_uid())
|
|
|
+
|
|
|
+ def read_flash(self, address: int, size: int) -> bytes:
|
|
|
+ """Read all content of flash memory"""
|
|
|
+ chunks, last_bytes_count = divmod(size, self.CHUNK_SIZE)
|
|
|
+ res = bytearray()
|
|
|
+ for i in range(chunks):
|
|
|
+ res.extend(self.read_mem(address + i * self.CHUNK_SIZE, self.CHUNK_SIZE))
|
|
|
+ if last_bytes_count != 0:
|
|
|
+ res.extend(self.read_mem(address + chunks * self.CHUNK_SIZE, last_bytes_count))
|
|
|
+ return bytes(res)
|
|
|
+
|
|
|
+ def read_flash_to_file(self, address: int, size: int, path: str):
|
|
|
+ """Read all content of flash memory to file"""
|
|
|
+ with open(path, 'wb') as f:
|
|
|
+ f.write(self.read_flash(address, size))
|
|
|
+
|
|
|
+ def erase_flash(self):
|
|
|
+ """Do mass erase"""
|
|
|
+ self.send_cmd(0x44)
|
|
|
+ self.send((0xFF, 0xFF), add_checksum=True)
|
|
|
+ self.wait_ack()
|
|
|
+
|
|
|
+ def write_flash(self, address: int, buf: bytes):
|
|
|
+ """Write binary data to flash"""
|
|
|
+ flash_full_chunks, last_chunk_size = divmod(len(buf), self.CHUNK_SIZE)
|
|
|
+ for i in range(flash_full_chunks):
|
|
|
+ self.write_mem(address + i * self.CHUNK_SIZE, buf[i * self.CHUNK_SIZE:(i + 1) * self.CHUNK_SIZE])
|
|
|
+ if last_chunk_size > 0:
|
|
|
+ self.write_mem(address + flash_full_chunks * self.CHUNK_SIZE,
|
|
|
+ buf[flash_full_chunks * self.CHUNK_SIZE:])
|
|
|
+
|
|
|
+ def write_file_to_flash(self, address: int, path: str):
|
|
|
+ """write binary file to flash"""
|
|
|
+ with open(path, 'rb') as f:
|
|
|
+ self.write_flash(address, f.read())
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ start_time = time.time()
|
|
|
+ d = FlashAT32('/dev/ttyUSB0', 115200 * 8)
|
|
|
+ d.DEBUG_PRINT = True
|
|
|
+ try:
|
|
|
+ d.connect()
|
|
|
+ except Exception as e:
|
|
|
+ print(e)
|
|
|
+ print(d.get_flash_size())
|
|
|
+ print(d.get_uid_str())
|
|
|
+ d.erase_flash()
|
|
|
+ iap_path = '../output/iap.bin'
|
|
|
+ fw_path = '../output/fw.bin'
|
|
|
+ iap_path_r = '../output/m3_artery_iap.bin'
|
|
|
+ fw_path_r = '../output/m3_artery_fw.bin'
|
|
|
+ d.write_file_to_flash(0x08000000, iap_path)
|
|
|
+ d.write_file_to_flash(0x08021000, fw_path)
|
|
|
+ d.read_flash_to_file(0x08000000, os.path.getsize(iap_path), iap_path_r)
|
|
|
+ d.read_flash_to_file(0x08021000, os.path.getsize(fw_path), fw_path_r)
|
|
|
+ os.system(f'diff {iap_path} {iap_path_r}')
|
|
|
+ os.system(f'diff {fw_path} {fw_path_r}')
|
|
|
+ os.remove(iap_path_r)
|
|
|
+ os.remove(fw_path_r)
|
|
|
+
|
|
|
+ print(time.time() - start_time)
|