first commit
This commit is contained in:
40
train_model.py
Normal file
40
train_model.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Train YOLOv8 classification model on the generated dataset.
|
||||
Classes: hammer (1), wrench (2), pliers (3) per competition spec.
|
||||
"""
|
||||
|
||||
from ultralytics import YOLO
|
||||
import os
|
||||
import shutil
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
DATASET_DIR = os.path.join(BASE_DIR, "dataset")
|
||||
|
||||
model = YOLO("yolov8n-cls.pt") # nano classification model
|
||||
|
||||
results = model.train(
|
||||
data=DATASET_DIR,
|
||||
epochs=80,
|
||||
imgsz=224,
|
||||
batch=32,
|
||||
patience=15,
|
||||
project=os.path.join(BASE_DIR, "runs"),
|
||||
name="module_v_cls",
|
||||
exist_ok=True,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Validate
|
||||
metrics = model.val()
|
||||
print(f"\nValidation accuracy top-1: {metrics.top1:.4f}")
|
||||
print(f"Validation accuracy top-5: {metrics.top5:.4f}")
|
||||
|
||||
# Export to ONNX
|
||||
best_path = os.path.join(BASE_DIR, "runs", "module_v_cls", "weights", "best.pt")
|
||||
export_model = YOLO(best_path)
|
||||
export_model.export(format="onnx", imgsz=224)
|
||||
print(f"\nModel exported to ONNX")
|
||||
|
||||
# Copy best.pt to project root
|
||||
shutil.copy(best_path, os.path.join(BASE_DIR, "best.pt"))
|
||||
print(f"Best model copied to {os.path.join(BASE_DIR, 'best.pt')}")
|
||||
Reference in New Issue
Block a user