41 lines
1.1 KiB
Python
41 lines
1.1 KiB
Python
"""
|
|
Train YOLOv8 classification model on the generated dataset.
|
|
Classes: hammer (1), wrench (2), pliers (3) per competition spec.
|
|
"""
|
|
|
|
from ultralytics import YOLO
|
|
import os
|
|
import shutil
|
|
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
DATASET_DIR = os.path.join(BASE_DIR, "dataset")
|
|
|
|
model = YOLO("yolov8n-cls.pt") # nano classification model
|
|
|
|
results = model.train(
|
|
data=DATASET_DIR,
|
|
epochs=80,
|
|
imgsz=224,
|
|
batch=32,
|
|
patience=15,
|
|
project=os.path.join(BASE_DIR, "runs"),
|
|
name="module_v_cls",
|
|
exist_ok=True,
|
|
verbose=True,
|
|
)
|
|
|
|
# Validate
|
|
metrics = model.val()
|
|
print(f"\nValidation accuracy top-1: {metrics.top1:.4f}")
|
|
print(f"Validation accuracy top-5: {metrics.top5:.4f}")
|
|
|
|
# Export to ONNX
|
|
best_path = os.path.join(BASE_DIR, "runs", "module_v_cls", "weights", "best.pt")
|
|
export_model = YOLO(best_path)
|
|
export_model.export(format="onnx", imgsz=224)
|
|
print(f"\nModel exported to ONNX")
|
|
|
|
# Copy best.pt to project root
|
|
shutil.copy(best_path, os.path.join(BASE_DIR, "best.pt"))
|
|
print(f"Best model copied to {os.path.join(BASE_DIR, 'best.pt')}")
|