// Copyright 2014, 2015 by Sascha L. Teichmann
// Use of this source code is governed by the MIT license
// that can be found in the LICENSE file.

package main

import (
	"encoding/json"
	"log"
	"net/http"

	"github.com/gorilla/websocket"
)

type websocketForwarder struct {
	upgrader    *websocket.Upgrader
	connections map[*connection]struct{}
	funcs       chan func(*websocketForwarder)
	init        func(*websocketForwarder, *connection)
}

type connection struct {
	ws   *websocket.Conn
	send chan []byte
}

type (
	tilesMsg struct {
		Tiles []xz `json:"tiles"`
	}
	plsMsg struct {
		Pls []*player `json:"players"`
	}
)

func newWebsocketForwarder() *websocketForwarder {
	upgrader := &websocket.Upgrader{
		ReadBufferSize:  512,
		WriteBufferSize: 2048,
		//CheckOrigin:     func(*http.Request) bool { return true },
	}
	return &websocketForwarder{
		upgrader:    upgrader,
		connections: make(map[*connection]struct{}),
		funcs:       make(chan func(*websocketForwarder)),
	}
}

func (wsf *websocketForwarder) run() {
	for fn := range wsf.funcs {
		fn(wsf)
	}
}

func (wsf *websocketForwarder) register(c *connection) {
	wsf.funcs <- func(wsf *websocketForwarder) {
		wsf.connections[c] = struct{}{}
	}
}

func (wsf *websocketForwarder) unregister(c *connection) {
	wsf.funcs <- func(wsf *websocketForwarder) {
		if _, ok := wsf.connections[c]; ok {
			delete(wsf.connections, c)
			close(c.send)
		}
	}
}

func (wsf *websocketForwarder) setInit(init func(*websocketForwarder, *connection)) {
	wsf.funcs <- func(wsf *websocketForwarder) {
		wsf.init = init
	}
}

func (wsf *websocketForwarder) send(m interface{}) {
	wsf.funcs <- func(wsf *websocketForwarder) {
		if len(wsf.connections) == 0 {
			return
		}

		data, err := json.Marshal(m)
		if err != nil {
			log.Printf("encoding failed. %v\n", err)
			return
		}

		for c := range wsf.connections {
			select {
			case c.send <- data:
			default:
				delete(wsf.connections, c)
				close(c.send)
			}
		}
	}
}

func (wsf *websocketForwarder) singleSend(c *connection, m interface{}) {
	wsf.funcs <- func(wsf *websocketForwarder) {
		_, ok := wsf.connections[c]
		if !ok {
			return
		}
		data, err := json.Marshal(m)
		if err != nil {
			log.Printf("encoding failed. %v\n", err)
			return
		}
		select {
		case c.send <- data:
		default:
			delete(wsf.connections, c)
			close(c.send)
		}
	}
}

func (wsf *websocketForwarder) BaseTilesUpdated(changes []xz) {
	wsf.send(&tilesMsg{Tiles: changes})
}

func (wsf *websocketForwarder) BroadcastPlayers(pls []*player) {
	wsf.send(&plsMsg{Pls: pls})
}

func (wsf *websocketForwarder) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
	ws, err := wsf.upgrader.Upgrade(rw, r, nil)
	if err != nil {
		log.Printf("Cannot upgrade to websocket: %s\n", err)
		return
	}
	c := &connection{ws: ws, send: make(chan []byte, 8)}
	wsf.register(c)
	defer wsf.unregister(c)
	go c.writer()
	if wsf.init != nil {
		wsf.init(wsf, c)
	}
	c.reader()
}

func (c *connection) writer() {
	defer c.ws.Close()
	for msg := range c.send {
		if c.ws.WriteMessage(websocket.TextMessage, msg) != nil {
			break
		}
	}
}

func (c *connection) reader() {
	defer c.ws.Close()
	for {
		// Just read the message and ignore it.
		if _, _, err := c.ws.NextReader(); err != nil {
			break
		}
	}
}