236 lines
7.7 KiB
Python
236 lines
7.7 KiB
Python
|
|
"""
|
||
|
|
Test detection on test.png - find individual boxes and classify each.
|
||
|
|
Saves annotated result to test_result.png
|
||
|
|
"""
|
||
|
|
|
||
|
|
import cv2
|
||
|
|
import numpy as np
|
||
|
|
import os
|
||
|
|
from ultralytics import YOLO
|
||
|
|
|
||
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||
|
|
|
||
|
|
YOLO_TO_COMPETITION = {0: 1, 1: 3, 2: 2} # hammer=1, pliers=3, wrench=2
|
||
|
|
CLASS_NAMES = {1: "hammer", 2: "wrench", 3: "pliers"}
|
||
|
|
CLASS_COLORS = {1: (0, 255, 0), 2: (255, 165, 0), 3: (0, 0, 255)}
|
||
|
|
|
||
|
|
|
||
|
|
def find_boxes(image):
|
||
|
|
"""Find box-like regions in the top-down camera view."""
|
||
|
|
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||
|
|
h, w = image.shape[:2]
|
||
|
|
|
||
|
|
# The boxes are white/light colored rectangles on a brown/dark shelf
|
||
|
|
# Try multiple approaches and combine
|
||
|
|
|
||
|
|
boxes = []
|
||
|
|
|
||
|
|
# Approach 1: Look for bright rectangular regions
|
||
|
|
# Boxes appear as light-colored rectangles
|
||
|
|
_, bright_mask = cv2.threshold(gray, 160, 255, cv2.THRESH_BINARY)
|
||
|
|
|
||
|
|
# Approach 2: Saturation-based (boxes are less saturated than shelf)
|
||
|
|
_, sat_mask = cv2.threshold(hsv[:, :, 1], 60, 255, cv2.THRESH_BINARY_INV)
|
||
|
|
|
||
|
|
# Approach 3: Value channel - boxes are brighter
|
||
|
|
_, val_mask = cv2.threshold(hsv[:, :, 2], 150, 255, cv2.THRESH_BINARY)
|
||
|
|
|
||
|
|
# Combine masks
|
||
|
|
combined = cv2.bitwise_and(bright_mask, sat_mask)
|
||
|
|
combined = cv2.bitwise_and(combined, val_mask)
|
||
|
|
|
||
|
|
# Clean up
|
||
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
|
||
|
|
combined = cv2.morphologyEx(combined, cv2.MORPH_CLOSE, kernel, iterations=3)
|
||
|
|
combined = cv2.morphologyEx(combined, cv2.MORPH_OPEN, kernel, iterations=2)
|
||
|
|
|
||
|
|
# Save debug mask
|
||
|
|
cv2.imwrite(os.path.join(BASE_DIR, "debug_mask.png"), combined)
|
||
|
|
|
||
|
|
contours, _ = cv2.findContours(combined, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||
|
|
|
||
|
|
min_area = (h * w) * 0.005 # 0.5% of frame
|
||
|
|
max_area = (h * w) * 0.15 # 15% of frame
|
||
|
|
|
||
|
|
for cnt in contours:
|
||
|
|
area = cv2.contourArea(cnt)
|
||
|
|
if area < min_area or area > max_area:
|
||
|
|
continue
|
||
|
|
|
||
|
|
rect = cv2.minAreaRect(cnt)
|
||
|
|
box_points = cv2.boxPoints(rect)
|
||
|
|
box_points = np.intp(box_points)
|
||
|
|
|
||
|
|
x, y, bw, bh = cv2.boundingRect(cnt)
|
||
|
|
|
||
|
|
# Aspect ratio filter - boxes should be somewhat rectangular
|
||
|
|
aspect = max(bw, bh) / (min(bw, bh) + 1e-5)
|
||
|
|
if aspect > 5:
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Pad the bounding box slightly
|
||
|
|
pad = 5
|
||
|
|
x1 = max(0, x - pad)
|
||
|
|
y1 = max(0, y - pad)
|
||
|
|
x2 = min(w, x + bw + pad)
|
||
|
|
y2 = min(h, y + bh + pad)
|
||
|
|
|
||
|
|
roi = image[y1:y2, x1:x2]
|
||
|
|
if roi.size == 0:
|
||
|
|
continue
|
||
|
|
|
||
|
|
boxes.append({
|
||
|
|
"roi": roi,
|
||
|
|
"bbox": (x1, y1, x2 - x1, y2 - y1),
|
||
|
|
"center": ((x1 + x2) // 2, (y1 + y2) // 2),
|
||
|
|
"area": area,
|
||
|
|
"contour": cnt,
|
||
|
|
})
|
||
|
|
|
||
|
|
# Sort by area descending
|
||
|
|
boxes.sort(key=lambda b: b["area"], reverse=True)
|
||
|
|
|
||
|
|
# NMS: remove boxes that overlap significantly with a larger box
|
||
|
|
filtered = []
|
||
|
|
for box in boxes:
|
||
|
|
x1, y1, bw1, bh1 = box["bbox"]
|
||
|
|
keep = True
|
||
|
|
for kept in filtered:
|
||
|
|
x2, y2, bw2, bh2 = kept["bbox"]
|
||
|
|
# Compute IoU
|
||
|
|
ix1 = max(x1, x2)
|
||
|
|
iy1 = max(y1, y2)
|
||
|
|
ix2 = min(x1 + bw1, x2 + bw2)
|
||
|
|
iy2 = min(y1 + bh1, y2 + bh2)
|
||
|
|
if ix2 > ix1 and iy2 > iy1:
|
||
|
|
inter = (ix2 - ix1) * (iy2 - iy1)
|
||
|
|
area_small = min(bw1 * bh1, bw2 * bh2)
|
||
|
|
# If intersection covers >40% of the smaller box, drop it
|
||
|
|
if inter / (area_small + 1e-5) > 0.4:
|
||
|
|
keep = False
|
||
|
|
break
|
||
|
|
if keep:
|
||
|
|
filtered.append(box)
|
||
|
|
|
||
|
|
return filtered
|
||
|
|
|
||
|
|
|
||
|
|
def classify_roi(model, roi):
|
||
|
|
"""Classify a single ROI."""
|
||
|
|
results = model(roi, imgsz=224, verbose=False)
|
||
|
|
if results and results[0].probs is not None:
|
||
|
|
probs = results[0].probs
|
||
|
|
yolo_class = probs.top1
|
||
|
|
confidence = probs.top1conf.item()
|
||
|
|
comp_class = YOLO_TO_COMPETITION.get(yolo_class, -1)
|
||
|
|
return comp_class, confidence
|
||
|
|
return -1, 0.0
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
model = YOLO(os.path.join(BASE_DIR, "best.pt"))
|
||
|
|
image = cv2.imread(os.path.join(BASE_DIR, "test.png"))
|
||
|
|
if image is None:
|
||
|
|
print("Error: cannot read test.png")
|
||
|
|
return
|
||
|
|
|
||
|
|
print(f"Image size: {image.shape[1]}x{image.shape[0]}")
|
||
|
|
|
||
|
|
# Find box regions
|
||
|
|
boxes = find_boxes(image)
|
||
|
|
print(f"Found {len(boxes)} box regions")
|
||
|
|
|
||
|
|
# Classify each box
|
||
|
|
annotated = image.copy()
|
||
|
|
results_list = []
|
||
|
|
|
||
|
|
for i, box in enumerate(boxes):
|
||
|
|
comp_class, conf = classify_roi(model, box["roi"])
|
||
|
|
name = CLASS_NAMES.get(comp_class, "unknown")
|
||
|
|
results_list.append((comp_class, conf, box["center"], box["bbox"]))
|
||
|
|
|
||
|
|
print(f" Box {i+1}: {name} (ID={comp_class}, conf={conf:.4f}) "
|
||
|
|
f"center=({box['center'][0]}, {box['center'][1]}) "
|
||
|
|
f"area={box['area']:.0f}")
|
||
|
|
|
||
|
|
# Draw on annotated image
|
||
|
|
x, y, bw, bh = box["bbox"]
|
||
|
|
color = CLASS_COLORS.get(comp_class, (128, 128, 128))
|
||
|
|
cv2.rectangle(annotated, (x, y), (x + bw, y + bh), color, 2)
|
||
|
|
label = f"{name} {conf:.2f}"
|
||
|
|
cv2.putText(annotated, label, (x, y - 5),
|
||
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
||
|
|
|
||
|
|
# Save annotated result
|
||
|
|
out_path = os.path.join(BASE_DIR, "test_result.png")
|
||
|
|
cv2.imwrite(out_path, annotated)
|
||
|
|
print(f"\nAnnotated result saved to: {out_path}")
|
||
|
|
|
||
|
|
# Summary
|
||
|
|
print("\n--- Summary ---")
|
||
|
|
class_counts = {}
|
||
|
|
for comp_class, conf, center, bbox in results_list:
|
||
|
|
name = CLASS_NAMES.get(comp_class, "unknown")
|
||
|
|
class_counts[name] = class_counts.get(name, 0) + 1
|
||
|
|
for name, count in class_counts.items():
|
||
|
|
print(f" {name}: {count}")
|
||
|
|
|
||
|
|
# If segmentation didn't work well, also try sliding window approach
|
||
|
|
if len(boxes) < 3:
|
||
|
|
print("\n--- Fallback: sliding window approach ---")
|
||
|
|
sliding_window_detect(model, image)
|
||
|
|
|
||
|
|
|
||
|
|
def sliding_window_detect(model, image):
|
||
|
|
"""Fallback: use sliding window to find boxes."""
|
||
|
|
h, w = image.shape[:2]
|
||
|
|
window_sizes = [(h // 3, w // 4), (h // 2, w // 3)]
|
||
|
|
step_ratio = 0.3
|
||
|
|
|
||
|
|
all_detections = []
|
||
|
|
|
||
|
|
for wh, ww in window_sizes:
|
||
|
|
step_y = int(wh * step_ratio)
|
||
|
|
step_x = int(ww * step_ratio)
|
||
|
|
|
||
|
|
for y in range(0, h - wh + 1, step_y):
|
||
|
|
for x in range(0, w - ww + 1, step_x):
|
||
|
|
roi = image[y:y+wh, x:x+ww]
|
||
|
|
comp_class, conf = classify_roi(model, roi)
|
||
|
|
if conf > 0.8:
|
||
|
|
cx, cy = x + ww // 2, y + wh // 2
|
||
|
|
all_detections.append((comp_class, conf, cx, cy, x, y, ww, wh))
|
||
|
|
|
||
|
|
# NMS-like: keep highest confidence per class in non-overlapping regions
|
||
|
|
if all_detections:
|
||
|
|
all_detections.sort(key=lambda d: d[1], reverse=True)
|
||
|
|
kept = []
|
||
|
|
for det in all_detections:
|
||
|
|
cx, cy = det[2], det[3]
|
||
|
|
overlap = False
|
||
|
|
for k in kept:
|
||
|
|
if abs(cx - k[2]) < w // 5 and abs(cy - k[3]) < h // 5:
|
||
|
|
overlap = True
|
||
|
|
break
|
||
|
|
if not overlap:
|
||
|
|
kept.append(det)
|
||
|
|
|
||
|
|
annotated = image.copy()
|
||
|
|
for comp_class, conf, cx, cy, x, y, ww, wh in kept:
|
||
|
|
name = CLASS_NAMES.get(comp_class, "unknown")
|
||
|
|
color = CLASS_COLORS.get(comp_class, (128, 128, 128))
|
||
|
|
cv2.rectangle(annotated, (x, y), (x + ww, y + wh), color, 2)
|
||
|
|
label = f"{name} {conf:.2f}"
|
||
|
|
cv2.putText(annotated, label, (x, y - 5),
|
||
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
||
|
|
print(f" {name} (conf={conf:.2f}) at ({cx}, {cy})")
|
||
|
|
|
||
|
|
out_path = os.path.join(BASE_DIR, "test_result_sliding.png")
|
||
|
|
cv2.imwrite(out_path, annotated)
|
||
|
|
print(f" Saved to: {out_path}")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|