Use more modern function channel approach in websocket forwarder.

This commit is contained in:
Sascha L. Teichmann 2022-03-01 22:29:56 +01:00
parent da2b327985
commit 25e62b3a8b

View File

@ -15,10 +15,8 @@ import (
type websocketForwarder struct { type websocketForwarder struct {
upgrader *websocket.Upgrader upgrader *websocket.Upgrader
register chan *connection connections map[*connection]struct{}
unregister chan *connection funcs chan func(*websocketForwarder)
broadcast chan msg
connections map[*connection]bool
} }
type connection struct { type connection struct {
@ -32,49 +30,74 @@ type msg struct {
} }
func newWebsocketForwarder() *websocketForwarder { func newWebsocketForwarder() *websocketForwarder {
upgrader := &websocket.Upgrader{ReadBufferSize: 512, WriteBufferSize: 2048} upgrader := &websocket.Upgrader{
ReadBufferSize: 512,
WriteBufferSize: 2048,
//CheckOrigin: func(*http.Request) bool { return true },
}
return &websocketForwarder{ return &websocketForwarder{
upgrader: upgrader, upgrader: upgrader,
register: make(chan *connection), connections: make(map[*connection]struct{}),
unregister: make(chan *connection), funcs: make(chan func(*websocketForwarder)),
broadcast: make(chan msg), }
connections: make(map[*connection]bool)}
} }
func (wsf *websocketForwarder) run() { func (wsf *websocketForwarder) run() {
for { for fn := range wsf.funcs {
select { fn(wsf)
case c := <-wsf.register: }
wsf.connections[c] = true }
case c := <-wsf.unregister:
if _, ok := wsf.connections[c]; ok { func (wsf *websocketForwarder) register(c *connection) {
wsf.funcs <- func(wsf *websocketForwarder) {
wsf.connections[c] = struct{}{}
}
}
func (wsf *websocketForwarder) unregister(c *connection) {
wsf.funcs <- func(wsf *websocketForwarder) {
if _, ok := wsf.connections[c]; ok {
delete(wsf.connections, c)
close(c.send)
}
}
}
func (wsf *websocketForwarder) send(m *msg) {
wsf.funcs <- func(wsf *websocketForwarder) {
if len(wsf.connections) == 0 {
return
}
var buf bytes.Buffer
encoder := json.NewEncoder(&buf)
if err := encoder.Encode(m); err != nil {
log.Printf("encoding changes failed: %s\n", err)
return
}
data := buf.Bytes()
for c := range wsf.connections {
select {
case c.send <- data:
default:
delete(wsf.connections, c) delete(wsf.connections, c)
close(c.send) close(c.send)
} }
case message := <-wsf.broadcast:
if len(wsf.connections) == 0 {
continue
}
var buf bytes.Buffer
encoder := json.NewEncoder(&buf)
if err := encoder.Encode(message); err != nil {
log.Printf("encoding changes failed: %s\n", err)
continue
}
m := buf.Bytes()
for c := range wsf.connections {
select {
case c.send <- m:
default:
delete(wsf.connections, c)
close(c.send)
}
}
} }
} }
} }
func (wsf *websocketForwarder) BaseTilesUpdated(changes []xz) {
wsf.send(&msg{Tiles: changes})
}
func (wsf *websocketForwarder) BroadcastPlayers(pls []*player) {
wsf.send(&msg{Pls: pls})
}
func (wsf *websocketForwarder) ServeHTTP(rw http.ResponseWriter, r *http.Request) { func (wsf *websocketForwarder) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
ws, err := wsf.upgrader.Upgrade(rw, r, nil) ws, err := wsf.upgrader.Upgrade(rw, r, nil)
if err != nil { if err != nil {
@ -82,20 +105,12 @@ func (wsf *websocketForwarder) ServeHTTP(rw http.ResponseWriter, r *http.Request
return return
} }
c := &connection{ws: ws, send: make(chan []byte, 8)} c := &connection{ws: ws, send: make(chan []byte, 8)}
wsf.register <- c wsf.register(c)
defer func() { wsf.unregister <- c }() defer wsf.unregister(c)
go c.writer() go c.writer()
c.reader() c.reader()
} }
func (wsf *websocketForwarder) BaseTilesUpdated(changes []xz) {
wsf.broadcast <- msg{Tiles: changes}
}
func (wsf *websocketForwarder) BroadcastPlayers(pls []*player) {
wsf.broadcast <- msg{Pls: pls}
}
func (c *connection) writer() { func (c *connection) writer() {
defer c.ws.Close() defer c.ws.Close()
for msg := range c.send { for msg := range c.send {