Diff of /dataloaders/split_data.py [000000] .. [903821]

Switch to unified view

a b/dataloaders/split_data.py
1
import os
2
from sklearn.model_selection import train_test_split
3
4
data_path = 'E:/data/LASet'
5
names = os.listdir(os.path.join(data_path,'origin'))
6
train_ids,test_ids = train_test_split(names,test_size=0.2,random_state=367)
7
with open(os.path.join(data_path,'train.list'),'w') as f:
8
    f.write('\n'.join(train_ids))
9
with open(os.path.join(data_path,'test.list'),'w') as f:
10
    f.write('\n'.join(test_ids))
11
print(len(names),len(train_ids),len(test_ids))