Add batch splitting ratio check

Check whether the splitting ratio adds up to one, e.g. `opt.batch_ratio="0.1-0.2"` would raise a error while "1"or "0.2-0.8" won't.
This commit is contained in:
#W[_t 2022-04-08 18:28:56 +08:00 committed by GitHub
parent 68a80fe979
commit 734a4e7ee4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 2 deletions

View File

@ -29,6 +29,8 @@ class Batch_Balanced_Dataset(object):
print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}')
log.write(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n')
assert len(opt.select_data) == len(opt.batch_ratio)
opt.batch_ratio = [float(x) for x in opt.batch_ratio]
assert np.sum(opt.batch_ratio) == 1.0
_AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
self.data_loader_list = []
@ -36,7 +38,7 @@ class Batch_Balanced_Dataset(object):
batch_size_list = []
Total_batch_size = 0
for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio):
_batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1)
_batch_size = max(round(opt.batch_size * batch_ratio_d), 1)
print(dashed_line)
log.write(dashed_line + '\n')
_dataset, _dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d])
@ -54,7 +56,7 @@ class Batch_Balanced_Dataset(object):
_dataset, _ = [Subset(_dataset, indices[offset - length:offset])
for offset, length in zip(_accumulate(dataset_split), dataset_split)]
selected_d_log = f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}\n'
selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}'
selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {batch_ratio_d} (batch_ratio) = {_batch_size}'
print(selected_d_log)
log.write(selected_d_log + '\n')
batch_size_list.append(str(_batch_size))