Use a more general db client factory approach.

This commit is contained in:
Sascha L. Teichmann
2022-03-01 14:47:14 +01:00
parent 834c8a9bc6
commit c507663826
9 changed files with 99 additions and 53 deletions

View File

@ -5,11 +5,18 @@
package common
import (
"fmt"
"strings"
)
type DBClientCreator func() (DBClient, error)
type DBClient interface {
QueryCuboid(cuboid Cuboid, fn func(*Block) *Block) (count int, err error)
Close() error
}
type DBClientFactory interface {
Create() (DBClient, error)
Close() error
}
func IsPostgreSQL(host string) (string, bool) {
if strings.HasPrefix(host, "postgres:") {
@ -18,29 +25,11 @@ func IsPostgreSQL(host string) (string, bool) {
return "", false
}
func CreateDBClientCreator(host string, port int) DBClientCreator {
func CreateDBClientFactory(host string, port int) (DBClientFactory, error) {
if host, ok := IsPostgreSQL(host); ok {
return func() (DBClient, error) {
return NewPGClient(host)
}
if connS, ok := IsPostgreSQL(host); ok {
return NewPGClientFactory(connS)
}
var address string
if strings.ContainsRune(host, '/') {
address = host
} else {
address = fmt.Sprintf("%s:%d", host, port)
}
var proto string
if strings.ContainsRune(address, '/') {
proto = "unix"
} else {
proto = "tcp"
}
return func() (DBClient, error) {
return NewRedisClient(proto, address)
}
return NewRedisClientFactory(host, port)
}

View File

@ -1,10 +0,0 @@
// Copyright 2022 by Sascha L. Teichmann
// Use of this source code is governed by the MIT license
// that can be found in the LICENSE file.
package common
type DBClient interface {
QueryCuboid(cuboid Cuboid, fn func(*Block) *Block) (count int, err error)
Close() error
}

View File

@ -19,24 +19,43 @@ WHERE
posz BETWEEN $5 AND $6`
type PGClient struct {
db *sql.DB
conn *sql.Conn
queryCuboidStmt *sql.Stmt
}
func NewPGClient(connS string) (*PGClient, error) {
type PGClientFactory struct {
db *sql.DB
}
func NewPGClientFactory(connS string) (*PGClientFactory, error) {
db, err := sql.Open("pgx", connS)
if err != nil {
return nil, err
}
stmt, err := db.Prepare(queryCuboidSQL)
return &PGClientFactory{db: db}, nil
}
func (pgcf *PGClientFactory) Close() error {
return pgcf.db.Close()
}
func (pgcf *PGClientFactory) Create() (DBClient, error) {
ctx := context.Background()
conn, err := pgcf.db.Conn(ctx)
if err != nil {
return nil, err
}
client := PGClient{
db: db,
queryCuboidStmt: stmt,
stmt, err := conn.PrepareContext(ctx, queryCuboidSQL)
if err != nil {
conn.Close()
return nil, err
}
return &client, nil
return &PGClient{
conn: conn,
queryCuboidStmt: stmt,
}, nil
}
func (pgc *PGClient) QueryCuboid(
@ -94,5 +113,5 @@ func (pgc *PGClient) Close() error {
if pgc.queryCuboidStmt != nil {
pgc.queryCuboidStmt.Close()
}
return pgc.db.Close()
return pgc.conn.Close()
}

View File

@ -11,9 +11,43 @@ import (
"fmt"
"net"
"strconv"
"strings"
"unicode"
)
type RedisClientFactory struct {
proto string
address string
}
func NewRedisClientFactory(host string, port int) (*RedisClientFactory, error) {
var address string
if strings.ContainsRune(host, '/') {
address = host
} else {
address = fmt.Sprintf("%s:%d", host, port)
}
var proto string
if strings.ContainsRune(address, '/') {
proto = "unix"
} else {
proto = "tcp"
}
return &RedisClientFactory{
proto: proto,
address: address,
}, nil
}
func (rcf *RedisClientFactory) Close() error {
return nil
}
func (rcf *RedisClientFactory) Create() (DBClient, error) {
return NewRedisClient(rcf.proto, rcf.address)
}
type RedisClient struct {
conn net.Conn
reader *bufio.Reader