/* Pass the Parcel */
/* Multihost Network + TCP/IP Stack Overloading Tool */
/* Version 0.01b */
/* Carl Ritson - 2001 */

#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <stdlib.h>
#include <stdio.h>
#include <unistd.h>
#include <string.h>
#include <time.h>
#include <signal.h>
#include <errno.h>

/*
 * XXX: Excessive comments inline, note to self, 
 *      don't spend a whole hour reading over the code once finished...
 */ 

/* FIXME: This is just a number I pulled from thin air :/ */
#define MCAST_GROUP "224.7.7.7"

enum {  
        PORT            = 6505,
        PKTSIZE         = 1400,
        /* XXX: Warning DEBUG outputs about 3-5 messages per packet */  
        DEBUG           = 0,
        /* 
        * The number of packet to receive and forward before regenerating
        * the random packet and adding extra packet to the network */
        MAXPKT          = 5000,
        /* 
        * Its a good idea to have the ZONESIZE set to around the 
        * number of hosts you have of the same type.
        * e.g. if you have say 10 100Mbps Machines and 5 10Mbps Machines
        *       set ZONESIZE to 10.
        */
        ZONESIZE        = 4,
        /* 
        * If you have lots of hosts with many different speeds you may want to
        * increase the number of ZONES 
        */
        ZONES           = 3,
        /* How frequenty hosts are balanced between the ZONES */
        BALANCE_INTERVAL = 15,
        MAX_BALANCES    = 10, /* XXX: Only used when DEBUG > 0 */
        /*
        * You may want to change this to reflect the average link speed
        * in octets..
        */
        MAX_OCTETS_SEC  = 12*1024*1024,
        /*
        * The number of seconds to sleep before actualy starting */
        PRESLEEP        = 10
};

typedef unsigned long hash_t;

/* 
 * This is the main host structure,
 * holds statistics and address structure,
 * serves as a hash bucket and a linked list item
 */ 
struct host_t {
        hash_t hash;
        struct sockaddr_in addr;
        struct host_t *hnext; /* Next Hash Bucket */
        struct host_t *lnext; /* Next List Item */
        unsigned long sent;
        unsigned long recv;
        unsigned long in_octets;
        unsigned long out_octets;
        int zone;
        int badness;
};

/* Global Frames Counter */
int frames = 0;

/* Head of The Linked List */
struct host_t *host_list = NULL;
/* ZONES Hash Tables */
struct host_t *zone[ZONES][ZONESIZE];
/* Hosts currently in a particular zone */
int hosts_in[ZONES];

int *zmap;   /* Used for Random Host Selection */
int total_hosts = 0;

/* Packets Allocated , to zones and globaly */
unsigned long allocs_remaining = 0;
unsigned long allocations[ZONES];

/* Else we spend all day calculating 2 to the the power ZONES */
int tpz;

/* Random Packet and its currently used size */
char packet[PKTSIZE];
int pktsize = PKTSIZE;

/* Comms Socket */
int lsocket = 0;

/* BEGIN FUNCTIONS */

/* hash, Hashing Function, acts on an IP Address */
hash_t hash(struct in_addr *in) {
        int i;
        char *p = (char *)in;
        hash_t ha;
        /* Magic Number: 4 Bytes in an IP Address */
        for(i = 0; i < 4; ++i) {
                ha = ha * 37 + (int)*(p+i);
        }
        return (ha % ZONESIZE);
}

/* power, returns in to the power pow */
double power(double in, unsigned int pow) {
        double out = in;
        int i;
        for(i = 1; i < pow; ++i) {
                out *= in;
        }
        return out;
}

/* halloc, mallocs and initializes a host structure */
struct host_t *halloc(char *ip, struct in_addr *in) {
        struct host_t *p = malloc(sizeof(struct host_t));
        if(p != NULL) {
                p->hnext = NULL;
                p->lnext = NULL;
                p->sent  = 0;
                p->recv  = 0;
                p->in_octets = 0;
                p->out_octets = 0;
                p->badness = 0;
                p->addr.sin_family = AF_INET;
                p->addr.sin_port = htons(PORT);
                if(ip == NULL && in != NULL) {
                        memcpy(&p->addr.sin_addr,in,sizeof(struct in_addr));
                } else {
                        inet_aton(ip,&p->addr.sin_addr);
                }
        }
        return p;
}

