diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index 72342ee7..1f09dd6e 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -28,6 +28,7 @@ import ( "crypto/tls" "io" "net" + "sync" "time" "github.com/pingcap/tiproxy/lib/config" @@ -42,6 +43,19 @@ var ( ErrInvalidSequence = errors.New("invalid sequence") ) +var ( + readerPool = sync.Pool{ + New: func() any { + return bufio.NewReaderSize(nil, DefaultConnBufferSize) + }, + } + writerPool = sync.Pool{ + New: func() any { + return bufio.NewWriterSize(nil, DefaultConnBufferSize) + }, + } +) + const ( DefaultConnBufferSize = 32 * 1024 ) @@ -86,16 +100,27 @@ type basicReadWriter struct { inBytes uint64 outBytes uint64 sequence uint8 + pooled bool } func newBasicReadWriter(conn net.Conn, bufferSize int) *basicReadWriter { if bufferSize == 0 { bufferSize = DefaultConnBufferSize } - return &basicReadWriter{ - Conn: conn, - ReadWriter: bufio.NewReadWriter(bufio.NewReaderSize(conn, bufferSize), bufio.NewWriterSize(conn, bufferSize)), + brw := &basicReadWriter{ + Conn: conn, + } + if bufferSize == DefaultConnBufferSize { + r := readerPool.Get().(*bufio.Reader) + r.Reset(conn) + w := writerPool.Get().(*bufio.Writer) + w.Reset(conn) + brw.ReadWriter = bufio.NewReadWriter(r, w) + brw.pooled = true + } else { + brw.ReadWriter = bufio.NewReadWriter(bufio.NewReaderSize(conn, bufferSize), bufio.NewWriterSize(conn, bufferSize)) } + return brw } func (brw *basicReadWriter) Read(b []byte) (n int, err error) { @@ -153,6 +178,22 @@ func (brw *basicReadWriter) ResetSequence() { brw.sequence = 0 } +func (brw *basicReadWriter) Free() { + if brw.pooled { + brw.pooled = false + brw.ReadWriter.Reader.Reset(nil) + brw.ReadWriter.Writer.Reset(nil) + readerPool.Put(brw.ReadWriter.Reader) + writerPool.Put(brw.ReadWriter.Writer) + } +} + +func (brw *basicReadWriter) Close() error { + err := brw.Conn.Close() + brw.Free() + return err +} + func (brw *basicReadWriter) TLSConnectionState() tls.ConnectionState { return tls.ConnectionState{} } @@ -496,6 +537,7 @@ func (p *packetIO) Close() error { errs = append(errs, err) } */ + if err := p.readWriter.Close(); err != nil && !errors.Is(err, net.ErrClosed) { errs = append(errs, errors.WithStack(err)) } diff --git a/pkg/proxy/net/packetio_test.go b/pkg/proxy/net/packetio_test.go index 15110884..9f59b7fb 100644 --- a/pkg/proxy/net/packetio_test.go +++ b/pkg/proxy/net/packetio_test.go @@ -719,3 +719,55 @@ func runForwardBenchmark(b *testing.B, f func(packetIO1, packetIO2 *packetIO)) { _ = packetIO2.Close() wg.Wait() } + +func TestPacketIOPooling(t *testing.T) { + testTCPConn(t, + func(t *testing.T, cli *packetIO) { + brw, ok := cli.readWriter.(*basicReadWriter) + require.True(t, ok) + require.True(t, brw.pooled, "pooled flag should be true for default buffer size") + + require.NoError(t, cli.WritePacket([]byte("pooltest"), true)) + }, + func(t *testing.T, srv *packetIO) { + brw, ok := srv.readWriter.(*basicReadWriter) + require.True(t, ok) + require.True(t, brw.pooled, "pooled flag should be true for default buffer size") + + data, err := srv.ReadPacket() + require.NoError(t, err) + require.Equal(t, []byte("pooltest"), data) + }, + 1, + ) + + lg, _ := logger.CreateLoggerForTest(t) + cli, srv := net.Pipe() + cliIO := NewPacketIO(cli, lg, DefaultConnBufferSize*2) + srvIO := NewPacketIO(srv, lg, DefaultConnBufferSize*2) + brw, ok := cliIO.readWriter.(*basicReadWriter) + require.True(t, ok) + require.False(t, brw.pooled, "pooled flag should be false for non-default buffer size") + _ = cliIO.Close() + _ = srvIO.Close() + + testTCPConn(t, + func(t *testing.T, cli *packetIO) { + require.NoError(t, cli.Close()) + require.NoError(t, cli.Close()) + }, + func(t *testing.T, srv *packetIO) { + require.NoError(t, srv.Close()) + require.NoError(t, srv.Close()) + }, + 1, + ) + + for i := 0; i < 100; i++ { + c1, c2 := net.Pipe() + p1 := NewPacketIO(c1, lg, DefaultConnBufferSize) + p2 := NewPacketIO(c2, lg, DefaultConnBufferSize) + _ = p1.Close() + _ = p2.Close() + } +}