update dataset print
This commit is contained in:
parent
7da666fc3a
commit
a89c0816bc
|
@ -27,6 +27,8 @@ class Batch_Balanced_Dataset(object):
|
|||
_AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW)
|
||||
self.data_loader_list = []
|
||||
self.dataloader_iter_list = []
|
||||
batch_size_list = []
|
||||
Total_batch_size = 0
|
||||
for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio):
|
||||
_batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1)
|
||||
print('-' * 80)
|
||||
|
@ -45,6 +47,8 @@ class Batch_Balanced_Dataset(object):
|
|||
for offset, length in zip(_accumulate(dataset_split), dataset_split)]
|
||||
print(f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}')
|
||||
print(f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}')
|
||||
batch_size_list.append(str(_batch_size))
|
||||
Total_batch_size += _batch_size
|
||||
|
||||
_data_loader = torch.utils.data.DataLoader(
|
||||
_dataset, batch_size=_batch_size,
|
||||
|
@ -54,6 +58,9 @@ class Batch_Balanced_Dataset(object):
|
|||
self.data_loader_list.append(_data_loader)
|
||||
self.dataloader_iter_list.append(iter(_data_loader))
|
||||
print('-' * 80)
|
||||
print('Total_batch_size: ', '+'.join(batch_size_list), '=', str(Total_batch_size))
|
||||
opt.batch_size = Total_batch_size
|
||||
print('-' * 80)
|
||||
|
||||
def get_batch(self):
|
||||
balanced_batch_images = []
|
||||
|
|
Loading…
Reference in New Issue