#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <stdio.h>
#include <string.h>
#include <unistd.h>
#include <stdlib.h>
#include <sys/wait.h>
#include <argp.h>
#include <signal.h>
#include <time.h>
#include <fcntl.h>
#include <sys/resource.h>

void
setcong(int s)
{
  char buf[256];
  socklen_t len;

  strcpy(buf, "bpf_cubic");
  
  len = strlen(buf);
  
  if (setsockopt(s, IPPROTO_TCP, TCP_CONGESTION, buf, len) != 0) {
    perror("setsockopt");
  }

  len = sizeof(buf);
  
  if (getsockopt(s, IPPROTO_TCP, TCP_CONGESTION, buf, &len) != 0) {
    perror("getsockopt");
  }
  
  printf("New: %s\n", buf);
}

int main(int argc, char **argv)
{
  system("echo 1 > /proc/sys/net/core/bpf_jit_enable");

  system("bpftool struct_ops register bpf_cubic.o");

    int ss = socket(AF_INET, SOCK_STREAM, 0);

    int yes = 1;
    setsockopt(ss, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes));

    struct sockaddr_in sin;
    memset(&sin, 0, sizeof(sin));
    sin.sin_family = AF_INET;
    sin.sin_port = htons(8011);

    if(bind(ss, (struct sockaddr *)&sin, sizeof(sin)) < 0){
      perror("bind");
    }

    if(listen(ss, 10) < 0){
      perror("listen");
    }

    { int yes = 1;
      setsockopt(ss, IPPROTO_TCP, TCP_NODELAY, &yes, sizeof(yes));
    }

    int pid = fork();
    if(pid == 0){
      socklen_t sinlen = sizeof(sin);
      int s1 = accept(ss, (struct sockaddr *) &sin, &sinlen);
      if(s1 < 0)
        perror("accept");
      close(ss);
      { int yes = 1;
        setsockopt(s1, IPPROTO_TCP, TCP_NODELAY, &yes, sizeof(yes));
      }
      char buf[512];
      while(1){
        int n = read(s1, buf, 512);
        printf("read %d\n", n);
        if(n <= 0)
          break;
      }
      exit(0);
    }

    close(ss);

    int cs = socket(AF_INET, SOCK_STREAM, 0);
    setcong(cs);
    memset(&sin, 0, sizeof(sin));
    sin.sin_family = AF_INET;
    sin.sin_port = htons(8011);
    if(connect(cs, (struct sockaddr *)&sin, sizeof(sin)) < 0)
      perror("connect");
    usleep(200000);
    if(write(cs, "xyz", 3) < 0)
      perror("write");

    usleep(200000);
    printf("first unregister:\n");
    system("bpftool struct_ops unregister name cubic");
    usleep(200000);

    if(write(cs, "abcd", 4) < 0)
      perror("write");

    usleep(200000);
    printf("close:\n");
    close(cs);
    usleep(200000);

    printf("second unregister:\n");
    system("bpftool struct_ops unregister name cubic");
    usleep(200000);
}
