Files
slr_handshapes/infer_webcam-multi.py
2026-01-19 22:19:15 -05:00

221 lines
7.9 KiB
Python

#!/usr/bin/env python3
"""
infer_webcam_multi.py
Live multi-letter inference from webcam using multiple per-letter binary models.
Examples:
# Detect A, B, C using default filenames asl_A_mlp.pt, asl_B_mlp.pt, asl_C_mlp.pt
python infer_webcam_multi.py --letters A,B,C
# Same but with a confidence threshold for accepting any letter
python infer_webcam_multi.py --letters A,B,C --threshold 0.8
# Explicit model paths (overrides --letters)
python infer_webcam_multi.py --models asl_A_mlp.pt asl_B_mlp.pt --threshold 0.75
Press 'q' to quit.
"""
import os, math, argparse
import numpy as np
import cv2
import torch
import mediapipe as mp
# ---------- geometry helpers ----------
def _angle(v): return math.atan2(v[1], v[0])
def _rot2d(t):
c, s = math.cos(t), math.sin(t)
return np.array([[c, -s], [s, c]], dtype=np.float32)
def normalize_landmarks(pts, handedness_label=None):
pts = pts.astype(np.float32).copy()
# translate wrist to origin
pts[:, :2] -= pts[0, :2]
# mirror left→right
if handedness_label and handedness_label.lower().startswith("left"):
pts[:, 0] *= -1.0
# rotate wrist→middle_mcp to +Y
v = pts[9, :2]
R = _rot2d(math.pi/2 - _angle(v))
pts[:, :2] = pts[:, :2] @ R.T
# scale by max pairwise distance
xy = pts[:, :2]
d = np.linalg.norm(xy[None,:,:] - xy[:,None,:], axis=-1).max()
d = 1.0 if d < 1e-6 else float(d)
pts[:, :2] /= d; pts[:, 2] /= d
return pts.reshape(-1)
# ---------- MLP ----------
class MLP(torch.nn.Module):
def __init__(self, in_dim, num_classes):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Linear(in_dim, 128),
torch.nn.ReLU(),
torch.nn.Dropout(0.2),
torch.nn.Linear(128, 64),
torch.nn.ReLU(),
torch.nn.Dropout(0.1),
torch.nn.Linear(64, num_classes),
)
def forward(self, x): return self.net(x)
# ---------- Utilities ----------
def load_model_bundle(model_path):
"""
Load a single per-letter model checkpoint and return a dict bundle with:
- 'model': torch.nn.Module (eval, on device)
- 'classes': list of class names, e.g. ['Not_A', 'A']
- 'pos_index': index of the positive (letter) class in 'classes'
- 'X_mean', 'X_std': np arrays (1, 63)
- 'letter': inferred letter string for display (e.g., 'A')
"""
state = torch.load(model_path, map_location="cpu", weights_only=False)
classes = state["classes"]
# identify the "letter" class: prefer anything not starting with "Not_"
# fallback: last class
pos_idx = None
for i, c in enumerate(classes):
if not c.lower().startswith("not_"):
pos_idx = i
break
if pos_idx is None:
pos_idx = len(classes) - 1
# letter name (strip Not_ if needed)
letter_name = classes[pos_idx]
if letter_name.lower().startswith("not_"):
letter_name = letter_name[4:]
X_mean = state["X_mean"]; X_std = state["X_std"]
if isinstance(X_mean, torch.Tensor): X_mean = X_mean.cpu().numpy()
if isinstance(X_std, torch.Tensor): X_std = X_std.cpu().numpy()
X_mean = np.asarray(X_mean, dtype=np.float32)
X_std = np.asarray(X_std, dtype=np.float32) + 1e-6
model = MLP(63, len(classes))
model.load_state_dict(state["model"])
model.eval()
return {
"path": model_path,
"model": model,
"classes": classes,
"pos_index": pos_idx,
"X_mean": X_mean,
"X_std": X_std,
"letter": letter_name,
}
def put_text(img, text, org, scale=1.1, color=(0,255,0), thick=2):
cv2.putText(img, text, org, cv2.FONT_HERSHEY_SIMPLEX, scale, color, thick, cv2.LINE_AA)
# ---------- Main ----------
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--letters", help="Comma-separated letters, e.g. A,B,C (uses asl_<L>_mlp.pt)")
ap.add_argument("--models", nargs="+", help="Explicit model paths (overrides --letters)")
ap.add_argument("--threshold", type=float, default=0.5,
help="Reject threshold on positive-class probability (default: 0.5)")
ap.add_argument("--camera", type=int, default=0, help="OpenCV camera index (default: 0)")
ap.add_argument("--width", type=int, default=640, help="Requested capture width (default: 640)")
ap.add_argument("--height", type=int, default=480, help="Requested capture height (default: 480)")
args = ap.parse_args()
model_paths = []
if args.models:
model_paths = args.models
elif args.letters:
for L in [s.strip().upper() for s in args.letters.split(",") if s.strip()]:
model_paths.append(f"asl_{L}_mlp.pt")
else:
raise SystemExit("Please provide --letters A,B,C or --models path1.pt path2.pt ...")
# Check files
for p in model_paths:
if not os.path.exists(p):
raise SystemExit(f"❌ Model file not found: {p}")
# Device
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
# Load bundles
bundles = [load_model_bundle(p) for p in model_paths]
for b in bundles:
b["model"].to(device)
print("✅ Loaded models:", ", ".join(f"{b['letter']}({os.path.basename(b['path'])})" for b in bundles))
# MediaPipe Hands
hands = mp.solutions.hands.Hands(
static_image_mode=False, max_num_hands=1, min_detection_confidence=0.5
)
# Camera
cap = cv2.VideoCapture(args.camera)
if not cap.isOpened():
raise SystemExit(f"❌ Could not open camera index {args.camera}")
cap.set(cv2.CAP_PROP_FRAME_WIDTH, args.width)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, args.height)
print("Press 'q' to quit.")
while True:
ok, frame = cap.read()
if not ok:
break
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
res = hands.process(rgb)
overlay = frame.copy()
label_text = "No hand"
scoreboard = []
if res.multi_hand_landmarks:
ih = res.multi_hand_landmarks[0]
handed = None
if res.multi_handedness:
handed = res.multi_handedness[0].classification[0].label
pts = np.array([[lm.x, lm.y, lm.z] for lm in ih.landmark], dtype=np.float32)
feat = normalize_landmarks(pts, handedness_label=handed)
# Evaluate each model
best_letter, best_prob = None, -1.0
for b in bundles:
X_mean = b["X_mean"].flatten()
X_std = b["X_std"].flatten()
xn = (feat - X_mean) / X_std
xt = torch.from_numpy(xn).float().unsqueeze(0).to(device)
with torch.no_grad():
probs = torch.softmax(b["model"](xt), dim=1)[0].cpu().numpy()
p_pos = float(probs[b["pos_index"]])
scoreboard.append((b["letter"], p_pos))
if p_pos > best_prob:
best_prob = p_pos
best_letter = b["letter"]
# Compose label based on threshold
if best_prob >= args.threshold:
label_text = f"{best_letter} {best_prob*100:.1f}%"
else:
label_text = f"Unknown ({best_letter} {best_prob*100:.1f}%)"
# Sort scoreboard desc and show top 3
scoreboard.sort(key=lambda x: x[1], reverse=True)
y0 = 80
put_text(overlay, "Scores:", (20, y0), scale=0.9, color=(0,255,255), thick=2)
y = y0 + 30
for i, (L, p) in enumerate(scoreboard[:3]):
put_text(overlay, f"{L}: {p*100:.1f}%", (20, y), scale=0.9, color=(0,255,0), thick=2)
y += 28
put_text(overlay, label_text, (20, 40), scale=1.2, color=(0,255,0), thick=3)
cv2.imshow("ASL multi-letter demo", overlay)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
if __name__ == "__main__":
main()