Files
webots-vision-2/train_model.py
2026-04-03 07:30:54 +03:00

41 lines
1.1 KiB
Python

"""
Train YOLOv8 classification model on the generated dataset.
Classes: hammer (1), wrench (2), pliers (3) per competition spec.
"""
from ultralytics import YOLO
import os
import shutil
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DATASET_DIR = os.path.join(BASE_DIR, "dataset")
model = YOLO("yolov8n-cls.pt") # nano classification model
results = model.train(
data=DATASET_DIR,
epochs=80,
imgsz=224,
batch=32,
patience=15,
project=os.path.join(BASE_DIR, "runs"),
name="module_v_cls",
exist_ok=True,
verbose=True,
)
# Validate
metrics = model.val()
print(f"\nValidation accuracy top-1: {metrics.top1:.4f}")
print(f"Validation accuracy top-5: {metrics.top5:.4f}")
# Export to ONNX
best_path = os.path.join(BASE_DIR, "runs", "module_v_cls", "weights", "best.pt")
export_model = YOLO(best_path)
export_model.export(format="onnx", imgsz=224)
print(f"\nModel exported to ONNX")
# Copy best.pt to project root
shutil.copy(best_path, os.path.join(BASE_DIR, "best.pt"))
print(f"Best model copied to {os.path.join(BASE_DIR, 'best.pt')}")