a | b/dataloaders/split_data.py | ||
---|---|---|---|
1 | import os |
||
2 | from sklearn.model_selection import train_test_split |
||
3 | |||
4 | data_path = 'E:/data/LASet' |
||
5 | names = os.listdir(os.path.join(data_path,'origin')) |
||
6 | train_ids,test_ids = train_test_split(names,test_size=0.2,random_state=367) |
||
7 | with open(os.path.join(data_path,'train.list'),'w') as f: |
||
8 | f.write('\n'.join(train_ids)) |
||
9 | with open(os.path.join(data_path,'test.list'),'w') as f: |
||
10 | f.write('\n'.join(test_ids)) |
||
11 | print(len(names),len(train_ids),len(test_ids)) |