diff --git a/internal/mocks/bufiox/bufreader.go b/internal/mocks/bufiox/bufreader.go new file mode 100644 index 0000000000..a83787c51f --- /dev/null +++ b/internal/mocks/bufiox/bufreader.go @@ -0,0 +1,137 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Code generated by MockGen. DO NOT EDIT. +// Source: ../../../gopkg/bufiox/bufreader.go + +// Package bufiox is a generated GoMock package. +package bufiox + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockReader is a mock of Reader interface. +type MockReader struct { + ctrl *gomock.Controller + recorder *MockReaderMockRecorder +} + +// MockReaderMockRecorder is the mock recorder for MockReader. +type MockReaderMockRecorder struct { + mock *MockReader +} + +// NewMockReader creates a new mock instance. +func NewMockReader(ctrl *gomock.Controller) *MockReader { + mock := &MockReader{ctrl: ctrl} + mock.recorder = &MockReaderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockReader) EXPECT() *MockReaderMockRecorder { + return m.recorder +} + +// Next mocks base method. +func (m *MockReader) Next(n int) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Next", n) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Next indicates an expected call of Next. +func (mr *MockReaderMockRecorder) Next(n interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockReader)(nil).Next), n) +} + +// Peek mocks base method. +func (m *MockReader) Peek(n int) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Peek", n) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Peek indicates an expected call of Peek. +func (mr *MockReaderMockRecorder) Peek(n interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Peek", reflect.TypeOf((*MockReader)(nil).Peek), n) +} + +// ReadBinary mocks base method. +func (m *MockReader) ReadBinary(bs []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadBinary", bs) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadBinary indicates an expected call of ReadBinary. +func (mr *MockReaderMockRecorder) ReadBinary(bs interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadBinary", reflect.TypeOf((*MockReader)(nil).ReadBinary), bs) +} + +// ReadLen mocks base method. +func (m *MockReader) ReadLen() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadLen") + ret0, _ := ret[0].(int) + return ret0 +} + +// ReadLen indicates an expected call of ReadLen. +func (mr *MockReaderMockRecorder) ReadLen() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadLen", reflect.TypeOf((*MockReader)(nil).ReadLen)) +} + +// Release mocks base method. +func (m *MockReader) Release(e error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Release", e) + ret0, _ := ret[0].(error) + return ret0 +} + +// Release indicates an expected call of Release. +func (mr *MockReaderMockRecorder) Release(e interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockReader)(nil).Release), e) +} + +// Skip mocks base method. +func (m *MockReader) Skip(n int) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Skip", n) + ret0, _ := ret[0].(error) + return ret0 +} + +// Skip indicates an expected call of Skip. +func (mr *MockReaderMockRecorder) Skip(n interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Skip", reflect.TypeOf((*MockReader)(nil).Skip), n) +} diff --git a/internal/mocks/conn.go b/internal/mocks/conn.go index f4d68713f4..3a576ea08b 100644 --- a/internal/mocks/conn.go +++ b/internal/mocks/conn.go @@ -20,6 +20,8 @@ import ( bytes2 "bytes" "net" "time" + + "github.com/cloudwego/gopkg/bufiox" ) var _ net.Conn = &Conn{} @@ -111,3 +113,12 @@ func NewIOConn() *Conn { }, } } + +type MockConnWithBufioxReader struct { + net.Conn + BufioxReader bufiox.Reader +} + +func (c *MockConnWithBufioxReader) Reader() bufiox.Reader { + return c.BufioxReader +} diff --git a/internal/mocks/update.sh b/internal/mocks/update.sh index 350fb4ad79..b51d537825 100755 --- a/internal/mocks/update.sh +++ b/internal/mocks/update.sh @@ -23,6 +23,7 @@ files=( ../../pkg/proxy/proxy.go proxy/proxy.go proxy ../../pkg/utils/sharedticker.go utils/sharedticker.go utils ../../../netpoll/connection.go netpoll/connection.go netpoll +../../../gopkg/bufiox/bufreader.go bufiox/bufreader.go bufiox $GOROOT/src/net/net.go net/net.go net ) diff --git a/pkg/remote/trans/nphttp2/grpc/framer.go b/pkg/remote/trans/nphttp2/grpc/framer.go index c27a74cefb..ebe221f161 100644 --- a/pkg/remote/trans/nphttp2/grpc/framer.go +++ b/pkg/remote/trans/nphttp2/grpc/framer.go @@ -20,6 +20,8 @@ import ( "io" "net" + "github.com/cloudwego/gopkg/bufiox" + "github.com/bytedance/gopkg/lang/dirtmake" "github.com/cloudwego/netpoll" "golang.org/x/net/http2/hpack" @@ -29,16 +31,24 @@ import ( type framer struct { *grpcframe.Framer - reader netpoll.Reader + reader bufiox.Reader writer *bufWriter } func newFramer(conn net.Conn, writeBufferSize, readBufferSize, maxHeaderListSize uint32) *framer { - var r netpoll.Reader - if npConn, ok := conn.(interface{ Reader() netpoll.Reader }); ok { - r = npConn.Reader() + var r bufiox.Reader + // Initialize a bufiox.Reader based on the connection: + // 1. If the connection's reader is a `bufiox.Reader`, use it directly. + // 2. If the connection's reader is a `netpoll.Reader`, wrap it in a `netpollBufioxReader`. + // 3. Otherwise, create a `bufiox.DefaultReader` with the connection. + if bc, ok := conn.(interface{ Reader() bufiox.Reader }); ok { + r = bc.Reader() } else { - r = netpoll.NewReader(conn) + if npConn, ok := conn.(interface{ Reader() netpoll.Reader }); ok { + r = &netpollBufioxReader{Reader: npConn.Reader()} + } else { + r = bufiox.NewDefaultReader(conn) + } } w := newBufWriter(conn, int(writeBufferSize)) fr := &framer{ @@ -101,3 +111,47 @@ func (w *bufWriter) Flush() error { w.offset = 0 return w.err } + +var _ bufiox.Reader = &netpollBufioxReader{} + +// netpollBufioxReader implements bufiox.Reader with netpoll.Reader +type netpollBufioxReader struct { + netpoll.Reader + readLen int +} + +func (r *netpollBufioxReader) Next(n int) (p []byte, err error) { + p, err = r.Reader.Next(n) + if err != nil { + return nil, err + } + r.readLen += len(p) + return p, nil +} + +func (r *netpollBufioxReader) ReadBinary(bs []byte) (n int, err error) { + p, err := r.Next(len(bs)) + if err != nil { + return 0, err + } + n = copy(bs, p) + return n, nil +} + +func (r *netpollBufioxReader) Skip(n int) (err error) { + err = r.Reader.Skip(n) + if err != nil { + return err + } + r.readLen += n + return nil +} + +func (r *netpollBufioxReader) ReadLen() (n int) { + return r.readLen +} + +func (r *netpollBufioxReader) Release(e error) (err error) { + r.readLen = 0 + return r.Reader.Release() +} diff --git a/pkg/remote/trans/nphttp2/grpc/framer_test.go b/pkg/remote/trans/nphttp2/grpc/framer_test.go new file mode 100644 index 0000000000..8afa2f44b8 --- /dev/null +++ b/pkg/remote/trans/nphttp2/grpc/framer_test.go @@ -0,0 +1,114 @@ +/* + * Copyright 2025 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package grpc + +import ( + "io" + "net" + "testing" + + "github.com/cloudwego/gopkg/bufiox" + "github.com/golang/mock/gomock" + + mocksnetpoll "github.com/cloudwego/kitex/internal/mocks/netpoll" + "github.com/cloudwego/kitex/internal/test" +) + +type mockConnWithBufioxReader struct { + net.Conn + bufioxReader bufiox.Reader +} + +func (c *mockConnWithBufioxReader) Reader() bufiox.Reader { + return c.bufioxReader +} + +func TestNewFramer(t *testing.T) { + // conn without bufiox reader + var conn net.Conn + conn = &mockConn{} + fr := newFramer(conn, 0, 0, 0) + _, ok := fr.reader.(*bufiox.DefaultReader) + test.Assert(t, ok) + + // conn with bufiox reader + reader := &bufiox.DefaultReader{} + conn = &mockConnWithBufioxReader{bufioxReader: reader} + fr = newFramer(conn, 0, 0, 0) + test.Assert(t, fr.reader == reader) + + // netpoll conn + conn = &mockNetpollConn{} + fr = newFramer(conn, 0, 0, 0) + _, ok = fr.reader.(*netpollBufioxReader) + test.Assert(t, ok) +} + +func TestNetpollBufioxReader(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + var ( + validReadLength = 5 + eofReadLength = 8 + validSkipLength = 5 + eofSkipLength = 3 + expectedString = "Hello" + ) + + mockReader := mocksnetpoll.NewMockReader(ctrl) + mockReader.EXPECT().Next(validReadLength).Return([]byte(expectedString), nil).Times(2) // readBinary also calls reader.Next + mockReader.EXPECT().Next(eofReadLength).Return(nil, io.EOF).Times(1) + mockReader.EXPECT().Skip(validSkipLength).Return(nil).Times(1) + mockReader.EXPECT().Skip(eofSkipLength).Return(io.EOF).Times(1) + mockReader.EXPECT().Release().Return(nil).Times(1) + reader := &netpollBufioxReader{ + Reader: mockReader, + } + currReadLen := 0 + + // Next + p, err := reader.Next(validReadLength) + test.Assert(t, err == nil) + test.Assert(t, "Hello" == string(p)) + test.Assert(t, validReadLength == reader.ReadLen()) + currReadLen = validReadLength + p, err = reader.Next(8) + test.Assert(t, io.EOF == err) + test.Assert(t, p == nil) + + // ReadBinary + buf := make([]byte, validReadLength) + n, err := reader.ReadBinary(buf) + test.Assert(t, err == nil) + test.Assert(t, validReadLength == n) + test.Assert(t, "Hello" == string(buf)) + test.Assert(t, currReadLen+validReadLength == reader.ReadLen()) + currReadLen = reader.ReadLen() + + // Skip + err = reader.Skip(validReadLength) + test.Assert(t, err == nil) + test.Assert(t, currReadLen+validReadLength == reader.ReadLen()) + err = reader.Skip(3) + test.Assert(t, io.EOF == err) + + // Release + err = reader.Release(nil) + test.Assert(t, err == nil) + test.Assert(t, reader.ReadLen() == 0) +} diff --git a/pkg/remote/trans/nphttp2/grpc/grpcframe/frame_parser.go b/pkg/remote/trans/nphttp2/grpc/grpcframe/frame_parser.go index 603bd81380..586bba6b76 100644 --- a/pkg/remote/trans/nphttp2/grpc/grpcframe/frame_parser.go +++ b/pkg/remote/trans/nphttp2/grpc/grpcframe/frame_parser.go @@ -16,7 +16,8 @@ import ( "encoding/binary" "fmt" - "github.com/cloudwego/netpoll" + "github.com/cloudwego/gopkg/bufiox" + "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" ) @@ -76,7 +77,7 @@ func (f *DataFrame) Data() []byte { return f.data } -func parseDataFrame(fc *frameCache, fh http2.FrameHeader, payload netpoll.Reader) (http2.Frame, error) { +func parseDataFrame(fc *frameCache, fh http2.FrameHeader, payload bufiox.Reader) (http2.Frame, error) { if fh.StreamID == 0 { // DATA frames MUST be associated with a stream. If a // DATA frame is received whose stream identifier @@ -91,12 +92,12 @@ func parseDataFrame(fc *frameCache, fh http2.FrameHeader, payload netpoll.Reader var padSize byte payloadLen := int(fh.Length) if fh.Flags.Has(http2.FlagDataPadded) { - var err error - padSize, err = payload.ReadByte() - payloadLen-- + p, err := payload.Next(1) if err != nil { return nil, err } + padSize = p[0] + payloadLen-- } if int(padSize) > payloadLen { // If the length of the padding is greater than the diff --git a/pkg/remote/trans/nphttp2/grpc/grpcframe/frame_reader.go b/pkg/remote/trans/nphttp2/grpc/grpcframe/frame_reader.go index 551a944103..89143124f1 100644 --- a/pkg/remote/trans/nphttp2/grpc/grpcframe/frame_reader.go +++ b/pkg/remote/trans/nphttp2/grpc/grpcframe/frame_reader.go @@ -19,7 +19,7 @@ import ( "io" "strings" - "github.com/cloudwego/netpoll" + "github.com/cloudwego/gopkg/bufiox" "golang.org/x/net/http/httpguts" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" @@ -34,7 +34,7 @@ type Framer struct { lastHeaderStream uint32 lastFrame http2.Frame - reader netpoll.Reader + reader bufiox.Reader maxReadSize uint32 writer io.Writer @@ -104,7 +104,7 @@ func (fc *frameCache) getDataFrame() *DataFrame { } // NewFramer returns a Framer that writes frames to w and reads them from r. -func NewFramer(w io.Writer, r netpoll.Reader) *Framer { +func NewFramer(w io.Writer, r bufiox.Reader) *Framer { fr := &Framer{ writer: w, reader: r, diff --git a/pkg/remote/trans/nphttp2/grpc/grpcframe/frame_reader_test.go b/pkg/remote/trans/nphttp2/grpc/grpcframe/frame_reader_test.go index aaa3708488..9b7d377f98 100644 --- a/pkg/remote/trans/nphttp2/grpc/grpcframe/frame_reader_test.go +++ b/pkg/remote/trans/nphttp2/grpc/grpcframe/frame_reader_test.go @@ -21,15 +21,16 @@ import ( "strings" "testing" + "github.com/cloudwego/gopkg/bufiox" + "github.com/cloudwego/gopkg/protocol/thrift" - "github.com/cloudwego/netpoll" "golang.org/x/net/http2" "github.com/cloudwego/kitex/internal/test" ) type mockNetpollReader struct { - netpoll.Reader + bufiox.Reader buf []byte } @@ -65,7 +66,7 @@ func Test_Framer_readAndCheckFrameHeader(t *testing.T) { fr.SetMaxReadFrameSize(http2MaxFrameLen) testcases := []struct { desc string - reader netpoll.Reader + reader bufiox.Reader checkFrameHeaderAndErr func(*testing.T, http2.FrameHeader, error) }{ { diff --git a/pkg/remote/trans/nphttp2/grpc/http2_client.go b/pkg/remote/trans/nphttp2/grpc/http2_client.go index cff486f825..c493339ae9 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_client.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_client.go @@ -1204,7 +1204,7 @@ func (t *http2Client) reader() { default: klog.Warnf("transport: http2Client.reader got unhandled frame type %v.", frame) } - t.framer.reader.Release() + t.framer.reader.Release(nil) } } diff --git a/pkg/remote/trans/nphttp2/grpc/http2_server.go b/pkg/remote/trans/nphttp2/grpc/http2_server.go index 543e381d7a..555184878f 100644 --- a/pkg/remote/trans/nphttp2/grpc/http2_server.go +++ b/pkg/remote/trans/nphttp2/grpc/http2_server.go @@ -469,7 +469,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context. default: klog.CtxErrorf(t.ctx, "transport: http2Server.HandleStreams found unhandled frame type %v.", frame) } - t.framer.reader.Release() + t.framer.reader.Release(nil) } } diff --git a/pkg/remote/trans/nphttp2/grpc/mocks_test.go b/pkg/remote/trans/nphttp2/grpc/mocks_test.go index 1b52ad8baa..efa50ed86b 100644 --- a/pkg/remote/trans/nphttp2/grpc/mocks_test.go +++ b/pkg/remote/trans/nphttp2/grpc/mocks_test.go @@ -45,10 +45,11 @@ var _ netpoll.Connection = &mockNetpollConn{} // mockNetpollConn implements netpoll.Connection. type mockNetpollConn struct { mockConn + reader netpoll.Reader } func (m *mockNetpollConn) Reader() netpoll.Reader { - panic("implement me") + return m.reader } func (m *mockNetpollConn) Writer() netpoll.Writer { diff --git a/pkg/remote/trans/nphttp2/grpc/transport_test.go b/pkg/remote/trans/nphttp2/grpc/transport_test.go index 78ec8e2161..bac239e82e 100644 --- a/pkg/remote/trans/nphttp2/grpc/transport_test.go +++ b/pkg/remote/trans/nphttp2/grpc/transport_test.go @@ -1286,7 +1286,7 @@ func TestServerWithMisbehavedClient(t *testing.T) { // success chan indicates that reader received a RSTStream from server. success := make(chan struct{}) var mu sync.Mutex - framer := grpcframe.NewFramer(mconn, mconn.(netpoll.Connection).Reader()) + framer := grpcframe.NewFramer(mconn, &netpollBufioxReader{Reader: mconn.(netpoll.Connection).Reader()}) if err := framer.WriteSettings(); err != nil { t.Fatalf("Error while writing settings: %v", err) } diff --git a/pkg/remote/trans/nphttp2/server_handler.go b/pkg/remote/trans/nphttp2/server_handler.go index c1293e3a6e..9735450fe9 100644 --- a/pkg/remote/trans/nphttp2/server_handler.go +++ b/pkg/remote/trans/nphttp2/server_handler.go @@ -29,6 +29,7 @@ import ( "sync/atomic" "time" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/netpoll" "github.com/cloudwego/kitex/pkg/endpoint" @@ -92,17 +93,26 @@ var prefaceReadAtMost = func() int { func (t *svrTransHandler) ProtocolMatch(ctx context.Context, conn net.Conn) error { // Check the validity of client preface. - // FIXME: should not rely on netpoll.Reader + var peekReader interface { + Peek(n int) (buf []byte, err error) + } + if withReader, ok := conn.(interface{ Reader() bufiox.Reader }); ok { + if br := withReader.Reader(); br != nil { + peekReader = br + } + } if withReader, ok := conn.(interface{ Reader() netpoll.Reader }); ok { - if npReader := withReader.Reader(); npReader != nil { - // read at most avoid block - preface, err := npReader.Peek(prefaceReadAtMost) - if err != nil { - return err - } - if len(preface) >= prefaceReadAtMost && bytes.Equal(preface[:prefaceReadAtMost], grpcTransport.ClientPreface[:prefaceReadAtMost]) { - return nil - } + if br := withReader.Reader(); br != nil { + peekReader = br + } + } + if peekReader != nil { + preface, err := peekReader.Peek(prefaceReadAtMost) + if err != nil { + return err + } + if len(preface) >= prefaceReadAtMost && bytes.Equal(preface[:prefaceReadAtMost], grpcTransport.ClientPreface[:prefaceReadAtMost]) { + return nil } } return errors.New("error protocol not match") diff --git a/pkg/remote/trans/nphttp2/server_handler_test.go b/pkg/remote/trans/nphttp2/server_handler_test.go index 0cb9e60236..f4a4aab46c 100644 --- a/pkg/remote/trans/nphttp2/server_handler_test.go +++ b/pkg/remote/trans/nphttp2/server_handler_test.go @@ -23,11 +23,10 @@ import ( "testing" "time" - "github.com/cloudwego/kitex/internal/mocks" - "github.com/golang/mock/gomock" - "github.com/cloudwego/kitex/internal/mocks/netpoll" + "github.com/cloudwego/kitex/internal/mocks" + mockBufiox "github.com/cloudwego/kitex/internal/mocks/bufiox" mocksremote "github.com/cloudwego/kitex/internal/mocks/remote" "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/kerrors" @@ -315,24 +314,21 @@ func TestSvrTransHandlerProtocolMatch(t *testing.T) { defer ctrl.Finish() th := &svrTransHandler{} - // netpoll reader + // bufiox reader // 1. success - reader := netpoll.NewMockReader(ctrl) + reader := mockBufiox.NewMockReader(ctrl) reader.EXPECT().Peek(prefaceReadAtMost).Times(1).Return(grpcTransport.ClientPreface, nil) - conn := netpoll.NewMockConnection(ctrl) - conn.EXPECT().Reader().AnyTimes().Return(reader) + conn := &mocks.MockConnWithBufioxReader{BufioxReader: reader} err := th.ProtocolMatch(context.Background(), conn) test.Assert(t, err == nil, err) // 2. failed, no reader - conn = netpoll.NewMockConnection(ctrl) - conn.EXPECT().Reader().AnyTimes().Return(nil) + conn = &mocks.MockConnWithBufioxReader{} err = th.ProtocolMatch(context.Background(), conn) test.Assert(t, err != nil, err) // 3. failed, wrong preface - failedReader := netpoll.NewMockReader(ctrl) + failedReader := mockBufiox.NewMockReader(ctrl) failedReader.EXPECT().Peek(prefaceReadAtMost).Times(1).Return([]byte{}, nil) - conn = netpoll.NewMockConnection(ctrl) - conn.EXPECT().Reader().AnyTimes().Return(failedReader) + conn = &mocks.MockConnWithBufioxReader{BufioxReader: failedReader} err = th.ProtocolMatch(context.Background(), conn) test.Assert(t, err != nil, err)