first commit
This commit is contained in:
235
test_detect.py
Normal file
235
test_detect.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
Test detection on test.png - find individual boxes and classify each.
|
||||
Saves annotated result to test_result.png
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
from ultralytics import YOLO
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
YOLO_TO_COMPETITION = {0: 1, 1: 3, 2: 2} # hammer=1, pliers=3, wrench=2
|
||||
CLASS_NAMES = {1: "hammer", 2: "wrench", 3: "pliers"}
|
||||
CLASS_COLORS = {1: (0, 255, 0), 2: (255, 165, 0), 3: (0, 0, 255)}
|
||||
|
||||
|
||||
def find_boxes(image):
|
||||
"""Find box-like regions in the top-down camera view."""
|
||||
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
h, w = image.shape[:2]
|
||||
|
||||
# The boxes are white/light colored rectangles on a brown/dark shelf
|
||||
# Try multiple approaches and combine
|
||||
|
||||
boxes = []
|
||||
|
||||
# Approach 1: Look for bright rectangular regions
|
||||
# Boxes appear as light-colored rectangles
|
||||
_, bright_mask = cv2.threshold(gray, 160, 255, cv2.THRESH_BINARY)
|
||||
|
||||
# Approach 2: Saturation-based (boxes are less saturated than shelf)
|
||||
_, sat_mask = cv2.threshold(hsv[:, :, 1], 60, 255, cv2.THRESH_BINARY_INV)
|
||||
|
||||
# Approach 3: Value channel - boxes are brighter
|
||||
_, val_mask = cv2.threshold(hsv[:, :, 2], 150, 255, cv2.THRESH_BINARY)
|
||||
|
||||
# Combine masks
|
||||
combined = cv2.bitwise_and(bright_mask, sat_mask)
|
||||
combined = cv2.bitwise_and(combined, val_mask)
|
||||
|
||||
# Clean up
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
|
||||
combined = cv2.morphologyEx(combined, cv2.MORPH_CLOSE, kernel, iterations=3)
|
||||
combined = cv2.morphologyEx(combined, cv2.MORPH_OPEN, kernel, iterations=2)
|
||||
|
||||
# Save debug mask
|
||||
cv2.imwrite(os.path.join(BASE_DIR, "debug_mask.png"), combined)
|
||||
|
||||
contours, _ = cv2.findContours(combined, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
min_area = (h * w) * 0.005 # 0.5% of frame
|
||||
max_area = (h * w) * 0.15 # 15% of frame
|
||||
|
||||
for cnt in contours:
|
||||
area = cv2.contourArea(cnt)
|
||||
if area < min_area or area > max_area:
|
||||
continue
|
||||
|
||||
rect = cv2.minAreaRect(cnt)
|
||||
box_points = cv2.boxPoints(rect)
|
||||
box_points = np.intp(box_points)
|
||||
|
||||
x, y, bw, bh = cv2.boundingRect(cnt)
|
||||
|
||||
# Aspect ratio filter - boxes should be somewhat rectangular
|
||||
aspect = max(bw, bh) / (min(bw, bh) + 1e-5)
|
||||
if aspect > 5:
|
||||
continue
|
||||
|
||||
# Pad the bounding box slightly
|
||||
pad = 5
|
||||
x1 = max(0, x - pad)
|
||||
y1 = max(0, y - pad)
|
||||
x2 = min(w, x + bw + pad)
|
||||
y2 = min(h, y + bh + pad)
|
||||
|
||||
roi = image[y1:y2, x1:x2]
|
||||
if roi.size == 0:
|
||||
continue
|
||||
|
||||
boxes.append({
|
||||
"roi": roi,
|
||||
"bbox": (x1, y1, x2 - x1, y2 - y1),
|
||||
"center": ((x1 + x2) // 2, (y1 + y2) // 2),
|
||||
"area": area,
|
||||
"contour": cnt,
|
||||
})
|
||||
|
||||
# Sort by area descending
|
||||
boxes.sort(key=lambda b: b["area"], reverse=True)
|
||||
|
||||
# NMS: remove boxes that overlap significantly with a larger box
|
||||
filtered = []
|
||||
for box in boxes:
|
||||
x1, y1, bw1, bh1 = box["bbox"]
|
||||
keep = True
|
||||
for kept in filtered:
|
||||
x2, y2, bw2, bh2 = kept["bbox"]
|
||||
# Compute IoU
|
||||
ix1 = max(x1, x2)
|
||||
iy1 = max(y1, y2)
|
||||
ix2 = min(x1 + bw1, x2 + bw2)
|
||||
iy2 = min(y1 + bh1, y2 + bh2)
|
||||
if ix2 > ix1 and iy2 > iy1:
|
||||
inter = (ix2 - ix1) * (iy2 - iy1)
|
||||
area_small = min(bw1 * bh1, bw2 * bh2)
|
||||
# If intersection covers >40% of the smaller box, drop it
|
||||
if inter / (area_small + 1e-5) > 0.4:
|
||||
keep = False
|
||||
break
|
||||
if keep:
|
||||
filtered.append(box)
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
def classify_roi(model, roi):
|
||||
"""Classify a single ROI."""
|
||||
results = model(roi, imgsz=224, verbose=False)
|
||||
if results and results[0].probs is not None:
|
||||
probs = results[0].probs
|
||||
yolo_class = probs.top1
|
||||
confidence = probs.top1conf.item()
|
||||
comp_class = YOLO_TO_COMPETITION.get(yolo_class, -1)
|
||||
return comp_class, confidence
|
||||
return -1, 0.0
|
||||
|
||||
|
||||
def main():
|
||||
model = YOLO(os.path.join(BASE_DIR, "best.pt"))
|
||||
image = cv2.imread(os.path.join(BASE_DIR, "test.png"))
|
||||
if image is None:
|
||||
print("Error: cannot read test.png")
|
||||
return
|
||||
|
||||
print(f"Image size: {image.shape[1]}x{image.shape[0]}")
|
||||
|
||||
# Find box regions
|
||||
boxes = find_boxes(image)
|
||||
print(f"Found {len(boxes)} box regions")
|
||||
|
||||
# Classify each box
|
||||
annotated = image.copy()
|
||||
results_list = []
|
||||
|
||||
for i, box in enumerate(boxes):
|
||||
comp_class, conf = classify_roi(model, box["roi"])
|
||||
name = CLASS_NAMES.get(comp_class, "unknown")
|
||||
results_list.append((comp_class, conf, box["center"], box["bbox"]))
|
||||
|
||||
print(f" Box {i+1}: {name} (ID={comp_class}, conf={conf:.4f}) "
|
||||
f"center=({box['center'][0]}, {box['center'][1]}) "
|
||||
f"area={box['area']:.0f}")
|
||||
|
||||
# Draw on annotated image
|
||||
x, y, bw, bh = box["bbox"]
|
||||
color = CLASS_COLORS.get(comp_class, (128, 128, 128))
|
||||
cv2.rectangle(annotated, (x, y), (x + bw, y + bh), color, 2)
|
||||
label = f"{name} {conf:.2f}"
|
||||
cv2.putText(annotated, label, (x, y - 5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
||||
|
||||
# Save annotated result
|
||||
out_path = os.path.join(BASE_DIR, "test_result.png")
|
||||
cv2.imwrite(out_path, annotated)
|
||||
print(f"\nAnnotated result saved to: {out_path}")
|
||||
|
||||
# Summary
|
||||
print("\n--- Summary ---")
|
||||
class_counts = {}
|
||||
for comp_class, conf, center, bbox in results_list:
|
||||
name = CLASS_NAMES.get(comp_class, "unknown")
|
||||
class_counts[name] = class_counts.get(name, 0) + 1
|
||||
for name, count in class_counts.items():
|
||||
print(f" {name}: {count}")
|
||||
|
||||
# If segmentation didn't work well, also try sliding window approach
|
||||
if len(boxes) < 3:
|
||||
print("\n--- Fallback: sliding window approach ---")
|
||||
sliding_window_detect(model, image)
|
||||
|
||||
|
||||
def sliding_window_detect(model, image):
|
||||
"""Fallback: use sliding window to find boxes."""
|
||||
h, w = image.shape[:2]
|
||||
window_sizes = [(h // 3, w // 4), (h // 2, w // 3)]
|
||||
step_ratio = 0.3
|
||||
|
||||
all_detections = []
|
||||
|
||||
for wh, ww in window_sizes:
|
||||
step_y = int(wh * step_ratio)
|
||||
step_x = int(ww * step_ratio)
|
||||
|
||||
for y in range(0, h - wh + 1, step_y):
|
||||
for x in range(0, w - ww + 1, step_x):
|
||||
roi = image[y:y+wh, x:x+ww]
|
||||
comp_class, conf = classify_roi(model, roi)
|
||||
if conf > 0.8:
|
||||
cx, cy = x + ww // 2, y + wh // 2
|
||||
all_detections.append((comp_class, conf, cx, cy, x, y, ww, wh))
|
||||
|
||||
# NMS-like: keep highest confidence per class in non-overlapping regions
|
||||
if all_detections:
|
||||
all_detections.sort(key=lambda d: d[1], reverse=True)
|
||||
kept = []
|
||||
for det in all_detections:
|
||||
cx, cy = det[2], det[3]
|
||||
overlap = False
|
||||
for k in kept:
|
||||
if abs(cx - k[2]) < w // 5 and abs(cy - k[3]) < h // 5:
|
||||
overlap = True
|
||||
break
|
||||
if not overlap:
|
||||
kept.append(det)
|
||||
|
||||
annotated = image.copy()
|
||||
for comp_class, conf, cx, cy, x, y, ww, wh in kept:
|
||||
name = CLASS_NAMES.get(comp_class, "unknown")
|
||||
color = CLASS_COLORS.get(comp_class, (128, 128, 128))
|
||||
cv2.rectangle(annotated, (x, y), (x + ww, y + wh), color, 2)
|
||||
label = f"{name} {conf:.2f}"
|
||||
cv2.putText(annotated, label, (x, y - 5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
||||
print(f" {name} (conf={conf:.2f}) at ({cx}, {cy})")
|
||||
|
||||
out_path = os.path.join(BASE_DIR, "test_result_sliding.png")
|
||||
cv2.imwrite(out_path, annotated)
|
||||
print(f" Saved to: {out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user