221 lines
7.9 KiB
Python
221 lines
7.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
infer_webcam_multi.py
|
|
Live multi-letter inference from webcam using multiple per-letter binary models.
|
|
|
|
Examples:
|
|
# Detect A, B, C using default filenames asl_A_mlp.pt, asl_B_mlp.pt, asl_C_mlp.pt
|
|
python infer_webcam_multi.py --letters A,B,C
|
|
|
|
# Same but with a confidence threshold for accepting any letter
|
|
python infer_webcam_multi.py --letters A,B,C --threshold 0.8
|
|
|
|
# Explicit model paths (overrides --letters)
|
|
python infer_webcam_multi.py --models asl_A_mlp.pt asl_B_mlp.pt --threshold 0.75
|
|
|
|
Press 'q' to quit.
|
|
"""
|
|
import os, math, argparse
|
|
import numpy as np
|
|
import cv2
|
|
import torch
|
|
import mediapipe as mp
|
|
|
|
# ---------- geometry helpers ----------
|
|
def _angle(v): return math.atan2(v[1], v[0])
|
|
def _rot2d(t):
|
|
c, s = math.cos(t), math.sin(t)
|
|
return np.array([[c, -s], [s, c]], dtype=np.float32)
|
|
|
|
def normalize_landmarks(pts, handedness_label=None):
|
|
pts = pts.astype(np.float32).copy()
|
|
# translate wrist to origin
|
|
pts[:, :2] -= pts[0, :2]
|
|
# mirror left→right
|
|
if handedness_label and handedness_label.lower().startswith("left"):
|
|
pts[:, 0] *= -1.0
|
|
# rotate wrist→middle_mcp to +Y
|
|
v = pts[9, :2]
|
|
R = _rot2d(math.pi/2 - _angle(v))
|
|
pts[:, :2] = pts[:, :2] @ R.T
|
|
# scale by max pairwise distance
|
|
xy = pts[:, :2]
|
|
d = np.linalg.norm(xy[None,:,:] - xy[:,None,:], axis=-1).max()
|
|
d = 1.0 if d < 1e-6 else float(d)
|
|
pts[:, :2] /= d; pts[:, 2] /= d
|
|
return pts.reshape(-1)
|
|
|
|
# ---------- MLP ----------
|
|
class MLP(torch.nn.Module):
|
|
def __init__(self, in_dim, num_classes):
|
|
super().__init__()
|
|
self.net = torch.nn.Sequential(
|
|
torch.nn.Linear(in_dim, 128),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Dropout(0.2),
|
|
torch.nn.Linear(128, 64),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Dropout(0.1),
|
|
torch.nn.Linear(64, num_classes),
|
|
)
|
|
def forward(self, x): return self.net(x)
|
|
|
|
# ---------- Utilities ----------
|
|
def load_model_bundle(model_path):
|
|
"""
|
|
Load a single per-letter model checkpoint and return a dict bundle with:
|
|
- 'model': torch.nn.Module (eval, on device)
|
|
- 'classes': list of class names, e.g. ['Not_A', 'A']
|
|
- 'pos_index': index of the positive (letter) class in 'classes'
|
|
- 'X_mean', 'X_std': np arrays (1, 63)
|
|
- 'letter': inferred letter string for display (e.g., 'A')
|
|
"""
|
|
state = torch.load(model_path, map_location="cpu", weights_only=False)
|
|
classes = state["classes"]
|
|
# identify the "letter" class: prefer anything not starting with "Not_"
|
|
# fallback: last class
|
|
pos_idx = None
|
|
for i, c in enumerate(classes):
|
|
if not c.lower().startswith("not_"):
|
|
pos_idx = i
|
|
break
|
|
if pos_idx is None:
|
|
pos_idx = len(classes) - 1
|
|
|
|
# letter name (strip Not_ if needed)
|
|
letter_name = classes[pos_idx]
|
|
if letter_name.lower().startswith("not_"):
|
|
letter_name = letter_name[4:]
|
|
|
|
X_mean = state["X_mean"]; X_std = state["X_std"]
|
|
if isinstance(X_mean, torch.Tensor): X_mean = X_mean.cpu().numpy()
|
|
if isinstance(X_std, torch.Tensor): X_std = X_std.cpu().numpy()
|
|
X_mean = np.asarray(X_mean, dtype=np.float32)
|
|
X_std = np.asarray(X_std, dtype=np.float32) + 1e-6
|
|
|
|
model = MLP(63, len(classes))
|
|
model.load_state_dict(state["model"])
|
|
model.eval()
|
|
|
|
return {
|
|
"path": model_path,
|
|
"model": model,
|
|
"classes": classes,
|
|
"pos_index": pos_idx,
|
|
"X_mean": X_mean,
|
|
"X_std": X_std,
|
|
"letter": letter_name,
|
|
}
|
|
|
|
def put_text(img, text, org, scale=1.1, color=(0,255,0), thick=2):
|
|
cv2.putText(img, text, org, cv2.FONT_HERSHEY_SIMPLEX, scale, color, thick, cv2.LINE_AA)
|
|
|
|
# ---------- Main ----------
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--letters", help="Comma-separated letters, e.g. A,B,C (uses asl_<L>_mlp.pt)")
|
|
ap.add_argument("--models", nargs="+", help="Explicit model paths (overrides --letters)")
|
|
ap.add_argument("--threshold", type=float, default=0.5,
|
|
help="Reject threshold on positive-class probability (default: 0.5)")
|
|
ap.add_argument("--camera", type=int, default=0, help="OpenCV camera index (default: 0)")
|
|
ap.add_argument("--width", type=int, default=640, help="Requested capture width (default: 640)")
|
|
ap.add_argument("--height", type=int, default=480, help="Requested capture height (default: 480)")
|
|
args = ap.parse_args()
|
|
|
|
model_paths = []
|
|
if args.models:
|
|
model_paths = args.models
|
|
elif args.letters:
|
|
for L in [s.strip().upper() for s in args.letters.split(",") if s.strip()]:
|
|
model_paths.append(f"asl_{L}_mlp.pt")
|
|
else:
|
|
raise SystemExit("Please provide --letters A,B,C or --models path1.pt path2.pt ...")
|
|
|
|
# Check files
|
|
for p in model_paths:
|
|
if not os.path.exists(p):
|
|
raise SystemExit(f"❌ Model file not found: {p}")
|
|
|
|
# Device
|
|
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
|
|
|
|
# Load bundles
|
|
bundles = [load_model_bundle(p) for p in model_paths]
|
|
for b in bundles:
|
|
b["model"].to(device)
|
|
print("✅ Loaded models:", ", ".join(f"{b['letter']}({os.path.basename(b['path'])})" for b in bundles))
|
|
|
|
# MediaPipe Hands
|
|
hands = mp.solutions.hands.Hands(
|
|
static_image_mode=False, max_num_hands=1, min_detection_confidence=0.5
|
|
)
|
|
|
|
# Camera
|
|
cap = cv2.VideoCapture(args.camera)
|
|
if not cap.isOpened():
|
|
raise SystemExit(f"❌ Could not open camera index {args.camera}")
|
|
cap.set(cv2.CAP_PROP_FRAME_WIDTH, args.width)
|
|
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, args.height)
|
|
|
|
print("Press 'q' to quit.")
|
|
while True:
|
|
ok, frame = cap.read()
|
|
if not ok:
|
|
break
|
|
|
|
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
res = hands.process(rgb)
|
|
|
|
overlay = frame.copy()
|
|
label_text = "No hand"
|
|
scoreboard = []
|
|
if res.multi_hand_landmarks:
|
|
ih = res.multi_hand_landmarks[0]
|
|
handed = None
|
|
if res.multi_handedness:
|
|
handed = res.multi_handedness[0].classification[0].label
|
|
|
|
pts = np.array([[lm.x, lm.y, lm.z] for lm in ih.landmark], dtype=np.float32)
|
|
feat = normalize_landmarks(pts, handedness_label=handed)
|
|
|
|
# Evaluate each model
|
|
best_letter, best_prob = None, -1.0
|
|
for b in bundles:
|
|
X_mean = b["X_mean"].flatten()
|
|
X_std = b["X_std"].flatten()
|
|
xn = (feat - X_mean) / X_std
|
|
xt = torch.from_numpy(xn).float().unsqueeze(0).to(device)
|
|
with torch.no_grad():
|
|
probs = torch.softmax(b["model"](xt), dim=1)[0].cpu().numpy()
|
|
p_pos = float(probs[b["pos_index"]])
|
|
scoreboard.append((b["letter"], p_pos))
|
|
if p_pos > best_prob:
|
|
best_prob = p_pos
|
|
best_letter = b["letter"]
|
|
|
|
# Compose label based on threshold
|
|
if best_prob >= args.threshold:
|
|
label_text = f"{best_letter} {best_prob*100:.1f}%"
|
|
else:
|
|
label_text = f"Unknown ({best_letter} {best_prob*100:.1f}%)"
|
|
|
|
# Sort scoreboard desc and show top 3
|
|
scoreboard.sort(key=lambda x: x[1], reverse=True)
|
|
y0 = 80
|
|
put_text(overlay, "Scores:", (20, y0), scale=0.9, color=(0,255,255), thick=2)
|
|
y = y0 + 30
|
|
for i, (L, p) in enumerate(scoreboard[:3]):
|
|
put_text(overlay, f"{L}: {p*100:.1f}%", (20, y), scale=0.9, color=(0,255,0), thick=2)
|
|
y += 28
|
|
|
|
put_text(overlay, label_text, (20, 40), scale=1.2, color=(0,255,0), thick=3)
|
|
cv2.imshow("ASL multi-letter demo", overlay)
|
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
|
break
|
|
|
|
cap.release()
|
|
cv2.destroyAllWindows()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|