artery_loader.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. import os
  2. import time
  3. from typing import Iterable, Union, Optional, Tuple
  4. from serial import Serial
  5. class ErrorAT32(IOError):
  6. """Common error for FlashAT32 util"""
  7. SendData = Union[str, bytes, bytearray, Iterable[int], int]
  8. class FlashAT32:
  9. """Artery AT32x flashing utility"""
  10. CMD_SYNC = 0x7F
  11. CMD_ACK = 0x79
  12. CMD_NACK = 0x1F
  13. ADDR_FLASH = 0x08000000
  14. CHUNK_SIZE = 256
  15. DEBUG_PRINT = False
  16. # def __init__(self, tty: str, baudrate: int = 115200):
  17. def __init__(self, tty: str, baudrate: int):
  18. self.serial = Serial(port=tty, baudrate=baudrate, timeout=1, parity='E', xonxoff=False)
  19. @staticmethod
  20. def to_bytes(data: SendData) -> bytes:
  21. """Convert various types of data to bytes"""
  22. if isinstance(data, str):
  23. return data.encode()
  24. elif isinstance(data, int):
  25. return bytes((data, ))
  26. else:
  27. return bytes(data)
  28. @staticmethod
  29. def checksum(buf: bytes) -> int:
  30. """Calculate standard CRC-8 checksum"""
  31. crc = 0 if len(buf) > 1 else 0xFF
  32. for b in buf:
  33. crc = crc.__xor__(b)
  34. return crc
  35. def send(self, data: SendData, add_checksum: bool = False):
  36. """Send data to device"""
  37. buf = self.to_bytes(data)
  38. if add_checksum:
  39. buf += bytes((self.checksum(buf), ))
  40. if self.DEBUG_PRINT:
  41. print(f'> {buf}')
  42. self.serial.write(buf)
  43. def recv(self, count: int, timeout: float = 1) -> bytes:
  44. """Receive bytes from device"""
  45. self.serial.timeout = timeout
  46. buf = self.serial.read(count)
  47. if self.DEBUG_PRINT:
  48. print(f'< {buf}')
  49. return buf
  50. def read(self, count: int, timeout: float = 3) -> bytes:
  51. """Read count bytes from device or raise exception"""
  52. buf = bytearray()
  53. start_time = time.monotonic()
  54. while time.monotonic() - start_time < timeout:
  55. buf.extend(self.recv(count - len(buf)))
  56. if len(buf) == count:
  57. return bytes(buf)
  58. raise ErrorAT32('Read timeout')
  59. def receive_ack(self, timeout: float = 10) -> Optional[bool]:
  60. """Wait ACK byte from device, return True on ACN, False on NACK, None on timeout"""
  61. start_time = time.monotonic()
  62. while time.monotonic() - start_time < timeout:
  63. buf = self.recv(1)
  64. if buf == bytes((self.CMD_ACK, )):
  65. return True
  66. if buf == bytes((self.CMD_NACK, )):
  67. return False
  68. return None
  69. def wait_ack(self, timeout: float = 10):
  70. """Wait ACK byte from device, raise exception on NACK or timeout"""
  71. res = self.receive_ack(timeout)
  72. if res is None:
  73. raise ErrorAT32('ACK timeout')
  74. if not res:
  75. raise ErrorAT32('NACK received')
  76. def send_cmd(self, code: int):
  77. """Send command and wait ACK"""
  78. self.send((code, ~code & 0xFF))
  79. self.wait_ack()
  80. def set_isp(self):
  81. """Send host data to device"""
  82. self.send((0xFA, 0x05))
  83. ack = self.receive_ack()
  84. if ack is None:
  85. raise ErrorAT32('ACK timeout')
  86. if not ack:
  87. raise ErrorAT32('NACK received')
  88. return
  89. self.send((0x02, 0x03, 0x54, 0x41), add_checksum=True)
  90. self.wait_ack()
  91. def connect(self, timeout: float = 10):
  92. """Init connection with AT32 bootloader"""
  93. start_time = time.monotonic()
  94. while time.monotonic() - start_time < timeout:
  95. self.send(self.CMD_SYNC)
  96. if self.recv(1, 0.01) == bytes((self.CMD_ACK, )):
  97. return self.set_isp()
  98. raise ErrorAT32('Connection timeout')
  99. def access_unprotect(self):
  100. """Disable access protection"""
  101. self.send_cmd(0x92)
  102. self.wait_ack()
  103. def get_device_id(self) -> Tuple[int, int]:
  104. """Read some device info"""
  105. self.send_cmd(0x02)
  106. length = self.read(1)[0]
  107. if length != 4:
  108. raise ErrorAT32('Incorrect response length')
  109. buf = self.read(length + 1)
  110. project_code, device_code = buf[4], buf[1] | (buf[0] << 8) | (buf[3] << 16) | (buf[2] << 24)
  111. if project_code not in {8, }:
  112. raise ErrorAT32('Device not supported')
  113. return project_code, device_code
  114. def reset(self):
  115. """SW reboot for device"""
  116. self.send_cmd(0xD4)
  117. self.wait_ack()
  118. def read_mem(self, address: int, count: int) -> bytes:
  119. """Read data from device memory"""
  120. self.send_cmd(0x11)
  121. self.send(address.to_bytes(4, 'big'), add_checksum=True)
  122. self.wait_ack()
  123. self.send((count - 1).to_bytes(1, 'big'), add_checksum=True)
  124. self.wait_ack()
  125. return self.read(count)
  126. def write_mem(self, address: int, buf: bytes):
  127. """Write data to device memory"""
  128. self.send_cmd(0x31)
  129. self.send(address.to_bytes(4, 'big'), add_checksum=True)
  130. self.wait_ack()
  131. self.send(bytes((len(buf) - 1, )) + buf, add_checksum=True)
  132. self.wait_ack()
  133. def get_flash_size(self) -> int:
  134. """Read flash size"""
  135. return int.from_bytes(self.read_mem(0x1FFFF7E0, 2), 'little') & 0x0000FFFF
  136. def get_uid(self) -> bytes:
  137. """Read unique chip ID"""
  138. return self.read_mem(0x1FFFF7E8, 12)
  139. # buf = bytearray()
  140. # x = 3
  141. # for i in range(12 // x):
  142. # buf.extend(self.read_mem(0x1FFFF7E8 + i * x, 1 * x))
  143. # return bytes(buf)
  144. def get_uid_str(self) -> str:
  145. """Get unique chip ID as hex string"""
  146. return ''.join(f'{b:02X}' for b in self.get_uid())
  147. def read_flash(self, address: int, size: int) -> bytes:
  148. """Read all content of flash memory"""
  149. chunks, last_bytes_count = divmod(size, self.CHUNK_SIZE)
  150. res = bytearray()
  151. for i in range(chunks):
  152. res.extend(self.read_mem(address + i * self.CHUNK_SIZE, self.CHUNK_SIZE))
  153. if last_bytes_count != 0:
  154. res.extend(self.read_mem(address + chunks * self.CHUNK_SIZE, last_bytes_count))
  155. return bytes(res)
  156. def read_flash_to_file(self, address: int, size: int, path: str):
  157. """Read all content of flash memory to file"""
  158. with open(path, 'wb') as f:
  159. f.write(self.read_flash(address, size))
  160. def erase_flash(self):
  161. """Do mass erase"""
  162. self.send_cmd(0x44)
  163. self.send((0xFF, 0xFF), add_checksum=True)
  164. self.wait_ack()
  165. def write_flash(self, address: int, buf: bytes):
  166. """Write binary data to flash"""
  167. flash_full_chunks, last_chunk_size = divmod(len(buf), self.CHUNK_SIZE)
  168. for i in range(flash_full_chunks):
  169. self.write_mem(address + i * self.CHUNK_SIZE, buf[i * self.CHUNK_SIZE:(i + 1) * self.CHUNK_SIZE])
  170. if last_chunk_size > 0:
  171. self.write_mem(address + flash_full_chunks * self.CHUNK_SIZE,
  172. buf[flash_full_chunks * self.CHUNK_SIZE:])
  173. def write_file_to_flash(self, address: int, path: str):
  174. """write binary file to flash"""
  175. with open(path, 'rb') as f:
  176. self.write_flash(address, f.read())
  177. if __name__ == '__main__':
  178. start_time = time.time()
  179. d = FlashAT32('/dev/ttyUSB0', 115200 * 8)
  180. d.DEBUG_PRINT = True
  181. try:
  182. d.connect()
  183. except Exception as e:
  184. print(e)
  185. print(d.get_flash_size())
  186. print(d.get_uid_str())
  187. d.erase_flash()
  188. iap_path = '../output/iap.bin'
  189. fw_path = '../output/fw.bin'
  190. iap_path_r = '../output/m3_artery_iap.bin'
  191. fw_path_r = '../output/m3_artery_fw.bin'
  192. d.write_file_to_flash(0x08000000, iap_path)
  193. d.write_file_to_flash(0x08021000, fw_path)
  194. d.read_flash_to_file(0x08000000, os.path.getsize(iap_path), iap_path_r)
  195. d.read_flash_to_file(0x08021000, os.path.getsize(fw_path), fw_path_r)
  196. os.system(f'diff {iap_path} {iap_path_r}')
  197. os.system(f'diff {fw_path} {fw_path_r}')
  198. os.remove(iap_path_r)
  199. os.remove(fw_path_r)
  200. print(time.time() - start_time)