138 lines
4.7 KiB
Python
Executable File
138 lines
4.7 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""
|
||
infer_webcam.py
|
||
Live webcam demo: detect a hand with MediaPipe, normalize landmarks,
|
||
classify with a trained MLP model.
|
||
|
||
Examples:
|
||
python infer_webcam.py --letter A # loads asl_A_mlp.pt
|
||
python infer_webcam.py --letter B # loads asl_B_mlp.pt
|
||
python infer_webcam.py --model /path/to/asl_A_mlp.pt
|
||
|
||
Press 'q' to quit.
|
||
"""
|
||
import os, math, argparse
|
||
import numpy as np
|
||
import cv2
|
||
import torch
|
||
import mediapipe as mp
|
||
|
||
# ---------- geometry helpers ----------
|
||
def _angle(v): return math.atan2(v[1], v[0])
|
||
def _rot2d(t):
|
||
c, s = math.cos(t), math.sin(t)
|
||
return np.array([[c, -s], [s, c]], dtype=np.float32)
|
||
|
||
def normalize_landmarks(pts, handedness_label=None):
|
||
pts = pts.astype(np.float32).copy()
|
||
# translate wrist to origin
|
||
pts[:, :2] -= pts[0, :2]
|
||
# mirror left→right
|
||
if handedness_label and handedness_label.lower().startswith("left"):
|
||
pts[:, 0] *= -1.0
|
||
# rotate wrist→middle_mcp to +Y
|
||
v = pts[9, :2]
|
||
R = _rot2d(math.pi/2 - _angle(v))
|
||
pts[:, :2] = pts[:, :2] @ R.T
|
||
# scale by max pairwise distance
|
||
xy = pts[:, :2]
|
||
d = np.linalg.norm(xy[None,:,:] - xy[:,None,:], axis=-1).max()
|
||
d = 1.0 if d < 1e-6 else float(d)
|
||
pts[:, :2] /= d; pts[:, 2] /= d
|
||
return pts.reshape(-1)
|
||
|
||
# ---------- model ----------
|
||
class MLP(torch.nn.Module):
|
||
def __init__(self, in_dim, num_classes):
|
||
super().__init__()
|
||
self.net = torch.nn.Sequential(
|
||
torch.nn.Linear(in_dim, 128),
|
||
torch.nn.ReLU(),
|
||
torch.nn.Dropout(0.2),
|
||
torch.nn.Linear(128, 64),
|
||
torch.nn.ReLU(),
|
||
torch.nn.Dropout(0.1),
|
||
torch.nn.Linear(64, num_classes),
|
||
)
|
||
def forward(self, x): return self.net(x)
|
||
|
||
# ---------- main ----------
|
||
def main():
|
||
ap = argparse.ArgumentParser()
|
||
grp = ap.add_mutually_exclusive_group(required=True)
|
||
grp.add_argument("--letter", help="Target letter (A–Z). Loads asl_<LETTER>_mlp.pt")
|
||
grp.add_argument("--model", help="Path to trained .pt model (overrides --letter)")
|
||
ap.add_argument("--camera", type=int, default=0, help="OpenCV camera index (default: 0)")
|
||
args = ap.parse_args()
|
||
|
||
# Resolve model path
|
||
model_path = args.model
|
||
if model_path is None:
|
||
letter = args.letter.upper()
|
||
model_path = f"asl_{letter}_mlp.pt"
|
||
|
||
if not os.path.exists(model_path):
|
||
raise SystemExit(f"❌ Model file not found: {model_path}")
|
||
|
||
# Load state (allowing tensors or numpy inside; disable weights-only safety)
|
||
state = torch.load(model_path, map_location="cpu", weights_only=False)
|
||
classes = state["classes"]
|
||
X_mean = state["X_mean"]
|
||
X_std = state["X_std"]
|
||
|
||
# Convert X_mean/X_std to numpy no matter how they were saved
|
||
if isinstance(X_mean, torch.Tensor): X_mean = X_mean.cpu().numpy()
|
||
if isinstance(X_std, torch.Tensor): X_std = X_std.cpu().numpy()
|
||
X_mean = np.asarray(X_mean, dtype=np.float32)
|
||
X_std = np.asarray(X_std, dtype=np.float32) + 1e-6
|
||
|
||
model = MLP(63, len(classes))
|
||
model.load_state_dict(state["model"])
|
||
model.eval()
|
||
|
||
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
|
||
model.to(device)
|
||
|
||
hands = mp.solutions.hands.Hands(
|
||
static_image_mode=False, max_num_hands=1, min_detection_confidence=0.5
|
||
)
|
||
|
||
cap = cv2.VideoCapture(args.camera)
|
||
if not cap.isOpened():
|
||
raise SystemExit(f"❌ Could not open camera index {args.camera}")
|
||
|
||
print(f"✅ Loaded {model_path} with classes {classes}")
|
||
print("Press 'q' to quit.")
|
||
while True:
|
||
ok, frame = cap.read()
|
||
if not ok: break
|
||
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||
res = hands.process(rgb)
|
||
|
||
label_text = "No hand"
|
||
if res.multi_hand_landmarks:
|
||
ih = res.multi_hand_landmarks[0]
|
||
handed = None
|
||
if res.multi_handedness:
|
||
handed = res.multi_handedness[0].classification[0].label
|
||
pts = np.array([[lm.x, lm.y, lm.z] for lm in ih.landmark], dtype=np.float32)
|
||
feat = normalize_landmarks(pts, handedness_label=handed)
|
||
# standardize
|
||
xn = (feat - X_mean.flatten()) / X_std.flatten()
|
||
xt = torch.from_numpy(xn).float().unsqueeze(0).to(device)
|
||
with torch.no_grad():
|
||
probs = torch.softmax(model(xt), dim=1)[0].cpu().numpy()
|
||
idx = int(probs.argmax())
|
||
label_text = f"{classes[idx]} {probs[idx]*100:.1f}%"
|
||
|
||
cv2.putText(frame, label_text, (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.1, (0,255,0), 2)
|
||
cv2.imshow("ASL handshape demo", frame)
|
||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||
break
|
||
|
||
cap.release()
|
||
cv2.destroyAllWindows()
|
||
|
||
if __name__ == "__main__":
|
||
main()
|