diff --git a/cmd/mtredisalize/connection.go b/cmd/mtredisalize/connection.go index b834030..e900abb 100644 --- a/cmd/mtredisalize/connection.go +++ b/cmd/mtredisalize/connection.go @@ -21,16 +21,18 @@ var ( ) type Connection struct { - conn net.Conn - session Session - boolArray []bool + conn net.Conn + session Session + maxBulkStringSize int64 + boolArray []bool } -func NewConnection(conn net.Conn, session Session) *Connection { +func NewConnection(conn net.Conn, session Session, maxBulkStringSize int64) *Connection { return &Connection{ - conn: conn, - session: session, - boolArray: []bool{}} + conn: conn, + session: session, + maxBulkStringSize: maxBulkStringSize, + boolArray: []bool{}} } func (c *Connection) Run() { @@ -40,7 +42,7 @@ func (c *Connection) Run() { }() rce := NewRedisCommandExecutor(c) r := bufio.NewReaderSize(c.conn, 8*1024) - parser := NewRedisParser(r, rce) + parser := NewRedisParser(r, rce, c.maxBulkStringSize) parser.Parse() log.Println("client disconnected") } diff --git a/cmd/mtredisalize/main.go b/cmd/mtredisalize/main.go index 7a3c2f6..3a612ac 100644 --- a/cmd/mtredisalize/main.go +++ b/cmd/mtredisalize/main.go @@ -16,9 +16,10 @@ import ( ) const ( - Version = "0.3" - GCDuration = "24h" - ChangeDuration = "30s" + defaultMaxBulkStringSize = 32 * 1024 * 1024 + Version = "0.3" + GCDuration = "24h" + ChangeDuration = "30s" ) func usage() { @@ -31,15 +32,16 @@ func usage() { func main() { var ( - port int - host string - driver string - cacheSize int - version bool - interleaved bool - changeUrl string - gcDuration string - changeDuration string + port int + host string + driver string + cacheSize int + version bool + interleaved bool + changeUrl string + gcDuration string + changeDuration string + maxBulkStringSize int64 ) flag.Usage = usage @@ -56,6 +58,8 @@ func main() { flag.StringVar(&changeDuration, "change-duration", ChangeDuration, "Duration to aggregate changes.") flag.StringVar(&changeUrl, "change-url", "", "URL to send changes to.") + flag.Int64Var(&maxBulkStringSize, "max-bulk-string-size", defaultMaxBulkStringSize, + "max size of a bulk string to be accepted as input (in bytes).") flag.Parse() if version { @@ -147,7 +151,7 @@ func main() { log.Printf("Cannot create session: %s", err) conn.Close() } else { - go NewConnection(conn, session).Run() + go NewConnection(conn, session, maxBulkStringSize).Run() } case <-sigChan: log.Println("Shutting down") diff --git a/cmd/mtredisalize/parser.go b/cmd/mtredisalize/parser.go index bd38351..432c6e1 100644 --- a/cmd/mtredisalize/parser.go +++ b/cmd/mtredisalize/parser.go @@ -13,8 +13,6 @@ import ( "strconv" ) -const maxBulkStringSize = 8 * 1024 * 1024 - type RedisConsumer interface { ConsumeInteger(int64) bool ConsumeArray(int64) bool @@ -24,14 +22,17 @@ type RedisConsumer interface { } type RedisParser struct { - reader *bufio.Reader - consumer RedisConsumer + reader *bufio.Reader + consumer RedisConsumer + maxBulkStringSize int64 } -func NewRedisParser(reader *bufio.Reader, consumer RedisConsumer) *RedisParser { +func NewRedisParser(reader *bufio.Reader, consumer RedisConsumer, + maxBulkStringSize int64) *RedisParser { return &RedisParser{ - reader: reader, - consumer: consumer} + reader: reader, + consumer: consumer, + maxBulkStringSize: maxBulkStringSize} } func (rp *RedisParser) Parse() { @@ -94,7 +95,7 @@ func (rp *RedisParser) bulkString(line []byte) bool { case i == 0: return rp.consumer.ConsumeBulkString([]byte{}) default: - if i > maxBulkStringSize { // prevent denial of service. + if i > rp.maxBulkStringSize { // prevent denial of service. return rp.consumer.ConsumeError( fmt.Errorf("Bulk string too large (%d bytes).\n", i)) }