diff --git a/src/lib/utils/post_process.py b/src/lib/utils/post_process.py index 8346549..6894c0f 100644 --- a/src/lib/utils/post_process.py +++ b/src/lib/utils/post_process.py @@ -15,8 +15,8 @@ def get_alpha(rot): # bin2_cls[0], bin2_cls[1], bin2_sin, bin2_cos] # return rot[:, 0] idx = rot[:, 1] > rot[:, 5] - alpha1 = np.arctan(rot[:, 2] / rot[:, 3]) + (-0.5 * np.pi) - alpha2 = np.arctan(rot[:, 6] / rot[:, 7]) + ( 0.5 * np.pi) + alpha1 = np.arctan2(rot[:, 2], rot[:, 3]) + (-0.5 * np.pi) + alpha2 = np.arctan2(rot[:, 6], rot[:, 7]) + ( 0.5 * np.pi) return alpha1 * idx + alpha2 * (1 - idx)