Add batch splitting ratio check
Check whether the splitting ratio adds up to one, e.g. `opt.batch_ratio="0.1-0.2"` would raise a error while "1"or "0.2-0.8" won't.
This commit is contained in:
parent
68a80fe979
commit
734a4e7ee4
|
@ -29,6 +29,8 @@ class Batch_Balanced_Dataset(object):
|
|||
print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}')
|
||||
log.write(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n')
|
||||
assert len(opt.select_data) == len(opt.batch_ratio)
|
||||
opt.batch_ratio = [float(x) for x in opt.batch_ratio]
|
||||
assert np.sum(opt.batch_ratio) == 1.0
|
||||
|
||||
_AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
|
||||
self.data_loader_list = []
|
||||
|
@ -36,7 +38,7 @@ class Batch_Balanced_Dataset(object):
|
|||
batch_size_list = []
|
||||
Total_batch_size = 0
|
||||
for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio):
|
||||
_batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1)
|
||||
_batch_size = max(round(opt.batch_size * batch_ratio_d), 1)
|
||||
print(dashed_line)
|
||||
log.write(dashed_line + '\n')
|
||||
_dataset, _dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d])
|
||||
|
@ -54,7 +56,7 @@ class Batch_Balanced_Dataset(object):
|
|||
_dataset, _ = [Subset(_dataset, indices[offset - length:offset])
|
||||
for offset, length in zip(_accumulate(dataset_split), dataset_split)]
|
||||
selected_d_log = f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}\n'
|
||||
selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}'
|
||||
selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {batch_ratio_d} (batch_ratio) = {_batch_size}'
|
||||
print(selected_d_log)
|
||||
log.write(selected_d_log + '\n')
|
||||
batch_size_list.append(str(_batch_size))
|
||||
|
|
Loading…
Reference in New Issue