Initial commit: ASL handshape recognition project
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
137
infer_webcam.py
Executable file
137
infer_webcam.py
Executable file
@@ -0,0 +1,137 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
infer_webcam.py
|
||||
Live webcam demo: detect a hand with MediaPipe, normalize landmarks,
|
||||
classify with a trained MLP model.
|
||||
|
||||
Examples:
|
||||
python infer_webcam.py --letter A # loads asl_A_mlp.pt
|
||||
python infer_webcam.py --letter B # loads asl_B_mlp.pt
|
||||
python infer_webcam.py --model /path/to/asl_A_mlp.pt
|
||||
|
||||
Press 'q' to quit.
|
||||
"""
|
||||
import os, math, argparse
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import mediapipe as mp
|
||||
|
||||
# ---------- geometry helpers ----------
|
||||
def _angle(v): return math.atan2(v[1], v[0])
|
||||
def _rot2d(t):
|
||||
c, s = math.cos(t), math.sin(t)
|
||||
return np.array([[c, -s], [s, c]], dtype=np.float32)
|
||||
|
||||
def normalize_landmarks(pts, handedness_label=None):
|
||||
pts = pts.astype(np.float32).copy()
|
||||
# translate wrist to origin
|
||||
pts[:, :2] -= pts[0, :2]
|
||||
# mirror left→right
|
||||
if handedness_label and handedness_label.lower().startswith("left"):
|
||||
pts[:, 0] *= -1.0
|
||||
# rotate wrist→middle_mcp to +Y
|
||||
v = pts[9, :2]
|
||||
R = _rot2d(math.pi/2 - _angle(v))
|
||||
pts[:, :2] = pts[:, :2] @ R.T
|
||||
# scale by max pairwise distance
|
||||
xy = pts[:, :2]
|
||||
d = np.linalg.norm(xy[None,:,:] - xy[:,None,:], axis=-1).max()
|
||||
d = 1.0 if d < 1e-6 else float(d)
|
||||
pts[:, :2] /= d; pts[:, 2] /= d
|
||||
return pts.reshape(-1)
|
||||
|
||||
# ---------- model ----------
|
||||
class MLP(torch.nn.Module):
|
||||
def __init__(self, in_dim, num_classes):
|
||||
super().__init__()
|
||||
self.net = torch.nn.Sequential(
|
||||
torch.nn.Linear(in_dim, 128),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Dropout(0.2),
|
||||
torch.nn.Linear(128, 64),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Dropout(0.1),
|
||||
torch.nn.Linear(64, num_classes),
|
||||
)
|
||||
def forward(self, x): return self.net(x)
|
||||
|
||||
# ---------- main ----------
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
grp = ap.add_mutually_exclusive_group(required=True)
|
||||
grp.add_argument("--letter", help="Target letter (A–Z). Loads asl_<LETTER>_mlp.pt")
|
||||
grp.add_argument("--model", help="Path to trained .pt model (overrides --letter)")
|
||||
ap.add_argument("--camera", type=int, default=0, help="OpenCV camera index (default: 0)")
|
||||
args = ap.parse_args()
|
||||
|
||||
# Resolve model path
|
||||
model_path = args.model
|
||||
if model_path is None:
|
||||
letter = args.letter.upper()
|
||||
model_path = f"asl_{letter}_mlp.pt"
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
raise SystemExit(f"❌ Model file not found: {model_path}")
|
||||
|
||||
# Load state (allowing tensors or numpy inside; disable weights-only safety)
|
||||
state = torch.load(model_path, map_location="cpu", weights_only=False)
|
||||
classes = state["classes"]
|
||||
X_mean = state["X_mean"]
|
||||
X_std = state["X_std"]
|
||||
|
||||
# Convert X_mean/X_std to numpy no matter how they were saved
|
||||
if isinstance(X_mean, torch.Tensor): X_mean = X_mean.cpu().numpy()
|
||||
if isinstance(X_std, torch.Tensor): X_std = X_std.cpu().numpy()
|
||||
X_mean = np.asarray(X_mean, dtype=np.float32)
|
||||
X_std = np.asarray(X_std, dtype=np.float32) + 1e-6
|
||||
|
||||
model = MLP(63, len(classes))
|
||||
model.load_state_dict(state["model"])
|
||||
model.eval()
|
||||
|
||||
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
|
||||
model.to(device)
|
||||
|
||||
hands = mp.solutions.hands.Hands(
|
||||
static_image_mode=False, max_num_hands=1, min_detection_confidence=0.5
|
||||
)
|
||||
|
||||
cap = cv2.VideoCapture(args.camera)
|
||||
if not cap.isOpened():
|
||||
raise SystemExit(f"❌ Could not open camera index {args.camera}")
|
||||
|
||||
print(f"✅ Loaded {model_path} with classes {classes}")
|
||||
print("Press 'q' to quit.")
|
||||
while True:
|
||||
ok, frame = cap.read()
|
||||
if not ok: break
|
||||
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
res = hands.process(rgb)
|
||||
|
||||
label_text = "No hand"
|
||||
if res.multi_hand_landmarks:
|
||||
ih = res.multi_hand_landmarks[0]
|
||||
handed = None
|
||||
if res.multi_handedness:
|
||||
handed = res.multi_handedness[0].classification[0].label
|
||||
pts = np.array([[lm.x, lm.y, lm.z] for lm in ih.landmark], dtype=np.float32)
|
||||
feat = normalize_landmarks(pts, handedness_label=handed)
|
||||
# standardize
|
||||
xn = (feat - X_mean.flatten()) / X_std.flatten()
|
||||
xt = torch.from_numpy(xn).float().unsqueeze(0).to(device)
|
||||
with torch.no_grad():
|
||||
probs = torch.softmax(model(xt), dim=1)[0].cpu().numpy()
|
||||
idx = int(probs.argmax())
|
||||
label_text = f"{classes[idx]} {probs[idx]*100:.1f}%"
|
||||
|
||||
cv2.putText(frame, label_text, (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.1, (0,255,0), 2)
|
||||
cv2.imshow("ASL handshape demo", frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user