fix legacy typos in code

This commit is contained in:
Xingyi Zhou 2019-04-18 17:42:04 -05:00
parent 7f411457b4
commit d19b524315
3 changed files with 12 additions and 11 deletions

View File

@ -273,6 +273,7 @@ class opts(object):
opt.exp_dir = os.path.join(opt.root_dir, 'exp', opt.task)
opt.save_dir = os.path.join(opt.exp_dir, opt.exp_id)
opt.debug_dir = os.path.join(opt.save_dir, 'debug')
print('The output will be saved to ', opt.save_dir)
if opt.resume and opt.load_model == '':
model_path = opt.save_dir[:-4] if opt.save_dir.endswith('TEST') \
@ -354,7 +355,7 @@ class opts(object):
def __init__(self, entries):
for k, v in entries.items():
self.__setattr__(k, v)
opt = self.parse()
opt = self.parse(args)
dataset = Struct(default_dataset_info[opt.task])
opt.dataset = dataset.dataset
opt = self.update_dataset_info_and_set_heads(opt, dataset)

View File

@ -77,7 +77,7 @@ class DddTrainer(BaseTrainer):
opt = self.opt
wh = output['wh'] if opt.reg_bbox else None
reg = output['reg'] if opt.reg_offset else None
dets = ctadd_decode(output['hm'], output['rot'], output['dep'],
dets = ddd_decode(output['hm'], output['rot'], output['dep'],
output['dim'], wh=wh, reg=reg, K=opt.K)
# x, y, score, r1-r8, depth, dim1-dim3, cls
@ -86,10 +86,10 @@ class DddTrainer(BaseTrainer):
# x, y, score, rot, depth, dim1, dim2, dim3
# if opt.dataset == 'gta':
# dets[:, 12:15] /= 3
dets_pred = ctadd_post_process(
dets_pred = ddd_post_process(
dets.copy(), batch['meta']['c'].detach().numpy(),
batch['meta']['s'].detach().numpy(), calib, opt)
dets_gt = ctadd_post_process(
dets_gt = ddd_post_process(
batch['meta']['gt_det'].detach().numpy().copy(),
batch['meta']['c'].detach().numpy(),
batch['meta']['s'].detach().numpy(), calib, opt)
@ -138,14 +138,14 @@ class DddTrainer(BaseTrainer):
opt = self.opt
wh = output['wh'] if opt.reg_bbox else None
reg = output['reg'] if opt.reg_offset else None
dets = ctadd_decode(output['hm'], output['rot'], output['dep'],
dets = ddd_decode(output['hm'], output['rot'], output['dep'],
output['dim'], wh=wh, reg=reg, K=opt.K)
# x, y, score, r1-r8, depth, dim1-dim3, cls
dets = dets.detach().cpu().numpy().reshape(1, -1, dets.shape[2])
calib = batch['meta']['calib'].detach().numpy()
# x, y, score, rot, depth, dim1, dim2, dim3
dets_pred = ctadd_post_process(
dets_pred = ddd_post_process(
dets.copy(), batch['meta']['c'].detach().numpy(),
batch['meta']['s'].detach().numpy(), calib, opt)
img_id = batch['meta']['img_id'].detach().numpy()[0]

View File

@ -222,16 +222,16 @@ class Debugger(object):
else:
self.ax = None
nImgs = len(self.imgs)
fig=plt.figure(figsize=(nImgs * 10,10))
fig=self.plt.figure(figsize=(nImgs * 10,10))
nCols = nImgs
nRows = nImgs // nCols
for i, (k, v) in enumerate(self.imgs.items()):
fig.add_subplot(1, nImgs, i + 1)
if len(v.shape) == 3:
plt.imshow(cv2.cvtColor(v, cv2.COLOR_BGR2RGB))
self.plt.imshow(cv2.cvtColor(v, cv2.COLOR_BGR2RGB))
else:
plt.imshow(v)
plt.show()
self.plt.imshow(v)
self.plt.show()
def save_img(self, imgId='default', path='./cache/debug/'):
cv2.imwrite(path + '{}.png'.format(imgId), self.imgs[imgId])