"""
Структура TCP сегмента:

Слово 1

    SRC (16 бит) номер порта источника
    DST (16 бит) номер порта назначения

Слово 2    
    SEQ (32 бита) позиционный номер, место первого байта данных сегмента в 
        потоке данных от источника до получателя.

Слово 3
    ACK (32 бита) квитанция

Слово 4
    LEN (4 бита) длина заголовка
    reserv (6 бит)
    FLAGS (6 бит) флаги
    WND (16 бит) размер окна приема

"""

import socket
import struct
import array


class TCPPacket:
    _srcip = None
    _dstip = None

    SRC = 0 # номер порта источника
    DST = 0
    SEQ = 1 # позиционный номер
    ACK = 0 # квитанция
    LEN = 5 # длина заголовка

    # flags
    f_URG = 0 
    f_ACK = 0
    f_PSH = 0
    f_RST = 0
    f_SYN = 0
    f_FIN = 0

    WND = 0 # размер окна приема
    CHK = 0 # контрольная сумма
    URG = 0 # указатель границы срочных данных
    OPT = ''    # опции
    DATA = ""   # данные

    
    def __init__(self) -> None:
        print("Create TCP packet instance")


    def set_data(self, data):
        self.DATA = data


    def set_ip(self, src, dst):
        self._srcip = src
        self._dstip = dst

    
    def to_hex_string(self):

        tcp_flags = self.f_FIN + (self.f_SYN << 1) + (self.f_RST << 2) + \
                    (self.f_PSH << 3) + (self.f_ACK << 4) + (self.f_URG << 5)
        
        header = struct.pack('!HHLLBBHHH', self.SRC, self.DST, self.SEQ, self.ACK,
                             (self.LEN << 4) + 0, tcp_flags, self.WND, self.CHK, self.URG)

        # calculate data length
        data_len = 0
        if(self.DATA != None):
            data_len = len(self.DATA)

        # create pseudo TCP packet
        psh = struct.pack('!4s4sBBH',
                          socket.inet_aton(self._srcip), # type: ignore
                          socket.inet_aton(self._dstip),                          0, # type: ignore
                          socket.IPPROTO_TCP,
                          len(header) + data_len)

        psh = psh + header + bytes(self.DATA, 'utf-8')

        crc = self.crc(psh)

        return struct.pack('!HHLLBBH', self.SRC, self.DST, self.SEQ, self.ACK, \
                            (self.LEN << 4) + 0, tcp_flags, self.WND) + \
                            struct.pack('H', crc) + struct.pack('!H', self.URG) + \
                            bytes(self.DATA, 'utf-8')



    def crc(self, data: bytes):
        if len(data)%2 != 0:
            data += b'\0'

        res = sum(array.array("H", data))
        res = (res >> 16) + (res & 0xffff)
        res += res >> 16

        return (~res) & 0xffff


    @staticmethod
    def get_falgs_from_byte(data):
        flags = [0,0,0,0,0,0]

        flags[0] = (data >> 5) & 1
        flags[1] = data >> 4 & 1
        flags[2] = data >> 3 & 1
        flags[3] = data >> 2 & 1
        flags[4] = data >> 1 & 1
        flags[5] = data & 1

        return (flags[0], flags[1], flags[2], flags[3], flags[4], flags[5])


    def __str__(self):
        s = ""
        s += "TCP Segment"
        s += "\tSource Port: {}, Destination Port: {}\n".format(self.SRC, self.DST)
        s += "\tSequence: {}, Acknowledgement: {}\n".format(self.SEQ, self.ACK)
        s += "\tFlags\n"
        s += "\t\tURG: {}, ACK: {}, PSH: {}\n".format(self.f_URG, self.f_ACK, self.f_PSH)
        s += "\t\tRST: {}, SYN: {}, FIN: {}\n".format(self.f_RST, self.f_SYN, self.f_FIN)
        s += "\tHeader Length: {}, Window Size: {}\n".format(self.LEN, self.WND)
        # Other parameters
        s += "Over parameters\n"
        s += "\t\tSRC ip: {}, DST ip: {}\n".format(self._srcip, self._dstip)
        return s





def main():
    pack = TCPPacket()
    print(pack)
    


if __name__ == "__main__":
    main()