feat add remote hand tracking backend
This commit is contained in:
@@ -0,0 +1,73 @@
|
||||
# Hand Tracking Backend
|
||||
|
||||
Remote-compatible Python backend for La-Fabrik hand tracking.
|
||||
|
||||
The browser captures webcam frames, downsizes them, sends JPEG frames to this backend over WebSocket, and receives hand landmarks plus pinch state.
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install -r backend/requirements.txt
|
||||
python3 backend/download_model.py
|
||||
```
|
||||
|
||||
## Run
|
||||
|
||||
```bash
|
||||
python3 -m backend.main
|
||||
```
|
||||
|
||||
The WebSocket endpoint is:
|
||||
|
||||
```txt
|
||||
ws://localhost:8000/ws
|
||||
```
|
||||
|
||||
## Health Check
|
||||
|
||||
```txt
|
||||
http://localhost:8000/health
|
||||
```
|
||||
|
||||
## Message Flow
|
||||
|
||||
Client sends a compressed frame:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "frame",
|
||||
"timestamp": 1234567890,
|
||||
"width": 320,
|
||||
"height": 240,
|
||||
"image": "base64-jpeg"
|
||||
}
|
||||
```
|
||||
|
||||
Server responds with detected hands:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "hands",
|
||||
"timestamp": 1234567890,
|
||||
"hands": [
|
||||
{
|
||||
"x": 0.5,
|
||||
"y": 0.3,
|
||||
"z": 0.1,
|
||||
"handedness": "Right",
|
||||
"isPinch": true,
|
||||
"pinchDistance": 0.05,
|
||||
"score": 0.92
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- The backend does not read `cv2.VideoCapture(0)`.
|
||||
- This keeps local development and production behavior aligned.
|
||||
- Each browser connection sends its own webcam frames.
|
||||
- The backend rate-limits frames per connection and drops work when a client is already being processed.
|
||||
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientConnection:
|
||||
id: str
|
||||
websocket: WebSocket
|
||||
is_processing: bool = False
|
||||
last_frame_at: float = 0.0
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self) -> None:
|
||||
self._connections: dict[str, ClientConnection] = {}
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
return len(self._connections)
|
||||
|
||||
async def connect(self, websocket: WebSocket) -> ClientConnection:
|
||||
await websocket.accept()
|
||||
connection = ClientConnection(id=str(uuid4()), websocket=websocket)
|
||||
self._connections[connection.id] = connection
|
||||
return connection
|
||||
|
||||
def disconnect(self, connection: ClientConnection) -> None:
|
||||
self._connections.pop(connection.id, None)
|
||||
|
||||
async def send(self, connection: ClientConnection, payload: dict[str, Any]) -> None:
|
||||
await connection.websocket.send_json(payload)
|
||||
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
|
||||
MODEL_URL = "https://storage.googleapis.com/mediapipe-models/hand_landmarker/hand_landmarker/float16/1/hand_landmarker.task"
|
||||
MODEL_PATH = Path(__file__).with_name("hand_landmarker.task")
|
||||
|
||||
|
||||
def download_model() -> None:
|
||||
if MODEL_PATH.exists():
|
||||
print(f"Model already exists at {MODEL_PATH}")
|
||||
return
|
||||
|
||||
print("Downloading MediaPipe Hand Landmarker model...")
|
||||
urlretrieve(MODEL_URL, MODEL_PATH)
|
||||
print(f"Model downloaded to {MODEL_PATH}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
download_model()
|
||||
@@ -0,0 +1,107 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import mediapipe as mp
|
||||
import numpy as np
|
||||
from mediapipe.tasks import python
|
||||
from mediapipe.tasks.python import vision
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HandData:
|
||||
x: float
|
||||
y: float
|
||||
z: float
|
||||
handedness: str
|
||||
is_pinch: bool
|
||||
pinch_distance: float
|
||||
score: float
|
||||
|
||||
def to_payload(self) -> dict[str, float | str | bool]:
|
||||
return {
|
||||
"x": self.x,
|
||||
"y": self.y,
|
||||
"z": self.z,
|
||||
"handedness": self.handedness,
|
||||
"isPinch": self.is_pinch,
|
||||
"pinchDistance": self.pinch_distance,
|
||||
"score": self.score,
|
||||
}
|
||||
|
||||
|
||||
class HandTracker:
|
||||
def __init__(self, max_hands: int = 2) -> None:
|
||||
model_path = Path(__file__).with_name("hand_landmarker.task")
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(
|
||||
"Missing hand_landmarker.task. Run `python backend/download_model.py`.",
|
||||
)
|
||||
|
||||
base_options = python.BaseOptions(model_asset_path=str(model_path))
|
||||
options = vision.HandLandmarkerOptions(
|
||||
base_options=base_options,
|
||||
running_mode=vision.RunningMode.IMAGE,
|
||||
num_hands=max_hands,
|
||||
)
|
||||
self._detector = vision.HandLandmarker.create_from_options(options)
|
||||
|
||||
def detect_from_base64_jpeg(self, image_base64: str) -> list[HandData]:
|
||||
image_data = base64.b64decode(image_base64, validate=True)
|
||||
image_buffer = np.frombuffer(image_data, dtype=np.uint8)
|
||||
frame = cv2.imdecode(image_buffer, cv2.IMREAD_COLOR)
|
||||
if frame is None:
|
||||
raise ValueError("Invalid JPEG frame")
|
||||
|
||||
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb_frame)
|
||||
result = self._detector.detect(mp_image)
|
||||
return self._to_hands(result)
|
||||
|
||||
def close(self) -> None:
|
||||
self._detector.close()
|
||||
|
||||
def _to_hands(self, result: vision.HandLandmarkerResult) -> list[HandData]:
|
||||
hands: list[HandData] = []
|
||||
if not result.hand_landmarks or not result.handedness:
|
||||
return hands
|
||||
|
||||
for landmarks, handedness_categories in zip(
|
||||
result.hand_landmarks,
|
||||
result.handedness,
|
||||
):
|
||||
index_tip = landmarks[8]
|
||||
thumb_tip = landmarks[4]
|
||||
pinch_distance = self._calculate_distance(index_tip, thumb_tip)
|
||||
handedness = handedness_categories[0]
|
||||
|
||||
hands.append(
|
||||
HandData(
|
||||
x=index_tip.x,
|
||||
y=index_tip.y,
|
||||
z=index_tip.z,
|
||||
handedness=handedness.category_name,
|
||||
is_pinch=pinch_distance < 0.07,
|
||||
pinch_distance=pinch_distance,
|
||||
score=handedness.score,
|
||||
),
|
||||
)
|
||||
|
||||
return hands
|
||||
|
||||
def _calculate_distance(self, point_a: Any, point_b: Any) -> float:
|
||||
return math.sqrt(
|
||||
(point_a.x - point_b.x) ** 2
|
||||
+ (point_a.y - point_b.y) ** 2
|
||||
+ (point_a.z - point_b.z) ** 2,
|
||||
)
|
||||
|
||||
|
||||
def now_ms() -> int:
|
||||
return time.monotonic_ns() // 1_000_000
|
||||
+122
@@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from backend.connection_manager import ClientConnection, ConnectionManager
|
||||
from backend.hand_tracker import HandTracker, now_ms
|
||||
|
||||
|
||||
MAX_FRAME_BYTES = 220_000
|
||||
MIN_FRAME_INTERVAL_SECONDS = 0.08
|
||||
|
||||
manager = ConnectionManager()
|
||||
tracker: HandTracker | None = None
|
||||
detection_lock = asyncio.Lock()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global tracker
|
||||
tracker = HandTracker(max_hands=2)
|
||||
yield
|
||||
if tracker:
|
||||
tracker.close()
|
||||
|
||||
|
||||
app = FastAPI(title="La-Fabrik Hand Tracking", lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> JSONResponse:
|
||||
return JSONResponse(
|
||||
{
|
||||
"status": "ok",
|
||||
"connections": manager.count,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket) -> None:
|
||||
connection = await manager.connect(websocket)
|
||||
await manager.send(connection, status_payload("connected"))
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive_json()
|
||||
response = await handle_message(connection, message)
|
||||
await manager.send(connection, response)
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(connection)
|
||||
except Exception as error:
|
||||
await manager.send(connection, error_payload(str(error)))
|
||||
manager.disconnect(connection)
|
||||
|
||||
|
||||
async def handle_message(
|
||||
connection: ClientConnection,
|
||||
message: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
if message.get("type") != "frame":
|
||||
return error_payload("Unsupported message type")
|
||||
|
||||
current_time = asyncio.get_running_loop().time()
|
||||
if current_time - connection.last_frame_at < MIN_FRAME_INTERVAL_SECONDS:
|
||||
return status_payload("rate_limited")
|
||||
|
||||
if connection.is_processing:
|
||||
return status_payload("busy")
|
||||
|
||||
image = message.get("image")
|
||||
if not isinstance(image, str):
|
||||
return error_payload("Missing image payload")
|
||||
|
||||
if len(image) > MAX_FRAME_BYTES:
|
||||
return error_payload("Frame payload too large")
|
||||
|
||||
if tracker is None:
|
||||
return error_payload("Hand tracker is not ready")
|
||||
|
||||
if detection_lock.locked():
|
||||
return status_payload("busy")
|
||||
|
||||
connection.last_frame_at = current_time
|
||||
connection.is_processing = True
|
||||
try:
|
||||
async with detection_lock:
|
||||
hands = await asyncio.to_thread(tracker.detect_from_base64_jpeg, image)
|
||||
return {
|
||||
"type": "hands",
|
||||
"timestamp": now_ms(),
|
||||
"hands": [hand.to_payload() for hand in hands],
|
||||
}
|
||||
finally:
|
||||
connection.is_processing = False
|
||||
|
||||
|
||||
def status_payload(status: str) -> dict[str, str | int]:
|
||||
return {
|
||||
"type": "status",
|
||||
"timestamp": now_ms(),
|
||||
"status": status,
|
||||
}
|
||||
|
||||
|
||||
def error_payload(message: str) -> dict[str, str | int | list[Any]]:
|
||||
return {
|
||||
"type": "error",
|
||||
"timestamp": now_ms(),
|
||||
"hands": [],
|
||||
"message": message,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
@@ -0,0 +1,5 @@
|
||||
fastapi==0.115.0
|
||||
uvicorn[standard]==0.30.6
|
||||
opencv-python-headless==4.10.0.84
|
||||
mediapipe==0.10.20
|
||||
numpy==1.26.4
|
||||
Reference in New Issue
Block a user