Diff of /dataprocess/utils.py [000000] .. [f77492]

Switch to unified view

a b/dataprocess/utils.py
1
import numpy as np
2
import random
3
import copy
4
5
6
def data_resample4seg(train_datas, train_masks,index=-2, splitvalue=3, resample_rate=1, is_big=True):
7
    print('Before Resample datas:', len(train_datas))
8
    res_datas = copy.deepcopy(train_datas)
9
    res_masks = copy.deepcopy(train_masks)
10
    for train_data, train_mask in zip(train_datas,train_masks):
11
        if (is_big is True and int(train_data.split('_')[index]) >= splitvalue) or (is_big is False and int(train_data.split('_')[index]) < splitvalue):
12
            for i in range(resample_rate):
13
                res_datas.append(train_data)
14
                res_masks.append(train_mask)
15
    sorted(res_datas)
16
    sorted(res_masks)
17
    temp = list(zip(res_datas, res_masks))
18
    random.shuffle(temp)
19
    res_datas, res_masks = zip(*temp)
20
    print('After Resample datas:', len(res_datas))
21
    return res_datas, res_masks
22
23
24
def data_resample4cls(train_datas, index=-2, splitvalue=3, resample_rate=1, is_big=True):
25
    print('Before Resample datas:',len(train_datas))
26
    resample_datas = copy.deepcopy(train_datas)
27
    for train_data in train_datas:
28
        # print(train_data)
29
        value = int(train_data.split('_')[index])
30
        if (is_big is True and value >= splitvalue) or (is_big is False and value <= splitvalue):
31
            for i in range(resample_rate):
32
                resample_datas.append(train_data)
33
                # print(train_data)
34
    print('After Resample datas:', len(resample_datas))
35
    random.shuffle(resample_datas)
36
    return resample_datas