/* place_host, places host h, in zone z */
void place_host(struct host_t *h, int z) {
        struct host_t *p;
        if(zone[z][h->hash] == NULL) {
                zone[z][h->hash] = h;
        } else {
                p = zone[z][h->hash];
                while(p->hnext != NULL) {
                        p = p->hnext;
                }
                p->hnext = h;
        }
        h->zone = z;
        ++hosts_in[z];
        h->hnext = NULL;
}

/* adrcmp, compares two ip addresses, returns 1 on match */
int adrcmp(struct in_addr *a, struct in_addr *b) {
        if(memcmp(a,b,sizeof(struct in_addr)) == 0) {
                return 1;
        } else {
                return 0;
        }
}

/* lfind_host, find a host structure using the host linked list */
/* XXX: I don't think this is used anywhere */
struct host_t *lfind_host(char *cip, struct in_addr *in) {
        struct in_addr ip;
        struct host_t *p = host_list;
        if(cip == NULL && in != NULL) {
                memcpy(&ip,in,sizeof(struct in_addr));
        } else {
                inet_aton(cip,&ip);
        }
        while(p != NULL) {
                if(adrcmp(&p->addr.sin_addr,&ip)) {
                        return p;
                } else {
                        p = p->lnext;
                }
        }
        
        if(p == NULL && DEBUG) {fprintf(stderr,"lfind_host, return: %p\n",p);}
        
        return p;
}

/*
 * find_host, 
 * finds a host by ip address which is taken in string form 
 * or network byte form (struct in_addr).
 */ 
struct host_t *find_host(char *cip, struct in_addr *in) {
        struct in_addr ip;
        struct host_t *p = NULL;
        struct host_t *h = NULL;
        hash_t ha;
        int z;
        if(cip == NULL && in != NULL) {
                memcpy(&ip,in,sizeof(struct in_addr));
        } else {
                inet_aton(cip,&ip);
        }
        ha = hash(&ip);
        /*
         * This is the only possible slow down, if we have lots of ZONES
         * we have to check them all 
         */
        for(z = 0; z < ZONES; ++z) {
                p = zone[z][ha];
                if(p != NULL) {
                        if(adrcmp(&p->addr.sin_addr,&ip)) {
                                h = p;
                        } else if(p->hnext != NULL) {
                                p = p->hnext;
                                while(p != NULL) {
                                        if(adrcmp(&p->addr.sin_addr,&ip)) {
                                                h = p;
                                        }
                                        p = p->hnext;
                                }
                        }
                }
                if(h != NULL) {
                        break;
                }
        }
        
        if(h == NULL && DEBUG) {fprintf(stderr,"find_host, return: %p\n",h);}
        
        return h;
}

/*
 * add_host, 
 * adds a host to the hosts list, as find_host takes char or long ip 
 */
struct host_t *add_host(char *ip, struct in_addr *in) {
        struct host_t *p = host_list;
        struct host_t *h = NULL;
        static int z = 0; /* XXX: this is static for a reason */
        hash_t ha;
        if(DEBUG) {fprintf(stderr,"add_host, ip %s\n",ip);}
        /* Messy Exceptions */
        if(in != NULL) {
                if(DEBUG) {fprintf(stderr,"add_host, in '%s'\n",inet_ntoa(*in));}
                if(in->s_addr == INADDR_ANY ||
                                in->s_addr == INADDR_LOOPBACK) {
                        return NULL;
                }
        } else {
                if(DEBUG) {fprintf(stderr,"add_host, ip '%s'\n",ip);}
                if(strlen(ip) < 4) {
                        return NULL;
                }
        }
        /* Got pass expections, so we must be adding the host */
        if(p == NULL) {
                p = halloc(ip,in);
                host_list = p;
                if(DEBUG) {fprintf(stderr,"add_host, set host_list: %p\n",p);}
        } else {
                while(p->lnext != NULL) {
                        p = p->lnext;
                }
                p->lnext = halloc(ip,in);
                p = p->lnext;
        }
        if(p == NULL) {
                fprintf(stderr,"Malloc Failed\n");
                exit(1);
        } else {
                /*
                 * Using the static z, we can spread additions accross calls. 
                 * We don't place any hosts in the last zone, as it is ment 
                 * for bad hosts.
                 */
                if(hosts_in[z] >= ZONESIZE) {
                        ++z;
                        /* Rotate Allocations */
                        if(z >= (ZONES - 1)) {
                                z = 0;
                        }
                }
                ha = hash(&p->addr.sin_addr);
                if(zone[z][ha] == NULL) {
                        zone[z][ha] = p;
                } else {
                        h = zone[z][ha];
                        while(h->hnext != NULL) {
                                h = h->hnext;
                        }
                        h->hnext = p;
                }
                p->hash = ha;
                p->zone = z;
                ++hosts_in[z];
                ++total_hosts;
        }
                
