diff --git a/cmd/mtseeder/baselevel.go b/cmd/mtseeder/baselevel.go index 96bf923..d11484a 100644 --- a/cmd/mtseeder/baselevel.go +++ b/cmd/mtseeder/baselevel.go @@ -56,7 +56,7 @@ func createTiles( } func createBaseLevel( - dbcc common.DBClientCreator, + dbcf common.DBClientFactory, xMin, yMin, zMin, xMax, yMax, zMax int, transparent bool, transparentDim float32, colorsFile string, bg color.RGBA, outDir string, @@ -81,7 +81,7 @@ func createBaseLevel( for i := 0; i < numWorkers; i++ { var client common.DBClient - if client, err = dbcc(); err != nil { + if client, err = dbcf.Create(); err != nil { return } done.Add(1) diff --git a/cmd/mtseeder/main.go b/cmd/mtseeder/main.go index 7ba1fba..31d8664 100644 --- a/cmd/mtseeder/main.go +++ b/cmd/mtseeder/main.go @@ -68,12 +68,16 @@ func main() { bg := common.ParseColorDefault(bgColor, common.BackgroundColor) - dbcc := common.CreateDBClientCreator(host, port) + dbcf, err := common.CreateDBClientFactory(host, port) + if err != nil { + log.Fatalf("error: %s\n", err) + } + defer dbcf.Close() if !skipBaseLevel { td := common.Clamp32f(float32(transparentDim/100.0), 0.0, 1.0) if err := createBaseLevel( - dbcc, + dbcf, xMin, yMin, zMin, xMax, yMax, zMax, transparent, td, colorsFile, bg, diff --git a/cmd/mttilemapper/main.go b/cmd/mttilemapper/main.go index 1801d7c..d24ae3b 100644 --- a/cmd/mttilemapper/main.go +++ b/cmd/mttilemapper/main.go @@ -83,7 +83,12 @@ func main() { colors.TransparentDim = common.Clamp32f( float32(transparentDim/100.0), 0.0, 100.0) - client, err := common.CreateDBClientCreator(host, port)() + cf, err := common.CreateDBClientFactory(host, port) + if err != nil { + log.Fatalf("error: %v\n", err) + } + + client, err := cf.Create() if err != nil { log.Fatalf("Cannot connect to '%s:%d': %s", host, port, err) } diff --git a/cmd/mtwebmapper/main.go b/cmd/mtwebmapper/main.go index e80b5a8..711685d 100644 --- a/cmd/mtwebmapper/main.go +++ b/cmd/mtwebmapper/main.go @@ -118,7 +118,11 @@ func main() { colors.TransparentDim = common.Clamp32f( float32(transparentDim/100.0), 0.0, 100.0) - dbcc := common.CreateDBClientCreator(redisHost, redisPort) + dbcf, err := common.CreateDBClientFactory(redisHost, redisPort) + if err != nil { + log.Fatalf("error: %v\n", err) + } + defer dbcf.Close() var allowedUpdateIps []net.IP if allowedUpdateIps, err = ipsFromHosts(updateHosts); err != nil { @@ -127,7 +131,7 @@ func main() { tu := newTileUpdater( mapDir, - dbcc, + dbcf, allowedUpdateIps, colors, bg, yMin, yMax, diff --git a/cmd/mtwebmapper/tilesupdater.go b/cmd/mtwebmapper/tilesupdater.go index ebb8173..7b764d2 100644 --- a/cmd/mtwebmapper/tilesupdater.go +++ b/cmd/mtwebmapper/tilesupdater.go @@ -38,7 +38,7 @@ type tileUpdater struct { changes map[xz]struct{} btu baseTilesUpdates mapDir string - dbcc common.DBClientCreator + dbcf common.DBClientFactory ips []net.IP colors *common.Colors bg color.RGBA @@ -82,7 +82,7 @@ func (c xz) parent() xzm { func newTileUpdater( mapDir string, - dbcc common.DBClientCreator, + dbcf common.DBClientFactory, ips []net.IP, colors *common.Colors, bg color.RGBA, @@ -94,7 +94,7 @@ func newTileUpdater( tu := tileUpdater{ btu: btu, mapDir: mapDir, - dbcc: dbcc, + dbcf: dbcf, ips: ips, changes: map[xz]struct{}{}, colors: colors, @@ -262,10 +262,11 @@ func (tu *tileUpdater) doUpdates() { for i, n := 0, common.Min(tu.workers, len(changes)); i < n; i++ { var client common.DBClient var err error - if client, err = tu.dbcc(); err != nil { + if client, err = tu.dbcf.Create(); err != nil { log.Printf("WARN: Cannot connect to redis server: %s\n", err) continue } + btc := common.NewBaseTileCreator( client, tu.colors, tu.bg, tu.yMin, tu.yMax, diff --git a/common/clientfactory.go b/common/clientfactory.go index fb87343..35333bc 100644 --- a/common/clientfactory.go +++ b/common/clientfactory.go @@ -5,11 +5,18 @@ package common import ( - "fmt" "strings" ) -type DBClientCreator func() (DBClient, error) +type DBClient interface { + QueryCuboid(cuboid Cuboid, fn func(*Block) *Block) (count int, err error) + Close() error +} + +type DBClientFactory interface { + Create() (DBClient, error) + Close() error +} func IsPostgreSQL(host string) (string, bool) { if strings.HasPrefix(host, "postgres:") { @@ -18,29 +25,11 @@ func IsPostgreSQL(host string) (string, bool) { return "", false } -func CreateDBClientCreator(host string, port int) DBClientCreator { +func CreateDBClientFactory(host string, port int) (DBClientFactory, error) { - if host, ok := IsPostgreSQL(host); ok { - return func() (DBClient, error) { - return NewPGClient(host) - } + if connS, ok := IsPostgreSQL(host); ok { + return NewPGClientFactory(connS) } - var address string - if strings.ContainsRune(host, '/') { - address = host - } else { - address = fmt.Sprintf("%s:%d", host, port) - } - - var proto string - if strings.ContainsRune(address, '/') { - proto = "unix" - } else { - proto = "tcp" - } - - return func() (DBClient, error) { - return NewRedisClient(proto, address) - } + return NewRedisClientFactory(host, port) } diff --git a/common/dbclient.go b/common/dbclient.go deleted file mode 100644 index 29af14c..0000000 --- a/common/dbclient.go +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright 2022 by Sascha L. Teichmann -// Use of this source code is governed by the MIT license -// that can be found in the LICENSE file. - -package common - -type DBClient interface { - QueryCuboid(cuboid Cuboid, fn func(*Block) *Block) (count int, err error) - Close() error -} diff --git a/common/pgclient.go b/common/pgclient.go index a419a09..7c691b3 100644 --- a/common/pgclient.go +++ b/common/pgclient.go @@ -19,24 +19,43 @@ WHERE posz BETWEEN $5 AND $6` type PGClient struct { - db *sql.DB + conn *sql.Conn queryCuboidStmt *sql.Stmt } -func NewPGClient(connS string) (*PGClient, error) { +type PGClientFactory struct { + db *sql.DB +} + +func NewPGClientFactory(connS string) (*PGClientFactory, error) { db, err := sql.Open("pgx", connS) if err != nil { return nil, err } - stmt, err := db.Prepare(queryCuboidSQL) + return &PGClientFactory{db: db}, nil +} + +func (pgcf *PGClientFactory) Close() error { + return pgcf.db.Close() +} + +func (pgcf *PGClientFactory) Create() (DBClient, error) { + + ctx := context.Background() + + conn, err := pgcf.db.Conn(ctx) if err != nil { return nil, err } - client := PGClient{ - db: db, - queryCuboidStmt: stmt, + stmt, err := conn.PrepareContext(ctx, queryCuboidSQL) + if err != nil { + conn.Close() + return nil, err } - return &client, nil + return &PGClient{ + conn: conn, + queryCuboidStmt: stmt, + }, nil } func (pgc *PGClient) QueryCuboid( @@ -94,5 +113,5 @@ func (pgc *PGClient) Close() error { if pgc.queryCuboidStmt != nil { pgc.queryCuboidStmt.Close() } - return pgc.db.Close() + return pgc.conn.Close() } diff --git a/common/redisclient.go b/common/redisclient.go index 23c04bb..b4d36c6 100644 --- a/common/redisclient.go +++ b/common/redisclient.go @@ -11,9 +11,43 @@ import ( "fmt" "net" "strconv" + "strings" "unicode" ) +type RedisClientFactory struct { + proto string + address string +} + +func NewRedisClientFactory(host string, port int) (*RedisClientFactory, error) { + var address string + if strings.ContainsRune(host, '/') { + address = host + } else { + address = fmt.Sprintf("%s:%d", host, port) + } + + var proto string + if strings.ContainsRune(address, '/') { + proto = "unix" + } else { + proto = "tcp" + } + return &RedisClientFactory{ + proto: proto, + address: address, + }, nil +} + +func (rcf *RedisClientFactory) Close() error { + return nil +} + +func (rcf *RedisClientFactory) Create() (DBClient, error) { + return NewRedisClient(rcf.proto, rcf.address) +} + type RedisClient struct { conn net.Conn reader *bufio.Reader