| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262 | import osimport timefrom typing import Iterable, Union, Optional, Tuplefrom serial import Serialimport shutilimport subprocessimport reimport inquirerfrom tqdm import tqdmclass ErrorAT32(IOError):	"""Common error for FlashAT32 util"""SendData = Union[str, bytes, bytearray, Iterable[int], int]class FlashAT32:	"""Artery AT32x flashing utility"""	CMD_SYNC = 0x7F	CMD_ACK = 0x79	CMD_NACK = 0x1F	ADDR_FLASH = 0x08000000	CHUNK_SIZE = 256	DEBUG_PRINT = False	# def __init__(self, tty: str, baudrate: int = 115200):	def __init__(self, tty: str, baudrate: int):		self.serial = Serial(port=tty, baudrate=baudrate, timeout=1, parity='E', xonxoff=False)	@staticmethod	def to_bytes(data: SendData) -> bytes:		"""Convert various types of data to bytes"""		if isinstance(data, str):			return data.encode()		elif isinstance(data, int):			return bytes((data, ))		else:			return bytes(data)	@staticmethod	def checksum(buf: bytes) -> int:		"""Calculate standard CRC-8 checksum"""		crc = 0 if len(buf) > 1 else 0xFF		for b in buf:			crc = crc.__xor__(b)		return crc	def send(self, data: SendData, add_checksum: bool = False):		"""Send data to device"""		buf = self.to_bytes(data)		if add_checksum:			buf += bytes((self.checksum(buf), ))		if self.DEBUG_PRINT:			print(f'> {buf}')		self.serial.write(buf)	def recv(self, count: int, timeout: float = 1) -> bytes:		"""Receive bytes from device"""		self.serial.timeout = timeout		buf = self.serial.read(count)		if self.DEBUG_PRINT:			print(f'< {buf}')		return buf	def read(self, count: int, timeout: float = 3) -> bytes:		"""Read count bytes from device or raise exception"""		buf = bytearray()		start_time = time.monotonic()		while time.monotonic() - start_time < timeout:			buf.extend(self.recv(count - len(buf)))			if len(buf) == count:				return bytes(buf)		raise ErrorAT32('Read timeout')	def receive_ack(self, timeout: float = 10) -> Optional[bool]:		"""Wait ACK byte from device, return True on ACN, False on NACK, None on timeout"""		start_time = time.monotonic()		while time.monotonic() - start_time < timeout:			buf = self.recv(1)			if buf == bytes((self.CMD_ACK, )):				return True			if buf == bytes((self.CMD_NACK, )):				return False		return None	def wait_ack(self, timeout: float = 10):		"""Wait ACK byte from device, raise exception on NACK or timeout"""		res = self.receive_ack(timeout)		if res is None:			raise ErrorAT32('ACK timeout')		if not res:			raise ErrorAT32('NACK received')	def send_cmd(self, code: int):		"""Send command and wait ACK"""		self.send((code, ~code & 0xFF))		self.wait_ack()	def set_isp(self):		"""Send host data to device"""		self.send((0xFA, 0x05))		ack = self.receive_ack()		if ack is None:			raise ErrorAT32('ACK timeout')		if not ack:			raise ErrorAT32('NACK received')			return		self.send((0x02, 0x03, 0x54, 0x41), add_checksum=True)		self.wait_ack()	def connect(self, timeout: float = 10):		"""Init connection with AT32 bootloader"""		start_time = time.monotonic()		while time.monotonic() - start_time < timeout:			self.send(self.CMD_SYNC)			if self.recv(1, 0.01) == bytes((self.CMD_ACK, )):				return self.set_isp()		raise ErrorAT32('Connection timeout')	def access_unprotect(self):		"""Disable access protection"""		self.send_cmd(0x92)		self.wait_ack()	def get_device_id(self) -> Tuple[int, int]:		"""Read some device info"""		self.send_cmd(0x02)		length = self.read(1)[0]		if length != 4:			raise ErrorAT32('Incorrect response length')		buf = self.read(length + 1)		project_code, device_code = buf[4], buf[1] | (buf[0] << 8) | (buf[3] << 16) | (buf[2] << 24)		if project_code not in {8, }:			raise ErrorAT32('Device not supported')		return project_code, device_code	def reset(self):		"""SW reboot for device"""		self.send_cmd(0xD4)		self.wait_ack()	def read_mem(self, address: int, count: int) -> bytes:		"""Read data from device memory"""		self.send_cmd(0x11)		self.send(address.to_bytes(4, 'big'), add_checksum=True)		self.wait_ack()		self.send((count - 1).to_bytes(1, 'big'), add_checksum=True)		self.wait_ack()		return self.read(count)	def write_mem(self, address: int, buf: bytes):		"""Write data to device memory"""		self.send_cmd(0x31)		self.send(address.to_bytes(4, 'big'), add_checksum=True)		self.wait_ack()		self.send(bytes((len(buf) - 1, )) + buf, add_checksum=True)		self.wait_ack()	def get_flash_size(self) -> int:		"""Read flash size"""		return int.from_bytes(self.read_mem(0x1FFFF7E0, 2), 'little') & 0x0000FFFF	def get_uid(self) -> bytes:		"""Read unique chip ID"""		return self.read_mem(0x1FFFF7E8, 12)		# buf = bytearray()		# x = 3		# for i in range(12 // x):		#	 buf.extend(self.read_mem(0x1FFFF7E8 + i * x, 1 * x))		# return bytes(buf)	def get_uid_str(self) -> str:		"""Get unique chip ID as hex string"""		return ''.join(f'{b:02X}' for b in self.get_uid())	def read_flash(self, address: int, size: int) -> bytes:		"""Read all content of flash memory"""		chunks, last_bytes_count = divmod(size, self.CHUNK_SIZE)		res = bytearray()		for i in range(chunks):			res.extend(self.read_mem(address + i * self.CHUNK_SIZE, self.CHUNK_SIZE))		if last_bytes_count != 0:			res.extend(self.read_mem(address + chunks * self.CHUNK_SIZE, last_bytes_count))		return bytes(res)	def read_flash_to_file(self, address: int, size: int, path: str):		"""Read all content of flash memory to file"""		with open(path, 'wb') as f:			f.write(self.read_flash(address, size))		def erase_flash(self):		"""Do mass erase"""		self.send_cmd(0x44)		self.send((0xFF, 0xFF), add_checksum=True)		self.wait_ack()	def write_flash(self, address: int, buf: bytes):		"""Write binary data to flash"""		flash_full_chunks, last_chunk_size = divmod(len(buf), self.CHUNK_SIZE)		for i in tqdm(range(flash_full_chunks)):			self.write_mem(address + i * self.CHUNK_SIZE, buf[i * self.CHUNK_SIZE:(i + 1) * self.CHUNK_SIZE])		if last_chunk_size > 0:			self.write_mem(address + flash_full_chunks * self.CHUNK_SIZE,						   buf[flash_full_chunks * self.CHUNK_SIZE:])	def write_file_to_flash(self, address: int, path: str):		"""write binary file to flash"""		with open(path, 'rb') as f:			self.write_flash(address, f.read())def menu_test():	questions = [	inquirer.List('size',                message="What size do you need?",                choices=['Jumbo', 'Large', 'Standard', 'Medium', 'Small', 'Micro'],            ),	]	answers = inquirer.prompt(questions)	print(answers)if __name__ == '__main__':	# Копируем бинарники и добавляем CRC в fw.bin	# print("Копирование bin файлов")	# shutil.copy('../../output/iap.bin', 'bin/iap.bin')	# shutil.copy('../../output/fw.bin', 'bin/fw.bin')	# shutil.copy('../../output/crc_ewarm.exe', 'bin/crc_ewarm.exe')	# print('Добавление CRC в fw.bin')	# os.startfile('bin\crc_ewarm.exe')		"""	start_time = time.time()	d = FlashAT32('COM53', 115200 * 8)	d.DEBUG_PRINT = True	try:		d.connect()	except Exception as e:		print(e)	print(d.get_flash_size())	print(d.get_uid_str())	d.erase_flash()	iap_path = 'bin/iap.bin'	fw_path = 'bin/fw.bin'	iap_path_r = 'bin/m3_artery_iap.bin'	fw_path_r = 'bin/m3_artery_fw.bin'	d.write_file_to_flash(0x08000000, iap_path)	d.write_file_to_flash(0x08021000, fw_path)	d.read_flash_to_file(0x08000000, os.path.getsize(iap_path), iap_path_r)	d.read_flash_to_file(0x08021000, os.path.getsize(fw_path), fw_path_r)	# os.system(f'diff {iap_path} {iap_path_r}')	# os.system(f'diff {fw_path} {fw_path_r}')	os.remove(iap_path_r)	os.remove(fw_path_r)	print(time.time() - start_time)	"""
 |