        if(DEBUG) {fprintf(stderr,"add_host, return: %p\n",p);}

        return p;
}

/*
 * rgen,
 * random number generator using random. 
 * 0 =< x < high
 * where x is the number generated. 
 */
int rgen(int high) {
        return (int) (((double)high)*rand()/(RAND_MAX+1.0));
}

/* pick_zone, returns the 'best' zone */
int pick_zone() {
        static int last_zone = ZONES - 1; /* More Statics */
        int z;
        /*
         * We still have packets to allocate we pick the next zone
         * with packets left to allocate, however we use the static
         * last_zone to jitter allocations, to make it more random
         */ 
        if(allocs_remaining > 0) {
                for(z = 0; z < ZONES; ++z) {
                        if(allocations[z] > 0 && z != last_zone
                                        && hosts_in[z] > 0) {
                                last_zone = z;
                                --allocs_remaining;
                                --allocations[z];
                                return z;
                        }
                }
        }
        /*
         * allocations have now run out or failed for some reason
         * we randomly pick a zone using the zone_map, which favours
         * the lower zones...
         */ 
        z = zmap[rgen(tpz)];
        while(hosts_in[z] < 1) {
                z = zmap[rgen(tpz)];
        }
        return z;
}

/* rand_host, returns a 'random' host, not quite as it uses the 'best' zone */
struct host_t *rand_host() {
        struct host_t *p = NULL;
        int z;
        
        /* If there are no hosts to pick from why bother going any further */
        if(total_hosts == 0) {
                if(DEBUG) {fprintf(stderr,"rand_host: There are no hosts!\n");}
                return NULL;
        }
        z = pick_zone();
        
        if(DEBUG) {fprintf(stderr,"rand_host, z: %d\n",z);}

        /*
         * This is the main failing if a zone is quite empty we will spend
         * alot of time in here spinning till we find a host.
         * Thats why its a good idea to keep the ZONESIZE small 
         */
        while(p == NULL) {
                p = zone[z][rgen(ZONESIZE)];
        }
        
        /*
         */
        if(p != NULL) {
                struct host_t *h = p;
                int i = 0;
                /* Count the number of hosts in the chain */
                while(h != NULL) {
                        ++i;
                        h = h->hnext;
                }
                /* Pick one at random */
                z = rgen(i);
                for(i = 0; i < z; ++i) {
                        p = p->hnext;
                }
        }
        
        if(DEBUG) {fprintf(stderr,"rand_host, return: %p\n",p);}
        
        return p;
}

/* add_ips, add ip addresses of hosts from file pointer in */
void add_ips(FILE *in) {
        char lb[1024];
        char *p = lb;
        int c = fgetc(in);
        while(c != EOF) {
                while(c != EOF && c != '\n' && ((p - 1024) < lb)) {
                        *(p++) = (char) c;
                        c = fgetc(in);
                }
                *p = '\0';
                add_host(lb,NULL);
                p = lb;
                c = fgetc(in);
        }
}

/*
 * gen_packet, 
 * regenerate the random packet with random data and random length 
 */
void gen_packet() {
        int i;
        pktsize = rgen(PKTSIZE);
        for(i = 0; i < pktsize; ++i) {
                packet[i] = (char) rgen(255);
        }
}

