Friday, March 23, 2012

My Implementation of Simple Reliable Transport Protocol (SRTP)

 This protocol runs above UDP and provides : reliability, retransmissions, acknowledgements, timeouts, and congestion control


#define NUM_DELIMETERS   10
#define MAX_SRTP_HEADER  22
#define MAX_SRTP_PAYLOAD 81
#define MAX_SRTP_PACKET  MAX_SRTP_PAYLOAD + MAX_SRTP_HEADER
#define MAX_NUM_PACKETS  2000
#define MAX_SRTP_STRING  MAX_SRTP_PACKET + NUM_DELIMETERS
  
#define CLOSED        0
#define LISTEN        1
#define SYN_SENT      2
#define SYN_RECVD     3
#define ESTABLISHED   4
#define FIN_WAIT1     5
#define FIN_WAIT2     6
#define CLOSE_WAIT    7
#define LAST_ACK      8
#define TIMER_EXP     9

/*Error Codes */
#define CHKSUM_FAILED  -2

/*Code for Control Field in UDP packet*/

#define ACK   "A"
#define RST   "R"
#define SYN   "S"
#define FIN   "F"
#define DAT   "D"
#define SYN_ACK "Z"
#define RXM   "R"
#define ERR   "E"
#define FIN_ACK "X"
#define DATL  "L"   /* indentifies last segment */
#define PUT   "put"

int conflag;
/*typedef definitions*/
typedef unsigned long  uint32;
typedef unsigned short uint16;
typedef unsigned char  uint8;

typedef struct
{      
       uint16  src_port_num;
       uint16  dest_port_num;
       uint32  seq_num;
       uint32  ack_num;     
       uint16  window_size;
       u_short  cksum;
       char    *contrl_field;
       uint32  reserved;
       char    data[MAX_SRTP_PAYLOAD];
      
}srtp_header;

typedef struct
{
     srtp_header   pkt;
     bool          received;
     bool          ack_sent;
}incoming_buff;

incoming_buff server_buff[2000];

typedef struct

      srtp_header     pkt;
      int             num_rxmt;
      bool            ack_recd;
      bool            sent_pkt;
      bool            timed_out;
      bool            no_pkt;
      int             pID;
     
}timer;

timer t1[MAX_NUM_PACKETS];

typedef struct
{
       int                start_index;
       int                current_index;
       int                window_size;          
}sWindow;
 
sWindow window; 
   
typedef struct
{
       int                connection_number;
       int                portNumber;
       int                socketDesc;
       int                currentState;
       struct sockaddr_in host_addr;
      
}connection;

connection connS, connC; 

//---------------------------------------------------------------------

void sendPackets(int , int);
srtp_header initialize(void);


void initWindow(int size)

   extern sWindow window;

   window.start_index = 0;
   window.current_index = size-1;
   window.window_size = size;
   //printf("INITIALIZED WINDOW!\n");
return;
}

