You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
44 lines
2.1 KiB
Python
44 lines
2.1 KiB
Python
3 weeks ago
|
import os
|
||
|
import shutil
|
||
|
import random
|
||
|
|
||
|
def copy_files(labels_dir, images_dir, ratio, parent_dir, classes):
|
||
|
# 创建目标文件夹
|
||
|
labels_train_dir = os.path.join("dataset/gesture/labels", 'train')
|
||
|
labels_train_dir = os.path.join(parent_dir, labels_train_dir)
|
||
|
labels_val_dir = os.path.join("dataset/gesture/labels", 'validation')
|
||
|
labels_val_dir = os.path.join(parent_dir, labels_val_dir)
|
||
|
images_train_dir = os.path.join("dataset/gesture/images", 'train')
|
||
|
images_train_dir = os.path.join(parent_dir, images_train_dir)
|
||
|
images_val_dir = os.path.join("dataset/gesture/images", 'validation')
|
||
|
images_val_dir = os.path.join(parent_dir, images_val_dir)
|
||
|
os.makedirs(labels_train_dir, exist_ok=True)
|
||
|
os.makedirs(labels_val_dir, exist_ok=True)
|
||
|
os.makedirs(images_train_dir, exist_ok=True)
|
||
|
os.makedirs(images_val_dir, exist_ok=True)
|
||
|
txts = os.listdir(labels_dir)
|
||
|
for cls in classes:
|
||
|
label_cls = [txt for txt in txts if txt.startswith(cls)]
|
||
|
random.shuffle(label_cls)
|
||
|
for i, txt in enumerate(label_cls):
|
||
|
labels_src_path = os.path.join(labels_dir, txt)
|
||
|
images_src_path = os.path.join(images_dir, txt.replace('.txt', '.jpg'))
|
||
|
if i < int(ratio * len(label_cls)):
|
||
|
labels_dst_path = os.path.join(labels_train_dir, txt)
|
||
|
images_dst_path = os.path.join(images_train_dir, txt.replace('.txt', '.jpg'))
|
||
|
else:
|
||
|
labels_dst_path = os.path.join(labels_val_dir, txt)
|
||
|
images_dst_path = os.path.join(images_val_dir, txt.replace('.txt', '.jpg'))
|
||
|
shutil.copy(labels_src_path, labels_dst_path)
|
||
|
shutil.copy(images_src_path, images_dst_path)
|
||
|
|
||
|
# 主函数调用
|
||
|
if __name__ == '__main__':
|
||
|
current_dir = os.path.abspath(os.path.dirname(__file__))
|
||
|
parent_dir = os.path.dirname(current_dir)
|
||
|
labels_dir = os.path.join(parent_dir, "labels")
|
||
|
images_dir = os.path.join(parent_dir, "images")
|
||
|
with open("classes.txt", "r") as f:
|
||
|
classes = f.read().splitlines()
|
||
|
copy_files(labels_dir, images_dir, 0.8, parent_dir, classes)
|