diff --git a/dataset.py b/dataset.py old mode 100644 new mode 100755 index 771ba1c..a73bc6b --- a/dataset.py +++ b/dataset.py @@ -27,6 +27,8 @@ class Batch_Balanced_Dataset(object): _AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW) self.data_loader_list = [] self.dataloader_iter_list = [] + batch_size_list = [] + Total_batch_size = 0 for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio): _batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1) print('-' * 80) @@ -45,6 +47,8 @@ class Batch_Balanced_Dataset(object): for offset, length in zip(_accumulate(dataset_split), dataset_split)] print(f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}') print(f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}') + batch_size_list.append(str(_batch_size)) + Total_batch_size += _batch_size _data_loader = torch.utils.data.DataLoader( _dataset, batch_size=_batch_size, @@ -54,6 +58,9 @@ class Batch_Balanced_Dataset(object): self.data_loader_list.append(_data_loader) self.dataloader_iter_list.append(iter(_data_loader)) print('-' * 80) + print('Total_batch_size: ', '+'.join(batch_size_list), '=', str(Total_batch_size)) + opt.batch_size = Total_batch_size + print('-' * 80) def get_batch(self): balanced_batch_images = []