/* forward, forward the packet pointed to by buffer to a 'random' host */
void forward(char *buffer, int size) {
        struct host_t *h = rand_host();
        
        if(DEBUG) {fprintf(stderr,"Forward to: %p, size: %d\n",h,size);}
        
        if(h != NULL) {
                h->out_octets += size;
                ++h->sent;
                sendto(lsocket, buffer, size, 0, (struct sockaddr *)&(h->addr), sizeof(struct sockaddr_in));
        }
}

/* send_rand, sends count random packets */
void send_rand(int count) {
        int i;
        for(i = 0; i < count; ++i) {
                forward(packet, pktsize);
        }
}

/* init_zones, clears the zones, e.g. all pointers = NULL */
void init_zones() {
        int i,j;
        for(i = 0; i < ZONES; ++i) {
                hosts_in[i] = 0;
                for(j = 0; j < ZONESIZE; ++j) {
                        zone[i][j] = NULL;
                }
        }
}

/* report, reports the state of the zones */
void report() {
        struct host_t *h = host_list;
        int z;
        
        fprintf(stderr,"Total Hosts: %d\n",total_hosts);
        for (z = 0; z < ZONES; ++z) {
                fprintf(stderr,"%d Hosts in Zone %d\n",hosts_in[z],z);
        }
        while(h != NULL) {
                fprintf(stderr,"Host: %s, zone: %d, badness: %d\n",inet_ntoa(h->addr.sin_addr),h->zone,h->badness);
                h = h->lnext;
        }
}

/* reset_counters, resets the statistical counters of the host h */
void reset_counters(struct host_t *h) {
        h->recv = 0;
        h->sent = 0;
        h->in_octets = 0;
        h->out_octets = 0;
}

/* hdetails, outputs the statistical counters of host h */
/* XXX: only used when DEBUG > 0 */
void hdetails(struct host_t *h) {
        fprintf(stderr,"host: %s, recv: %d, sent: %d, in_octets: %d, out_octets: %d\n",inet_ntoa(h->addr.sin_addr),h->recv,h->sent,h->in_octets,h->out_octets);
}

/* set_allocations, allocate packets to the zones */
void set_allocations(unsigned long packets, unsigned long octets) {
        int z;
        /* If we are below half of the max link speed add MAX_PKT packets */
        if((octets/BALANCE_INTERVAL) < (MAX_OCTETS_SEC/2)) {
                packets += MAXPKT;
        }
        allocs_remaining = packets;
        for(z = (ZONES - 1); z >= 0; --z) {
                /* Only allocate packets to a zone if it contains hosts */
                if(hosts_in[z] > 0) {
                        allocations[z] = packets / (z+1);
                        packets -= packets / (z+1);
                }
        }
}

