Initial commit: handshapes multiclass project
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
198
first_attempt_landmark_hands/infer_seq_webcam.py
Executable file
198
first_attempt_landmark_hands/infer_seq_webcam.py
Executable file
@@ -0,0 +1,198 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
infer_seq_webcam.py
|
||||
Live webcam demo: detect a hand with MediaPipe, normalize landmarks,
|
||||
classify with a trained sequence GRU model (multiclass).
|
||||
|
||||
Examples:
|
||||
python infer_seq_webcam.py --model asl_seq32_gru_ABJZ.pt --threshold 0.8 --smooth 0.7
|
||||
python infer_seq_webcam.py --model asl_seq32_gru_ABJZ.pt --threshold 0.85 --smooth 1.0 --url https://www.google.com
|
||||
"""
|
||||
|
||||
import os, math, argparse, time, webbrowser
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import mediapipe as mp
|
||||
|
||||
# --- Quiet logs ---
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
||||
os.environ["GLOG_minloglevel"] = "2"
|
||||
import absl.logging
|
||||
absl.logging.set_verbosity(absl.logging.ERROR)
|
||||
cv2.setLogLevel(0)
|
||||
|
||||
# ---------- geometry helpers ----------
|
||||
def _angle(v): return math.atan2(v[1], v[0])
|
||||
def _rot2d(t):
|
||||
c, s = math.cos(t), math.sin(t)
|
||||
return np.array([[c, -s], [s, c]], dtype=np.float32)
|
||||
|
||||
def normalize_landmarks(pts, handedness_label=None):
|
||||
"""
|
||||
pts: (21,3) MediaPipe normalized coords in [0..1]
|
||||
Steps: translate wrist->origin, mirror left to right, rotate to +Y, scale by max pairwise distance.
|
||||
Returns: (63,) float32
|
||||
"""
|
||||
pts = pts.astype(np.float32).copy()
|
||||
pts[:, :2] -= pts[0, :2]
|
||||
if handedness_label and handedness_label.lower().startswith("left"):
|
||||
pts[:, 0] *= -1.0
|
||||
v = pts[9, :2] # middle MCP
|
||||
R = _rot2d(math.pi/2 - _angle(v))
|
||||
pts[:, :2] = pts[:, :2] @ R.T
|
||||
xy = pts[:, :2]
|
||||
d = np.linalg.norm(xy[None,:,:] - xy[:,None,:], axis=-1).max()
|
||||
d = 1.0 if d < 1e-6 else float(d)
|
||||
pts[:, :2] /= d; pts[:, 2] /= d
|
||||
return pts.reshape(-1)
|
||||
|
||||
# ---------- sequence model ----------
|
||||
class SeqGRU(torch.nn.Module):
|
||||
def __init__(self, input_dim=63, hidden=128, num_classes=26):
|
||||
super().__init__()
|
||||
self.gru = torch.nn.GRU(input_dim, hidden, batch_first=True, bidirectional=True)
|
||||
self.head = torch.nn.Sequential(
|
||||
torch.nn.Linear(hidden*2, 128),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Dropout(0.2),
|
||||
torch.nn.Linear(128, num_classes),
|
||||
)
|
||||
def forward(self, x):
|
||||
h, _ = self.gru(x) # (B,T,2H)
|
||||
h_last = h[:, -1, :] # or h.mean(1)
|
||||
return self.head(h_last)
|
||||
|
||||
# ---------- main ----------
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", required=True, help="Path to trained .pt model")
|
||||
ap.add_argument("--camera", type=int, default=0)
|
||||
ap.add_argument("--threshold", type=float, default=0.8)
|
||||
ap.add_argument("--smooth", type=float, default=0.7,
|
||||
help="EMA smoothing window in seconds (0 disables smoothing)")
|
||||
ap.add_argument("--width", type=int, default=640)
|
||||
ap.add_argument("--height", type=int, default=480)
|
||||
ap.add_argument("--url", type=str, default="https://www.google.com",
|
||||
help="URL to open when the sequence W→E→B is detected")
|
||||
args = ap.parse_args()
|
||||
|
||||
if not os.path.exists(args.model):
|
||||
raise SystemExit(f"❌ Model file not found: {args.model}")
|
||||
|
||||
# Load checkpoint (support numpy or tensor stats; support 'frames' if present)
|
||||
state = torch.load(args.model, map_location="cpu", weights_only=False)
|
||||
classes = state["classes"]
|
||||
T = int(state.get("frames", 32))
|
||||
|
||||
X_mean, X_std = state["X_mean"], state["X_std"]
|
||||
if isinstance(X_mean, torch.Tensor): X_mean = X_mean.cpu().numpy()
|
||||
if isinstance(X_std, torch.Tensor): X_std = X_std.cpu().numpy()
|
||||
X_mean = X_mean.astype(np.float32)
|
||||
X_std = (X_std.astype(np.float32) + 1e-6)
|
||||
|
||||
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
|
||||
model = SeqGRU(63, 128, num_classes=len(classes)).to(device)
|
||||
model.load_state_dict(state["model"])
|
||||
model.eval()
|
||||
|
||||
hands = mp.solutions.hands.Hands(
|
||||
static_image_mode=False, max_num_hands=1, min_detection_confidence=0.5
|
||||
)
|
||||
|
||||
cap = cv2.VideoCapture(args.camera)
|
||||
if not cap.isOpened():
|
||||
raise SystemExit(f"❌ Could not open camera index {args.camera}")
|
||||
cap.set(cv2.CAP_PROP_FRAME_WIDTH, args.width)
|
||||
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, args.height)
|
||||
|
||||
print(f"✅ Loaded {args.model} frames={T} classes={classes}")
|
||||
print("Press 'q' to quit.")
|
||||
|
||||
seq_buffer, ema_probs = [], None
|
||||
last_ts = time.time()
|
||||
last_emitted_letter = None
|
||||
|
||||
# Rolling history of emitted letters to detect the sequence "WEB"
|
||||
detected_history = [] # only stores emitted letters (deduped by change)
|
||||
|
||||
while True:
|
||||
ok, frame = cap.read()
|
||||
if not ok: break
|
||||
now = time.time()
|
||||
dt = max(1e-6, now - last_ts)
|
||||
last_ts = now
|
||||
|
||||
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
res = hands.process(rgb)
|
||||
|
||||
overlay_text = "No hand"
|
||||
current_letter = None
|
||||
|
||||
if res.multi_hand_landmarks:
|
||||
ih = res.multi_hand_landmarks[0]
|
||||
handed = None
|
||||
if res.multi_handedness:
|
||||
handed = res.multi_handedness[0].classification[0].label
|
||||
pts = np.array([[lm.x, lm.y, lm.z] for lm in ih.landmark], dtype=np.float32)
|
||||
feat = normalize_landmarks(pts, handedness_label=handed)
|
||||
seq_buffer.append(feat)
|
||||
if len(seq_buffer) > T: seq_buffer.pop(0)
|
||||
|
||||
if len(seq_buffer) == T:
|
||||
X = np.stack(seq_buffer, 0)
|
||||
Xn = (X - X_mean) / X_std
|
||||
xt = torch.from_numpy(Xn).float().unsqueeze(0).to(device)
|
||||
with torch.no_grad():
|
||||
logits = model(xt)
|
||||
probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
|
||||
|
||||
if args.smooth > 0:
|
||||
alpha = 1.0 - math.exp(-dt / args.smooth)
|
||||
if ema_probs is None: ema_probs = probs
|
||||
else: ema_probs = (1.0 - alpha) * ema_probs + alpha * probs
|
||||
use_probs = ema_probs
|
||||
else:
|
||||
use_probs = probs
|
||||
|
||||
top_idx = int(np.argmax(use_probs))
|
||||
top_p = float(use_probs[top_idx])
|
||||
top_cls = classes[top_idx]
|
||||
|
||||
if top_p >= args.threshold:
|
||||
overlay_text = f"{top_cls} {top_p*100:.1f}%"
|
||||
current_letter = top_cls
|
||||
else:
|
||||
seq_buffer, ema_probs = [], None
|
||||
|
||||
# Only emit when a *letter* changes (ignore no-hand and repeats)
|
||||
if current_letter is not None and current_letter != last_emitted_letter:
|
||||
print(f"Detected: {current_letter}")
|
||||
last_emitted_letter = current_letter
|
||||
|
||||
# Update rolling history
|
||||
detected_history.append(current_letter)
|
||||
if len(detected_history) > 3:
|
||||
detected_history.pop(0)
|
||||
|
||||
# Check for special sequence "WEB"
|
||||
if detected_history == ["W", "E", "B"]:
|
||||
print("🚀 Detected WEB! Time to open the web browser app.")
|
||||
try:
|
||||
webbrowser.open(args.url)
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to open browser: {e}")
|
||||
detected_history.clear() # fire once per occurrence
|
||||
|
||||
# On-screen overlay (still shows "No hand" when nothing is detected)
|
||||
cv2.putText(frame, overlay_text, (20, 40),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 1.1, (0,255,0), 2)
|
||||
cv2.imshow("ASL sequence demo", frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user