123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233 |
- import os
- import time
- from typing import Iterable, Union, Optional, Tuple
- from serial import Serial
- class ErrorAT32(IOError):
- """Common error for FlashAT32 util"""
- SendData = Union[str, bytes, bytearray, Iterable[int], int]
- class FlashAT32:
- """Artery AT32x flashing utility"""
- CMD_SYNC = 0x7F
- CMD_ACK = 0x79
- CMD_NACK = 0x1F
- ADDR_FLASH = 0x08000000
- CHUNK_SIZE = 256
- DEBUG_PRINT = False
- # def __init__(self, tty: str, baudrate: int = 115200):
- def __init__(self, tty: str, baudrate: int):
- self.serial = Serial(port=tty, baudrate=baudrate, timeout=1, parity='E', xonxoff=False)
- @staticmethod
- def to_bytes(data: SendData) -> bytes:
- """Convert various types of data to bytes"""
- if isinstance(data, str):
- return data.encode()
- elif isinstance(data, int):
- return bytes((data, ))
- else:
- return bytes(data)
- @staticmethod
- def checksum(buf: bytes) -> int:
- """Calculate standard CRC-8 checksum"""
- crc = 0 if len(buf) > 1 else 0xFF
- for b in buf:
- crc = crc.__xor__(b)
- return crc
- def send(self, data: SendData, add_checksum: bool = False):
- """Send data to device"""
- buf = self.to_bytes(data)
- if add_checksum:
- buf += bytes((self.checksum(buf), ))
- if self.DEBUG_PRINT:
- print(f'> {buf}')
- self.serial.write(buf)
- def recv(self, count: int, timeout: float = 1) -> bytes:
- """Receive bytes from device"""
- self.serial.timeout = timeout
- buf = self.serial.read(count)
- if self.DEBUG_PRINT:
- print(f'< {buf}')
- return buf
- def read(self, count: int, timeout: float = 3) -> bytes:
- """Read count bytes from device or raise exception"""
- buf = bytearray()
- start_time = time.monotonic()
- while time.monotonic() - start_time < timeout:
- buf.extend(self.recv(count - len(buf)))
- if len(buf) == count:
- return bytes(buf)
- raise ErrorAT32('Read timeout')
- def receive_ack(self, timeout: float = 10) -> Optional[bool]:
- """Wait ACK byte from device, return True on ACN, False on NACK, None on timeout"""
- start_time = time.monotonic()
- while time.monotonic() - start_time < timeout:
- buf = self.recv(1)
- if buf == bytes((self.CMD_ACK, )):
- return True
- if buf == bytes((self.CMD_NACK, )):
- return False
- return None
- def wait_ack(self, timeout: float = 10):
- """Wait ACK byte from device, raise exception on NACK or timeout"""
- res = self.receive_ack(timeout)
- if res is None:
- raise ErrorAT32('ACK timeout')
- if not res:
- raise ErrorAT32('NACK received')
- def send_cmd(self, code: int):
- """Send command and wait ACK"""
- self.send((code, ~code & 0xFF))
- self.wait_ack()
- def set_isp(self):
- """Send host data to device"""
- self.send((0xFA, 0x05))
- ack = self.receive_ack()
- if ack is None:
- raise ErrorAT32('ACK timeout')
- if not ack:
- raise ErrorAT32('NACK received')
- return
- self.send((0x02, 0x03, 0x54, 0x41), add_checksum=True)
- self.wait_ack()
- def connect(self, timeout: float = 10):
- """Init connection with AT32 bootloader"""
- start_time = time.monotonic()
- while time.monotonic() - start_time < timeout:
- self.send(self.CMD_SYNC)
- if self.recv(1, 0.01) == bytes((self.CMD_ACK, )):
- return self.set_isp()
- raise ErrorAT32('Connection timeout')
- def access_unprotect(self):
- """Disable access protection"""
- self.send_cmd(0x92)
- self.wait_ack()
- def get_device_id(self) -> Tuple[int, int]:
- """Read some device info"""
- self.send_cmd(0x02)
- length = self.read(1)[0]
- if length != 4:
- raise ErrorAT32('Incorrect response length')
- buf = self.read(length + 1)
- project_code, device_code = buf[4], buf[1] | (buf[0] << 8) | (buf[3] << 16) | (buf[2] << 24)
- if project_code not in {8, }:
- raise ErrorAT32('Device not supported')
- return project_code, device_code
- def reset(self):
- """SW reboot for device"""
- self.send_cmd(0xD4)
- self.wait_ack()
- def read_mem(self, address: int, count: int) -> bytes:
- """Read data from device memory"""
- self.send_cmd(0x11)
- self.send(address.to_bytes(4, 'big'), add_checksum=True)
- self.wait_ack()
- self.send((count - 1).to_bytes(1, 'big'), add_checksum=True)
- self.wait_ack()
- return self.read(count)
- def write_mem(self, address: int, buf: bytes):
- """Write data to device memory"""
- self.send_cmd(0x31)
- self.send(address.to_bytes(4, 'big'), add_checksum=True)
- self.wait_ack()
- self.send(bytes((len(buf) - 1, )) + buf, add_checksum=True)
- self.wait_ack()
- def get_flash_size(self) -> int:
- """Read flash size"""
- return int.from_bytes(self.read_mem(0x1FFFF7E0, 2), 'little') & 0x0000FFFF
- def get_uid(self) -> bytes:
- """Read unique chip ID"""
- return self.read_mem(0x1FFFF7E8, 12)
- # buf = bytearray()
- # x = 3
- # for i in range(12 // x):
- # buf.extend(self.read_mem(0x1FFFF7E8 + i * x, 1 * x))
- # return bytes(buf)
- def get_uid_str(self) -> str:
- """Get unique chip ID as hex string"""
- return ''.join(f'{b:02X}' for b in self.get_uid())
- def read_flash(self, address: int, size: int) -> bytes:
- """Read all content of flash memory"""
- chunks, last_bytes_count = divmod(size, self.CHUNK_SIZE)
- res = bytearray()
- for i in range(chunks):
- res.extend(self.read_mem(address + i * self.CHUNK_SIZE, self.CHUNK_SIZE))
- if last_bytes_count != 0:
- res.extend(self.read_mem(address + chunks * self.CHUNK_SIZE, last_bytes_count))
- return bytes(res)
- def read_flash_to_file(self, address: int, size: int, path: str):
- """Read all content of flash memory to file"""
- with open(path, 'wb') as f:
- f.write(self.read_flash(address, size))
- def erase_flash(self):
- """Do mass erase"""
- self.send_cmd(0x44)
- self.send((0xFF, 0xFF), add_checksum=True)
- self.wait_ack()
- def write_flash(self, address: int, buf: bytes):
- """Write binary data to flash"""
- flash_full_chunks, last_chunk_size = divmod(len(buf), self.CHUNK_SIZE)
- for i in range(flash_full_chunks):
- self.write_mem(address + i * self.CHUNK_SIZE, buf[i * self.CHUNK_SIZE:(i + 1) * self.CHUNK_SIZE])
- if last_chunk_size > 0:
- self.write_mem(address + flash_full_chunks * self.CHUNK_SIZE,
- buf[flash_full_chunks * self.CHUNK_SIZE:])
- def write_file_to_flash(self, address: int, path: str):
- """write binary file to flash"""
- with open(path, 'rb') as f:
- self.write_flash(address, f.read())
- if __name__ == '__main__':
- start_time = time.time()
- d = FlashAT32('/dev/ttyUSB0', 115200 * 8)
- d.DEBUG_PRINT = True
- try:
- d.connect()
- except Exception as e:
- print(e)
- print(d.get_flash_size())
- print(d.get_uid_str())
- d.erase_flash()
- iap_path = '../output/iap.bin'
- fw_path = '../output/fw.bin'
- iap_path_r = '../output/m3_artery_iap.bin'
- fw_path_r = '../output/m3_artery_fw.bin'
- d.write_file_to_flash(0x08000000, iap_path)
- d.write_file_to_flash(0x08021000, fw_path)
- d.read_flash_to_file(0x08000000, os.path.getsize(iap_path), iap_path_r)
- d.read_flash_to_file(0x08021000, os.path.getsize(fw_path), fw_path_r)
- os.system(f'diff {iap_path} {iap_path_r}')
- os.system(f'diff {fw_path} {fw_path_r}')
- os.remove(iap_path_r)
- os.remove(fw_path_r)
- print(time.time() - start_time)
|