/* balance_zones, balances the hosts between the zones and print stats */
void balance_zones() {
        /* Not too bad we walk the hosts list, ZONES+2 times */
        struct host_t *l = NULL;
        struct host_t *h = host_list;
        unsigned long highest = 0;
        long tsent = 0;
        long trecv = 0;
        long osent = 0;
        long orecv = 0;
        double in_rate;
        double out_rate;
        int z;
        
        /*
         * First Walk do the Following:
         * 1. find highest recv count.
         * 2. ++badness if no packet were received,
         *      if badness > 1, host is removed,
         *      place any hosts with badness in last zone.
         * 3. Take the statistics from the hosts.
         * 4. Set the zone of any none bad hosts to -1
         */ 
        while(h != NULL) {
                /* Get stats */
                orecv += h->in_octets;
                osent += h->out_octets;
                trecv += h->recv;
                tsent += h->sent;
                if(DEBUG) {hdetails(h);}
                /* Process info */
                if(h->recv > highest) {
                        highest = h->recv;
                        h->zone = -1;
                        l = h;
                        h = h->lnext;
                } else if(h->recv == 0) {
                        ++h->badness;
                        if(h->badness > 1) {
                                if(DEBUG) {fprintf(stderr,"Removing %s\n",inet_ntoa(h->addr.sin_addr));}
                                if(l == NULL) {
                                        host_list = h->lnext;
                                        free(h);
                                        h = host_list;
                                } else {
                                        l->lnext = h->lnext;
                                        free(h);
                                        h = l->lnext;
                                }
                                --total_hosts;
                        } else {
                                place_host(h,(ZONES - 1));
                                l = h;
                                h = h->lnext;
                        }
                } else {
                        h->zone = -1;
                        l = h;
                        h = h->lnext;
                }
        }
        
        if(DEBUG) {fprintf(stderr,"balance_zones, highest: %d\n",highest);}
        /* Clear The Zones */
        init_zones();

        /* Fill all zones but the one before last */
        for(z = 0; z < (ZONES - 2); ++z) {
                h = host_list;
                highest = (unsigned long)((double)highest / 2.0);
                while(h != NULL) {
                        if(h->recv > highest && h->zone == -1) {
                                place_host(h,z);
                                h->badness = 0; /* Remove Badness */
                        }
                        h = h->lnext;
                }
        }
        /* 
         * Fill the zone before last,
         * while we are at it reset the counters.
         * This ZONE before the last thing is done so that the badness on
         * hosts sending low numbers of packet is maintained, so a host
         * that may have stopped sending packets is quickly evicted.
         */
        z = (ZONES - 2);
        h = host_list;
        while(h != NULL) {
                if(h->recv <= highest && h->zone == -1) {
                        place_host(h,z);
                }
                reset_counters(h);
                h = h->lnext;
        }

        /*
         * Set Allocations for Zones, allocate from which ever is greater of 
         * total received packets or MAXPKT.
         */ 
        set_allocations((trecv < MAXPKT ? MAXPKT : trecv), orecv);

        /* Calculate and Print Statistics */
        out_rate = (((double) osent)/((double) BALANCE_INTERVAL)) / 1024.0;
        in_rate = (((double) orecv)/((double) BALANCE_INTERVAL)) / 1024.0;
        fprintf(stderr,"Rate in: %.0fkb, Rate out: %.0fkb\n", in_rate, out_rate);
        fprintf(stderr,"Packets in: %d, Packets out: %d\n", trecv, tsent);
}       

/* mcast, send a multicast announcement of ourselves */
void mcast() {
        static struct sockaddr_in *addr = NULL;
        int r;
        if(addr == NULL) {
                addr = malloc(sizeof(struct sockaddr_in));
                addr->sin_family        = AF_INET;
                addr->sin_port          = htons(PORT);
                inet_aton(MCAST_GROUP,&(addr->sin_addr));
        }
        r = sendto(lsocket, packet, pktsize, 0, (struct sockaddr *)addr, sizeof(struct sockaddr_in));
        if(DEBUG) {fprintf(stderr,"mcast, r: %d\n",r);}
}

/* Main Balancing+Reporting Loop runs every BALANCE_INTERVAL seconds */
void alarm_handler(int sig) {
        static int calls = 0;
        /* Balance the zones */
        /* this also reports bandwidth, etc */
        balance_zones();
        /* Report the state of the zones */
        report();
        /* Regenerate the random packet */
        gen_packet();
        /* Fire off some extra packets */
        send_rand(5);
        /* Reannounce ourself */
        /* The multicast packets also feed the loop */
        mcast();
        /* Reprogram the alarm timer */
        alarm(BALANCE_INTERVAL);
        /* if debugging we leave after MAX_BALANCES */
        if(DEBUG) {
                ++calls;
                if(calls > MAX_BALANCES) {
                        exit(0);
                }
        }
}       

void init_zmap() {
        double t;
        int i = 0,j = 0,k = 1;

        /* Program 2 to the power ZONES value */
        tpz = power(2, ZONES);
        
        /* zmap is treated as an array but its size is only now know */
        /* so we malloc space for it */
        zmap = malloc(sizeof(int) * tpz);
        if(zmap == NULL) {
                fprintf(stderr,"zmap: Malloc Failed\n");
                exit(1);
        }
        
        zmap[0] = 0; /* Zone 0 always has 1 extra mapping */
        for(i = 0; i < ZONES; ++i) {
                allocations[i] = 0;
                t = tpz * (power(0.5,(i+1)));
                for(j = 0; j < t; ++j) {
                        zmap[j+k] = i;
                }
                k += (int) t;
        }

        if(DEBUG) {
                for(i = 0; i < tpz; ++i) {
                        fprintf(stderr,"zmap[%d] = %d\n",i,zmap[i]);
                }
        }
}

int main(int argc, char *argv[]) {
        char buffer[PKTSIZE];
        int size        = 0;
        socklen_t len   = 0;
        struct host_t *h;
        FILE *input;
        struct sockaddr_in addr;
        struct ip_mreqn mcast_addr;
        int r;
        const int zero = 0;
        
        /* Program Listenning Address */
        addr.sin_family         = AF_INET;
        addr.sin_port           = htons(PORT);
        addr.sin_addr.s_addr    = INADDR_ANY;
        
        /* Clear Zones */
        init_zones();
        /* Initalize the Zone Mappings */
        init_zmap();

        /* A few Debug checks */
        if(DEBUG) {fprintf(stderr,"2^2 == %d\n",(int)power(2,2));}
        if(DEBUG) {fprintf(stderr,"0.5^2 == %f\n",power(0.5,2));}
        if(DEBUG) {fprintf(stderr,"tpz == %d\n",tpz);}
        
        /* Generate the first random packet */
        gen_packet();
        
        /* Check Arguments and Load IP List if there is one */
        if(argc >= 2) {
                input = fopen(argv[1],"r");
                add_ips(input);
                fclose(input);
        }
        
        /* Setup UDP Socket */
        lsocket = socket(PF_INET, SOCK_DGRAM, 0);
        bind(lsocket, (struct sockaddr *)&addr, sizeof(addr));
        
        /* Join it to the Appropriate Multicast Group */        
        inet_aton(MCAST_GROUP,&(mcast_addr.imr_multiaddr));
        mcast_addr.imr_address.s_addr = INADDR_ANY;
        mcast_addr.imr_ifindex = 0;
        
        r = setsockopt(lsocket, 0, IP_ADD_MEMBERSHIP, &mcast_addr, sizeof(struct ip_mreqn));
        if(r == -1) {
                fprintf(stderr,"Error: %s\n",sys_errlist[errno]);
                exit(1);
        }
        r = setsockopt(lsocket, 0, IP_MULTICAST_LOOP, &zero, sizeof(int));
        if(r == -1) {
                fprintf(stderr,"Error: %s\n",sys_errlist[errno]);
                exit(1);
        }
        
        /* Setup Alarm Handler */
        signal(SIGALRM,&(alarm_handler));
        
        /* Report ZONE states */
        report();
        /* Sleep For PRESLEEP seconds, to make sure everything is ready */
        sleep(PRESLEEP);
        if(DEBUG) {fprintf(stderr,"Let the fun begin\n");}
        
        /* Program the alarm, this sets the balancing rolling */
        alarm(BALANCE_INTERVAL);
        /* Announce ourselves */
        mcast();
        /* Fire off some Packets */
        send_rand(5);
        
        /* This is the Main Loop */
        for(;;) {
                /* Receive a Packet */
                size = recvfrom(lsocket, buffer, PKTSIZE, 0, (struct sockaddr *)&addr, &len);
                
                ++frames;
                if(DEBUG) {fprintf(stderr,"Got Frame, Cycle:%d\n",frames);}
                
                /* Get the h structure for the just received frame */
                h = find_host(NULL,&addr.sin_addr);
                if(h == NULL) {
                        /* If the host doesn't exist add it */
                        h = add_host(NULL,&addr.sin_addr);
                }
                /* Add_host can fail if the host is bogus e.g. 0.0.0.0 */
                if(h != NULL) {
                        h->recv++;
                        h->in_octets += size;
                        forward(buffer, size);
                        /* Every MAXPKT frames the random packet is regenerated */
                        /* and another extra packet is sent */
                        if(frames >= MAXPKT) {
                                gen_packet();
                                send_rand(1);
                                frames = 0;
                        }
                } else {
                        if(DEBUG) {fprintf(stderr,"main, h == (NULL)\n");}
                }
        }
        return 0;
}