Improvements to create-train-val-split.

+ Use python2 for consistency with our other Python code.
+ The previous version left empty subdirectories.
  This modification deletes them.
This commit is contained in:
Brandon Amos 2015-12-21 18:08:39 -05:00
parent e90694fcba
commit eecb0041e1
1 changed files with 22 additions and 10 deletions

View File

@ -1,4 +1,4 @@
#!/usr/bin/env python3
#!/usr/bin/env python2
#
# Copyright 2015 Carnegie Mellon University
#
@ -15,37 +15,44 @@
# limitations under the License.
import argparse
import errno
import os
import random
import shutil
def mkdirP(path):
try:
os.makedirs(path)
except OSError as exc: # Python >2.5
if exc.errno == errno.EEXIST and os.path.isdir(path):
pass
else:
raise
def getImgs(imageDir):
exts = ["jpg", "png"]
# All images with one image from each class put into the validation set.
allImgsM = []
classes = {} # Directory Names -> 0-based indexes for Caffe classes.
classes = set()
valImgs = []
for subdir, dirs, files in os.walk(imageDir):
for fName in files:
(imageClass, imageName) = (os.path.basename(subdir), fName)
if any(imageName.lower().endswith("." + ext) for ext in exts):
if imageClass not in classes:
caffeClass = len(classes) # 0-based indexes.
classes[imageClass] = caffeClass
classes.add(imageClass)
valImgs.append((imageClass, imageName))
else:
allImgsM.append((imageClass, imageName))
return (allImgsM, classes, valImgs)
print("+ Number of Classes: '{}'.".format(len(classes)))
return (allImgsM, valImgs)
def createTrainValSplit(imageDir, valRatio):
print("+ Val ratio: '{}'.".format(valRatio))
(allImgsM, classes, valImgs) = getImgs(imageDir)
print("+ Number of Classes: '{}'.".format(len(classes)))
(allImgsM, valImgs) = getImgs(imageDir)
trainValIdx = int((len(allImgsM) + len(valImgs)) * valRatio) - len(valImgs)
assert(trainValIdx > 0) # Otherwise, valRatio is too small.
@ -61,16 +68,21 @@ def createTrainValSplit(imageDir, valRatio):
origPath = os.path.join(imageDir, person, img)
newDir = os.path.join(imageDir, 'train', person)
newPath = os.path.join(imageDir, 'train', person, img)
os.makedirs(newDir, exist_ok=True)
mkdirP(newDir)
shutil.move(origPath, newPath)
for person, img in valImgs:
origPath = os.path.join(imageDir, person, img)
newDir = os.path.join(imageDir, 'val', person)
newPath = os.path.join(imageDir, 'val', person, img)
os.makedirs(newDir, exist_ok=True)
mkdirP(newDir)
shutil.move(origPath, newPath)
for person, img in valImgs:
d = os.path.join(imageDir, person)
if os.path.isdir(d):
os.rmdir(d)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(