Improvements to create-train-val-split.
+ Use python2 for consistency with our other Python code. + The previous version left empty subdirectories. This modification deletes them.
This commit is contained in:
parent
e90694fcba
commit
eecb0041e1
|
@ -1,4 +1,4 @@
|
|||
#!/usr/bin/env python3
|
||||
#!/usr/bin/env python2
|
||||
#
|
||||
# Copyright 2015 Carnegie Mellon University
|
||||
#
|
||||
|
@ -15,37 +15,44 @@
|
|||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import errno
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
|
||||
def mkdirP(path):
|
||||
try:
|
||||
os.makedirs(path)
|
||||
except OSError as exc: # Python >2.5
|
||||
if exc.errno == errno.EEXIST and os.path.isdir(path):
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
def getImgs(imageDir):
|
||||
exts = ["jpg", "png"]
|
||||
|
||||
# All images with one image from each class put into the validation set.
|
||||
allImgsM = []
|
||||
classes = {} # Directory Names -> 0-based indexes for Caffe classes.
|
||||
classes = set()
|
||||
valImgs = []
|
||||
for subdir, dirs, files in os.walk(imageDir):
|
||||
for fName in files:
|
||||
(imageClass, imageName) = (os.path.basename(subdir), fName)
|
||||
if any(imageName.lower().endswith("." + ext) for ext in exts):
|
||||
if imageClass not in classes:
|
||||
caffeClass = len(classes) # 0-based indexes.
|
||||
classes[imageClass] = caffeClass
|
||||
classes.add(imageClass)
|
||||
valImgs.append((imageClass, imageName))
|
||||
else:
|
||||
allImgsM.append((imageClass, imageName))
|
||||
return (allImgsM, classes, valImgs)
|
||||
print("+ Number of Classes: '{}'.".format(len(classes)))
|
||||
return (allImgsM, valImgs)
|
||||
|
||||
|
||||
def createTrainValSplit(imageDir, valRatio):
|
||||
print("+ Val ratio: '{}'.".format(valRatio))
|
||||
|
||||
(allImgsM, classes, valImgs) = getImgs(imageDir)
|
||||
|
||||
print("+ Number of Classes: '{}'.".format(len(classes)))
|
||||
(allImgsM, valImgs) = getImgs(imageDir)
|
||||
|
||||
trainValIdx = int((len(allImgsM) + len(valImgs)) * valRatio) - len(valImgs)
|
||||
assert(trainValIdx > 0) # Otherwise, valRatio is too small.
|
||||
|
@ -61,16 +68,21 @@ def createTrainValSplit(imageDir, valRatio):
|
|||
origPath = os.path.join(imageDir, person, img)
|
||||
newDir = os.path.join(imageDir, 'train', person)
|
||||
newPath = os.path.join(imageDir, 'train', person, img)
|
||||
os.makedirs(newDir, exist_ok=True)
|
||||
mkdirP(newDir)
|
||||
shutil.move(origPath, newPath)
|
||||
|
||||
for person, img in valImgs:
|
||||
origPath = os.path.join(imageDir, person, img)
|
||||
newDir = os.path.join(imageDir, 'val', person)
|
||||
newPath = os.path.join(imageDir, 'val', person, img)
|
||||
os.makedirs(newDir, exist_ok=True)
|
||||
mkdirP(newDir)
|
||||
shutil.move(origPath, newPath)
|
||||
|
||||
for person, img in valImgs:
|
||||
d = os.path.join(imageDir, person)
|
||||
if os.path.isdir(d):
|
||||
os.rmdir(d)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
|
|
Loading…
Reference in New Issue