void initTimer()
{
  extern timer t1[MAX_NUM_PACKETS];
 
  for (int i=0;i
  {  srtp_header srt;
     srt = initialize();
     t1[i].pkt = srt;
     t1[i].ack_recd = 0;
     t1[i].timed_out = 0;
     t1[i].pID = 0;
     t1[i].no_pkt=1;
  }

}

void initializeBuff()
{  extern incoming_buff server_buff[2000];

  for(int i=0;i<2000;i++)
  {  srtp_header srp;
     srp = initialize();
 
     server_buff[i].received=0;
     server_buff[i].pkt = srp;
     server_buff[i].ack_sent=0;
  }
  return; 
}



void insertTimer(srtp_header hdr, int sent)
{
   extern timer t1[MAX_NUM_PACKETS];
   extern sWindow window;
   extern connection connC;

   if (sent
   {
     t1[sent].pkt = hdr;
     printf("INSERTED PACKET! %d %s \n",sent, hdr.contrl_field);
     t1[sent].num_rxmt = 0;
     t1[sent].ack_recd = 0;
     t1[sent].sent_pkt = 0;
     t1[sent].timed_out = 1;
     t1[sent].no_pkt=0;
     t1[sent].pID = 0;
   }
   return;
}

//---------------------------------------------------------------------
u_short checksum(srtp_header srtp)
{
    int sum, i = 0;
    u_short answer;
   
    sum = 0;
    i = 0;
    sum  = sum + srtp.src_port_num;
    //printf("SUM1 = %d\n\n", sum);
    sum  = sum + srtp.dest_port_num;
    //printf("SUM2 = %d\n\n", sum);
    sum  = sum + srtp.seq_num;
    //printf("SUM3 = %d\n\n", sum);
    sum  = sum + srtp.ack_num;
    //printf("SUM4 = %d\n\n", sum);
    sum  = sum + srtp.window_size;
    //printf("SUM5 = %d\n\n", sum);
    sum  = sum + atoi(srtp.contrl_field);
    //printf("SUM6 = %d\n\n", sum);
    sum  = sum + srtp.reserved;
    //printf("SUM7 = %d\n\n", sum);
   
    for(i=0; i
    {
        sum  = sum + (srtp.data[i]);
        //cout<<"SUM "<<<" = "<<<" : "<<<
    }
   
    //printf("SUM8 = %d\n\n", sum);
   
   
    answer = ~sum;  /*takes the ones complement */
    //printf("ANSWER = %d\n\n", answer);
    return answer;
   
}

  /* reverse string s in place*/
  void reverse (char s[])
  {
      int c, i, j;

      for (i=0, j = strlen(s)-1; i
          c = s[i];
          s[i] = s[j];
          s[j] = c;
      }
  }  /* end of reverse*/


  /* converts n to characters in s */
  void itoa(int n, char s[])
  {
      int i, sign;

      if ((sign = n) < 0) {
          n = -n;
      }
      i=0;
      do {
          s[i++] = n % 10 + '0';
      } while ((n /= 10) > 0  );
      if (sign < 0) {
          s[i++] = '-';
      }
      s[i] = '\0';
      reverse(s);
  }   /* end of itoa */

  /* create packet  and place into circular buffer*/
char* getPacket(srtp_header srtp)
  {
     /*declarations*/
    char tmp_str[MAX_SRTP_PACKET + 10]; /*0416*/

    int i,j, k;
   
    static char *out;
    out = (char *)calloc(532, sizeof(char));
    const char *delimiter = "&";
    const char *delnull = "\0"; /*0416*/

    memset(tmp_str, NULL, sizeof(tmp_str) );
   
  
  
    printf("reading the input ...\n\n");

    /**************create header string**********************/
    strcat(out, delimiter);
    itoa(srtp.src_port_num, tmp_str);
    strcat(out, tmp_str);
    strcat(out, delimiter);

    itoa(srtp.dest_port_num, tmp_str);
    strcat(out, tmp_str);
    strcat(out, delimiter);
   
    itoa(srtp.seq_num, tmp_str);
    strcat(out, tmp_str);
    strcat(out, delimiter);

    itoa(srtp.ack_num, tmp_str);
    strcat(out, tmp_str);
    strcat(out, delimiter);

    itoa(srtp.window_size, tmp_str);
    strcat(out, tmp_str);
    strcat(out, delimiter);
   
    itoa(srtp.cksum, tmp_str);
    strcat(out, tmp_str);
    strcat(out, delimiter);
  
    strcat(out, srtp.contrl_field);
    strcat(out, delimiter);
   
    itoa(srtp.reserved, tmp_str);
    strcat(out, tmp_str);
    strcat(out, delimiter);

   
    strncpy(tmp_str, srtp.data, MAX_SRTP_PAYLOAD);
    strcat(out, tmp_str);
    strcat(out, delnull); /*0416*/
    //strcat(out, delimiter);
 
   

    printf("packet string  = %s\n\n", out);
   
    return out;
 
   

  } /* end of getPacket()*/


  /* parse packet into tokens */
srtp_header parser(char s[])
  {
    /*declarations */
    srtp_header srtp;
    const char *delimiter = "&";
    int i;
    char *data;
    data = (char *)calloc(MAX_SRTP_PAYLOAD, sizeof(char)); /*0416*/
    memset(data, NULL, MAX_SRTP_PAYLOAD); /*0416*/
    const char *delnull = "\0"; /*0416*/

   
   
   
    printf("You are inside Parser\n\n");

    /**************parse string********/
   
    srtp.src_port_num = (short)atoi(strtok(s, delimiter)); 
    srtp.dest_port_num = (short)atoi(strtok(NULL, delimiter));
    srtp.seq_num = (short)atoi(strtok(NULL, delimiter));
    srtp.ack_num = (short)atoi(strtok(NULL, delimiter));
    srtp.window_size = (short)atoi(strtok(NULL, delimiter));
    srtp.cksum = (u_short)atoi(strtok(NULL, delimiter));
    srtp.contrl_field = (strtok(NULL, delimiter) );
    srtp.reserved = (short)atoi(strtok(NULL, delimiter));
 
    //data = strtok(NULL, delimiter);
    data = strtok(NULL, delnull); /*0416*/

    
    if(data != NULL)
    {
      strncpy(srtp.data, data, MAX_SRTP_PAYLOAD); /*0416*/
    }
    

  

    /*printf("src_port_num = %d\n", srtp.src_port_num);
    printf("dest_port_num = %d\n", srtp.dest_port_num);
    printf("cksum_num = %d\n", srtp.cksum);
    printf("seq_num = %d\n", srtp.seq_num);
    printf("ack_num = %d\n", srtp.ack_num);
    printf("contrl_field = %s\n", srtp.contrl_field);*/
    if(data != NULL)
    {
      printf("data = %s\n", data);
    }
   
    printf("Leaving Parser\n\n");
   
    return srtp;

    /*************parse string**********/


  }  /* end of parser */
 
 
 
/* initializes srtp struct to its default values */
srtp_header initialize(void)
{
   
    srtp_header packet;
  
    packet.src_port_num = 0;
    packet.dest_port_num = 0;
    packet.seq_num = 0;
    packet.ack_num = 0;
    packet.cksum = 0;
    packet.contrl_field = "0";
    packet.window_size = 0; 
    packet.reserved = 0;
    memset(packet.data, 0, MAX_SRTP_PAYLOAD );
   
    //printf("You have initialized the packet\n\n");
   
    return packet;

} /* end of initialize */

int cTimer(int index,int amount)
{
   extern timer t1[MAX_NUM_PACKETS];
   int procId;
  
   procId = fork();
  
   switch(procId)
   {
     case 0:
        printf("Timer %d started!\n", index);
        //usleep(amount);/*0418*/
    sleep(amount);
    printf("Timer %d timer expired!\n", index);
   
    _exit(0);
    
     default:
       t1[index].pID = procId;
       //printf("TIMER PID = %d\n", procId);
       return procId;
   }
  
}


void sendPackets(int startIx, int numberOfPkts)
{  extern timer t1[MAX_NUM_PACKETS];
   extern int conflag;
   int r,i, slp, numTrmt;

   if ((startIx+numberOfPkts-1)>MAX_NUM_PACKETS)
        numTrmt = MAX_NUM_PACKETS;
   else numTrmt = startIx+numberOfPkts-1;

     
   //if ((t1[startIx].num_rxmt < 3)&&(t1[startIx].timed_out))
   //{  
  
      for (i=startIx;i<=(numTrmt);i++)
      {
        if((!t1[i].no_pkt)&&(t1[i].num_rxmt<3)&&(t1[i].timed_out))
        { t1[i].ack_recd = 0;
          t1[i].timed_out=0;  

      slp = cTimer(i,10);/*0418*/

          r = sendto(connC.socketDesc,getPacket(t1[i].pkt),MAX_SRTP_STRING,0,(struct sockaddr*)&connC.host_addr, sizeof(struct sockaddr));
     
          if (r==-1)
          {
           //printf("sendto fails\n");
       return;
          }
          else if (t1[i].sent_pkt)
          {
           t1[i].num_rxmt++;
       printf("REXMT PACKET %d FOR %d TIME!!\n", i, t1[i].num_rxmt);
          }
          else 
      {
       t1[i].num_rxmt = 0;
           t1[i].sent_pkt = 1;
           printf("=====PACKET %d SENT====\n", i);      
      }
     }
     else if((!t1[i].no_pkt)&&(t1[i].num_rxmt==3)&&(t1[i].timed_out))
     {
        printf("pkt %d experienced congestion\n",i);
            conflag = 1;
     }

      }  
   //}
   //else if (t1[startIx].num_rxmt==3)
   //{
   //   printf("**************CONGESTION EXPERIENCED*********\n");
     
   //}
return;  
}


//======================================BEGIN API FOR SRTP=============================================
int sendSynAckPkt()
  {
           int sendsynack;
           srtp_header synackPkt;
       extern connection connS;
       
       synackPkt = initialize();
           synackPkt.contrl_field = SYN_ACK;
           synackPkt.seq_num = rand();
      
           sendsynack = sendto(connS.socketDesc,getPacket(synackPkt),MAX_SRTP_PACKET,0,(struct sockaddr *)&connS.host_addr, sizeof(struct sockaddr));
      
           if (sendsynack == -1)
           {
              return -1;
           }
       else
       {
           return 1;
       }

      
  } /*end of sendSynAckPkt*/
 
 
int sendFinAckPkt()
  {
          int sendfinack;
          extern connection connS;
          srtp_header finackPkt;
      
          finackPkt = initialize();
          finackPkt.contrl_field = FIN_ACK;
          finackPkt.seq_num = rand();
         
          sendfinack = sendto(connS.socketDesc,getPacket(finackPkt),MAX_SRTP_PACKET,0,(struct sockaddr *)&connS.host_addr, sizeof(struct sockaddr));
      
              if (sendfinack == -1)
              {
                 return -1;
              }
              else
          {   
               return 1;
          }
          
    
  } /* end of sendFinAckPkt */
      
 
int s_sendAckPkt(int seq_num)
 {

          int sendAck;
      srtp_header ackPacket;
      extern connection connS;
      ackPacket = initialize();
      ackPacket.contrl_field = ACK;     
      ackPacket.ack_num = seq_num;
     
     
      sendAck = sendto(connS.socketDesc, getPacket(ackPacket), MAX_SRTP_PACKET,0,(struct sockaddr *)&connS.host_addr, sizeof(struct sockaddr));
     
      if (sendAck == -1)
      {
        return -1;
      }
      else
      {
         return 1;

      }
   
 } /* end of s_sendAckPkt */

int sendFinPkt()
 {
         int sendFin;
         extern connection connC;
         srtp_header finPacket;
             finPacket = initialize();
      
             finPacket.seq_num = rand();     
             finPacket.contrl_field = FIN;   
   
         sendFin = sendto(connC.socketDesc, getPacket(finPacket), MAX_SRTP_PACKET, 0, (struct sockaddr *)&connC.host_addr, sizeof(struct sockaddr));
        
         if (sendFin == -1)
             {
               return -1;
             }      
             else
             { 
             return 1;

         }
 } /*end of sendFinPkt */   
   
          
   
  
 int sendSynPkt()
 {

        int  sendSyn;
        extern connection connC;
        srtp_header synPacket;
        synPacket = initialize();
      
        synPacket.seq_num = rand();
        //printf("SYN SEQ NUM = %d\n\n", synPacket.seq_num );
        synPacket.contrl_field = SYN;
        sendSyn = sendto(connC.socketDesc, getPacket(synPacket), MAX_SRTP_PACKET, 0, (struct sockaddr *)&connC.host_addr, sizeof(struct sockaddr));
       
   
    if (sendSyn == -1)
        {
          return -1;
        }
    else
    {
      return 1;
    }
      
 }  /* end of sendsynpkt */

int c_sendAckPkt(int ackNum)
 {

          int sendAck;
      extern connection connC;
      srtp_header ackPacket;
      ackPacket = initialize();
      ackPacket.contrl_field = ACK;     
      ackPacket.ack_num = ackNum;
     
      sendAck = sendto(connC.socketDesc, getPacket(ackPacket), MAX_SRTP_PACKET,0,(struct sockaddr *)&connC.host_addr, sizeof(struct sockaddr));
     
      if (sendAck == -1)
      {
        return -1;
      }
      else
      {
         return 1;

      }
   
 } /* end of c_sendAckPkt */
     

/*** generate one packet that contains command, filename, filesize***/
int sendSetupPkt(char *filename, char *command, int namelength)
{  
        
        int sendPkt;
        srtp_header setupPkt;
        extern connection connC;
       
        setupPkt = initialize();
        setupPkt.contrl_field = command;
        memset(setupPkt.data,0,80);
        strncpy(setupPkt.data, filename, namelength);
        sendPkt = sendto(connC.socketDesc, getPacket(setupPkt), MAX_SRTP_PACKET+10, 0, (struct sockaddr*)&connC.host_addr, sizeof(struct sockaddr));
        if(sendPkt ==-1)
        {
            return -1;  /* error sending packet */
        }
        else
        {
            return 1;  /* packet sent ok */
        }
       
} /* end of sendSetupPkt */



//THIS FUNCTION PERFORMS 3-WAY HANDSHAKE, ESTABLISHES   
//CONNECTION AND RETURNS A CONNECTION NUMBER
    
    int SRTP_Open(struct sockaddr *udp_host_addr, int port_number)
    {  int connection_number = 1;
       struct sockaddr_in client_addr;
       extern connection connC;
      
       //*************STEP 1: CREATE SOCKET***********************

    if (connC.currentState ==  CLOSED)
    {
      //printf("Current State = %d\n",connC.currentState);
   
          connC.socketDesc = socket(AF_INET, SOCK_DGRAM, 0);
      connC.connection_number = connection_number = 1;

          client_addr.sin_family      = AF_INET;
          client_addr.sin_port        = htons(port_number);
          client_addr.sin_addr.s_addr = INADDR_ANY;      
             
          if (connC.socketDesc == -1)
          {
           printf("Socket failed to create!\n");
       exit(1);
          }
      
          else
          {  
       //************STEP 2: BIND SOCKET TO LOCAL ADDRESS**********
      
           if (bind(connC.socketDesc, (struct sockaddr *)&client_addr, sizeof(struct sockaddr)) == -1)
       {
          printf("Bind Failed in SRTP_Open!\n");
          close(connC.socketDesc);
          exit(1);
       }
       else
       {
       // printf("Bind successful!\n");
        fcntl(connC.socketDesc,F_SETFL,O_NONBLOCK);
           }
      }
         return connection_number;
     } 
}
   
//********************END SRTP_Open Function***************************
   
//***********STEP 1: GO TO STATE LISTEN************** 
    
int SRTP_Listen(int port_number)       

       //********** LISTEN FOR A SYN*************
       extern connection connS;
       int sockfd;
       struct sockaddr_in host_addr;
      
       connS.connection_number = 2;
      
    if (connS.currentState == CLOSED)
    { 
       //*********CREATE SOCKET**************
      
       sockfd = socket(AF_INET, SOCK_DGRAM, 0);
       connS.socketDesc = sockfd;
      
       host_addr.sin_family = AF_INET;
       host_addr.sin_port   = htons(port_number);
       host_addr.sin_addr.s_addr = INADDR_ANY;
      
       if(sockfd==-1)
       {
          printf("CREATE SOCKET FAILED!\n");
      exit(1);
       }
      
       //************ELSE BIND THE SOCKET*******
       else
       {
          connS.socketDesc = sockfd;
      if(bind(sockfd,(struct sockaddr *)&host_addr, sizeof(struct sockaddr)) == -1)
      {
         printf("Bind failed in SRTP_Listen!\n");
         close(sockfd);
         exit(1);
      }
      else
      {printf("Bind Successful!\n");
       fcntl(sockfd,F_SETFL,O_NONBLOCK);
      }
       }
           
       connS.currentState = LISTEN;
       //printf("Current State: %d\n",connS.currentState);

       return connS.connection_number;
     }
    
     if (connS.currentState == SYN_RECVD)
     { 
       //*************SEND A SYNACK********************
      
        int sendSyn = sendSynAckPkt();
      
      
       if (sendSyn == -1)
       {
          printf("SYNACK SEND ERROR!\n");
      connS.currentState = CLOSED;
      close(connS.socketDesc);
      exit(1);
       }
       else if (sendSyn != -1)
       {
         printf("SYNACK SENT!\n");   
       }
       return connS.connection_number;
      }     
    }

int SRTP_Close(int connection_number)
{
    extern connection connS,connC;
    //printf("INSIDE CLOSE!\n");
   
    if ((connS.connection_number==connection_number)&&(connS.currentState==CLOSE_WAIT))
    {
      //printf("SENDING A FINACK!\n");
      connS.currentState = LAST_ACK;
      int temp = sendFinAckPkt();
      return temp;
    }
    else if((connC.connection_number==connection_number)&&(connC.currentState==ESTABLISHED))
    {
      //printf("TERMINATING-SENDING A FIN\n");
      connC.currentState = FIN_WAIT1;
      int temp = sendFinPkt();
      return temp;
    }
    else if((connC.connection_number==connection_number)&&(connC.currentState==FIN_WAIT2))
    {
       connC.currentState = CLOSED;
       //printf("Sending Ack on FIN\n");
       int temp = c_sendAckPkt(0);
       return temp;
    }
}

int SRTP_Send(int conn_number, char*buffer_ptr, int buffer_size)

    int num, r, sent;
    static char *string;
    string = (char *)calloc(81, sizeof(char));
    srtp_header srtp;
    static char *result;
    char value[81];
    extern connection conn1;
   
   sent = num = 0;
   memset(string, NULL, MAX_SRTP_PAYLOAD + 1);
   memset(value, NULL, 81);
     
   while(*buffer_ptr != 0)
   {
     value[num++] = *buffer_ptr++;
             
     if( ((num / 80) == 1) )
     {
       //printf("String greater than 80\n\n\n\n");

      srtp = initialize();
      srtp.src_port_num = 3490;
      srtp.dest_port_num = connC.portNumber;
      srtp.window_size = 10;
      //srtp.contrl_field = DAT; /*0416*/
      srtp.ack_num = 0;
     
      if(*buffer_ptr == NULL)  /*0416*/
      {
         srtp.contrl_field = DATL;
      }
      else
      {
         srtp.contrl_field = DAT;   
      }

           
       
       srtp.seq_num = (sent+1)*80;
       memset(srtp.data,NULL,80);  
              
       strcat(string, value);
       strncpy(srtp.data, string, strlen(string));
        srtp.cksum = checksum(srtp);
        //printf("~~~~~~CHECKSUM ON THIS PACKET:  %d\n", srtp.cksum);     
         
       insertTimer(srtp,sent);    
       sent++;

       num = 0;
       memset(string, NULL, MAX_SRTP_PAYLOAD + 1);
       memset(value, NULL, 81);
     }   
   } /*end of whileloop */
    
   //printf("strlen = %d\n\n", strlen(value) ) ;
   if (strlen(value)>0)
   {
      //printf("String less than 80\n\n");

     srtp = initialize();
     srtp.src_port_num = 3490;
     srtp.dest_port_num = connC.portNumber;
     srtp.window_size = 10;
     srtp.ack_num = 0;
           
      
      srtp.contrl_field=DATL;
      srtp.seq_num = (sent+1)*80;
      memset(srtp.data,NULL,80);
                   
      strcat(string, value);
      strncpy(srtp.data, string, strlen(string));
      srtp.cksum = checksum(srtp);
      //printf("~~~~~~CHECKSUM ON THIS PACKET: %d\n", srtp.cksum);     
     

      insertTimer(srtp,sent);
      sent++;
      num = 0;
      memset(string, NULL, MAX_SRTP_PAYLOAD + 1);

     } /*end of if*/
         
     return sent;

}


int SRTP_Receive(int conn_num, char *buffer_ptr, int buffer_size)
{
    
     srtp_header srtp;
   
    /*  1) reads data from udp buffer
      2) parse packet string
      3) create data packet (structs)
      4) return # of bytes received or an error code = -1 */
     
      srtp = parser(buffer_ptr);
     
      /*** pass srtp to global buffer if space is avail if not return error **/
      

       return buffer_size;
}

No comments: