Google


Wednesday, January 12, 2011

Kaggle social network challenge - test/train code

For those having participated in the Kaggle social network challenge here is the Python code to split the full downloaded graph into test and training.

#create random sorted train set and test set with equal amounts of true and false edges

import random

samp=9000

#import complete file
f1=open('complete4.txt','r')
f2=open('simplesplit_test.txt','w')
f3=open('simplesplit_validate.txt','w')
f4=open('simplesplit_train.txt','w')


prim=[]
prim_set=set()
sec_set=set()
prim_connections={}
prim_2plus=0
sec_connections={}
sec_2plus=0
for line in f1:
    a=line.split(',')[0]
    b=line.split(',')[1].strip()
    prim.append([a,b,random.random()]) #need rand for later
    if a in prim_set: #if seen before
        prim_connections[a]+=1
    else:
        prim_connections[a]=1       
    if b in sec_set: #if seen before
        sec_connections[b]+=1
    else:
        sec_connections[b]=1       
    prim_set.add(a)
    sec_set.add(b)
   
print len(prim),len(prim_connections),len(sec_connections)

#universe of those with 2+ connections
prim_universe=set()
for p in prim_connections.keys():
    if prim_connections[p]>1:
        prim_2plus+=1
        prim_universe.add(p)

#universe of those with 2+ connections
sec_universe=set()
for p in sec_connections.keys():
    if sec_connections[p]>1:
        sec_2plus+=1
        sec_universe.add(p)
       
print prim_2plus,sec_2plus

#chose 2 sets 5000
sample=random.sample(prim_universe,samp)
sample1=set(random.sample(sample,samp/2))
sample2=set([i for i in sample if i not in sample1])

print len(sample),len(sample1),len(sample2)

#sort by random
prim2=sorted(prim,key=lambda rand:rand[2])

del prim

prim3=[]
sample1_done=set()
for i in prim2:
    if i[0] in sample1:
        if i[0] not in sample1_done and (sec_connections[i[1]]>1 or i[1] in prim_connections): #not done and inbound has other edge
            sec_connections[i[1]]-=1
            f2.write(i[0]+','+i[1]+'\n') #test
            f3.write(i[0]+','+i[1]+',1\n') #validate
            sample1_done.add(i[0]) #is done
            print len(sample1_done)
        else:
            f4.write(i[0]+','+i[1]+'\n') #train       
    else:
        f4.write(i[0]+','+i[1]+'\n') #train
        if i[0] in sample2: #create a subset of prim to speed up non pairs check
            prim3.append([i[0],i[1]])

del prim2

print len(prim3)

#for sample2 chose non connections
count=0
for i in sample2:
    if count
        done=0
        prim4=[j[1] for j in prim3 if i==j[0]] #a subset
        while done==0:
            rand=random.sample(sec_universe,1)[0] #because 1 returns set
            if rand not in prim4 and rand<>i:
                done=1
                count+=1
        print count
        f2.write(i+','+rand+'\n') #test
        f3.write(i+','+rand+',0\n') #validate
    else:
        break

f1.close()
f2.close()
f3.close()
f4.close()

No comments: