#create random sorted train set and test set with equal amounts of true and false edges
import random
samp=9000
#import complete file
f1=open('complete4.txt','r')
f2=open('simplesplit_test.txt','w')
f3=open('simplesplit_validate.txt','w')
f4=open('simplesplit_train.txt','w')
prim=[]
prim_set=set()
sec_set=set()
prim_connections={}
prim_2plus=0
sec_connections={}
sec_2plus=0
for line in f1:
a=line.split(',')[0]
b=line.split(',')[1].strip()
prim.append([a,b,random.random()]) #need rand for later
if a in prim_set: #if seen before
prim_connections[a]+=1
else:
prim_connections[a]=1
if b in sec_set: #if seen before
sec_connections[b]+=1
else:
sec_connections[b]=1
prim_set.add(a)
sec_set.add(b)
print len(prim),len(prim_connections),len(sec_connections)
#universe of those with 2+ connections
prim_universe=set()
for p in prim_connections.keys():
if prim_connections[p]>1:
prim_2plus+=1
prim_universe.add(p)
#universe of those with 2+ connections
sec_universe=set()
for p in sec_connections.keys():
if sec_connections[p]>1:
sec_2plus+=1
sec_universe.add(p)
print prim_2plus,sec_2plus
#chose 2 sets 5000
sample=random.sample(prim_universe,samp)
sample1=set(random.sample(sample,samp/2))
sample2=set([i for i in sample if i not in sample1])
print len(sample),len(sample1),len(sample2)
#sort by random
prim2=sorted(prim,key=lambda rand:rand[2])
del prim
prim3=[]
sample1_done=set()
for i in prim2:
if i[0] in sample1:
if i[0] not in sample1_done and (sec_connections[i[1]]>1 or i[1] in prim_connections): #not done and inbound has other edge
sec_connections[i[1]]-=1
f2.write(i[0]+','+i[1]+'\n') #test
f3.write(i[0]+','+i[1]+',1\n') #validate
sample1_done.add(i[0]) #is done
print len(sample1_done)
else:
f4.write(i[0]+','+i[1]+'\n') #train
else:
f4.write(i[0]+','+i[1]+'\n') #train
if i[0] in sample2: #create a subset of prim to speed up non pairs check
prim3.append([i[0],i[1]])
del prim2
print len(prim3)
#for sample2 chose non connections
count=0
for i in sample2:
if count
done=0
prim4=[j[1] for j in prim3 if i==j[0]] #a subset
while done==0:
rand=random.sample(sec_universe,1)[0] #because 1 returns set
if rand not in prim4 and rand<>i:
done=1
count+=1
print count
f2.write(i+','+rand+'\n') #test
f3.write(i+','+rand+',0\n') #validate
else:
break
f1.close()
f2.close()
f3.close()
f4.close()


0 comments:
Post a Comment