Files
vikunja/pkg/websocket/connection.go
kolaente 9255fe07a9 feat(websocket): add message types, connection hub, and connection handler
Add the core WebSocket infrastructure:

- Message type definitions for the wire protocol (subscribe, unsubscribe,
  auth, error, push events)
- In-memory connection hub that tracks per-user connections and routes
  messages to subscribed clients
- Connection wrapper with auth-after-connect flow: connections start
  unauthenticated, client sends JWT as first message, only then can
  subscribe to event topics

Includes auth timeout (30s), shared cancellation context for read/write
loops, hub map cleanup on last connection removal, and proper error
delivery before closing on auth failure.
2026-04-02 16:30:23 +00:00

272 lines
7.2 KiB
Go

// Vikunja is a to-do list application to facilitate your life.
// Copyright 2018-present Vikunja and contributors. All rights reserved.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package websocket
import (
"context"
"encoding/json"
"sync"
"time"
"code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/modules/auth"
"github.com/coder/websocket"
)
const (
writeTimeout = 10 * time.Second
pingInterval = 30 * time.Second
authTimeout = 30 * time.Second
sendBufSize = 64
)
// Connection wraps a single WebSocket connection.
type Connection struct {
ws *websocket.Conn
hub *Hub
mu sync.RWMutex
userID int64
authenticated bool
subscriptions map[string]bool
send chan OutgoingMessage
}
// NewConnection creates a new unauthenticated Connection.
func NewConnection(ws *websocket.Conn, hub *Hub) *Connection {
return &Connection{
ws: ws,
hub: hub,
authenticated: false,
subscriptions: make(map[string]bool),
send: make(chan OutgoingMessage, sendBufSize),
}
}
// Subscribe adds an event subscription.
func (c *Connection) Subscribe(event string) {
c.mu.Lock()
defer c.mu.Unlock()
c.subscriptions[event] = true
}
// Unsubscribe removes an event subscription.
func (c *Connection) Unsubscribe(event string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.subscriptions, event)
}
// IsSubscribed checks if the connection is subscribed to an event.
func (c *Connection) IsSubscribed(event string) bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.subscriptions[event]
}
// IsAuthenticated returns whether the connection is authenticated.
func (c *Connection) IsAuthenticated() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.authenticated
}
// UserID returns the authenticated user's ID.
func (c *Connection) UserID() int64 {
c.mu.RLock()
defer c.mu.RUnlock()
return c.userID
}
// ReadLoop reads messages from the WebSocket and handles auth/subscribe/unsubscribe.
func (c *Connection) ReadLoop(ctx context.Context, cancel context.CancelFunc) {
defer func() {
cancel()
if c.IsAuthenticated() {
c.hub.Unregister(c)
}
c.ws.Close(websocket.StatusNormalClosure, "")
}()
// Close the connection if auth doesn't happen within the timeout
authTimer := time.AfterFunc(authTimeout, func() {
if !c.IsAuthenticated() {
log.Debugf("WebSocket: closing unauthenticated connection after timeout")
c.ws.Close(websocket.StatusPolicyViolation, "auth timeout")
}
})
defer authTimer.Stop()
for {
_, data, err := c.ws.Read(ctx)
if err != nil {
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
log.Debugf("WebSocket: connection closed normally for user %d", c.UserID())
} else {
log.Debugf("WebSocket: read error for user %d: %v", c.UserID(), err)
}
return
}
var msg IncomingMessage
if err := json.Unmarshal(data, &msg); err != nil {
log.Warningf("WebSocket: invalid message: %v", err)
continue
}
if !c.handleMessage(ctx, msg) {
return // close connection
}
}
}
// handleMessage processes an incoming message. Returns false if connection should be closed.
func (c *Connection) handleMessage(ctx context.Context, msg IncomingMessage) bool {
switch msg.Action {
case ActionAuth:
return c.handleAuth(ctx, msg.Token)
case ActionSubscribe:
if !c.IsAuthenticated() {
c.sendError("auth_required", "")
return true
}
if !isValidEvent(msg.Event) {
c.sendError("invalid_event", msg.Event)
return true
}
c.Subscribe(msg.Event)
log.Debugf("WebSocket: user %d subscribed to %s", c.UserID(), msg.Event)
case ActionUnsubscribe:
if !c.IsAuthenticated() {
c.sendError("auth_required", "")
return true
}
c.Unsubscribe(msg.Event)
log.Debugf("WebSocket: user %d unsubscribed from %s", c.UserID(), msg.Event)
default:
log.Warningf("WebSocket: unknown action %q", msg.Action)
}
return true
}
func (c *Connection) handleAuth(ctx context.Context, token string) bool {
if c.IsAuthenticated() {
c.sendError("already_authenticated", "")
return true
}
userID, err := auth.GetUserIDFromToken(token)
if err != nil {
log.Debugf("WebSocket: auth failed: %v", err)
// Write the error directly to the websocket since ReadLoop will close the
// connection immediately after we return false, before WriteLoop can drain the channel.
c.writeMessageDirect(ctx, OutgoingMessage{Error: "invalid_token"})
return false
}
c.mu.Lock()
c.userID = userID
c.authenticated = true
c.mu.Unlock()
c.hub.Register(c)
// Send auth success
select {
case c.send <- OutgoingMessage{Action: ActionAuthSuccess, Success: true}:
default:
log.Warningf("WebSocket: send buffer full for user %d", userID)
}
log.Debugf("WebSocket: user %d authenticated", userID)
return true
}
// writeMessageDirect writes a message directly to the websocket, bypassing the send channel.
// Use this when the message must be sent before the connection is closed.
func (c *Connection) writeMessageDirect(ctx context.Context, msg OutgoingMessage) {
writeCtx, cancel := context.WithTimeout(ctx, writeTimeout)
defer cancel()
data, err := json.Marshal(msg)
if err != nil {
log.Errorf("WebSocket: marshal error: %v", err)
return
}
if err := c.ws.Write(writeCtx, websocket.MessageText, data); err != nil {
log.Debugf("WebSocket: direct write error: %v", err)
}
}
func (c *Connection) sendError(errMsg, event string) {
select {
case c.send <- OutgoingMessage{Error: errMsg, Event: event}:
default:
log.Warningf("WebSocket: send buffer full, dropping error")
}
}
// WriteLoop drains the send channel and writes messages to the WebSocket.
// It also sends periodic pings.
func (c *Connection) WriteLoop(ctx context.Context, cancel context.CancelFunc) {
defer cancel()
ticker := time.NewTicker(pingInterval)
defer ticker.Stop()
for {
select {
case msg, ok := <-c.send:
if !ok {
return
}
writeCtx, cancel := context.WithTimeout(ctx, writeTimeout)
data, err := json.Marshal(msg)
if err != nil {
cancel()
log.Errorf("WebSocket: marshal error: %v", err)
continue
}
err = c.ws.Write(writeCtx, websocket.MessageText, data)
cancel()
if err != nil {
log.Debugf("WebSocket: write error for user %d: %v", c.UserID(), err)
return
}
case <-ticker.C:
pingCtx, cancel := context.WithTimeout(ctx, writeTimeout)
err := c.ws.Ping(pingCtx)
cancel()
if err != nil {
log.Debugf("WebSocket: ping error for user %d: %v", c.UserID(), err)
return
}
case <-ctx.Done():
return
}
}
}
// validEvents is the set of event names clients are allowed to subscribe to.
var validEvents = map[string]bool{
"notification.created": true,
}
func isValidEvent(event string) bool {
return validEvents[event]
}