diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b432e3a2f7..d9d5e8708a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -19,7 +19,7 @@ jobs: run: | cd .. rm -rf kitex-tests - git clone --depth=1 https://github.com/cloudwego/kitex-tests.git + git clone -b test/ttstream_pb --depth=1 https://github.com/DMwangnima/kitex-tests.git cd kitex-tests KITEX_TOOL_USE_PROTOC=0 ./run.sh ${{github.workspace}} cd ${{github.workspace}} diff --git a/internal/codec/protobuf_struct.go b/internal/codec/protobuf_struct.go new file mode 100644 index 0000000000..0ad8ad9a8a --- /dev/null +++ b/internal/codec/protobuf_struct.go @@ -0,0 +1,157 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package codec + +import ( + "context" + "errors" + "fmt" + + "github.com/bytedance/gopkg/lang/mcache" + "github.com/cloudwego/fastpb" + "google.golang.org/protobuf/proto" + + "github.com/cloudwego/kitex/internal/utils/safemcache" + "github.com/cloudwego/kitex/pkg/remote/codec/protobuf" + "github.com/cloudwego/kitex/pkg/rpcinfo" +) + +var ( + errEncodePbGenericEmptyMethod = errors.New("empty methodName in generic pb Encode") + errDecodePbGenericEmptyMethod = errors.New("empty methodName in generic pb Decode") +) + +// gogoproto generate +type marshaler interface { + MarshalTo(data []byte) (n int, err error) + Size() int +} + +type protobufV2MsgCodec interface { + XXX_Unmarshal(b []byte) error + XXX_Marshal(b []byte, deterministic bool) ([]byte, error) +} + +type EncodeResult struct { + Payload []byte // encoded byte slice + PreAllocate bool // the struct encoded can pre-allocate memory + PreAllocateSize int // pre-allocate size for the struct encoded +} + +const GRPCDataFrameHeaderLen = 5 + +func GRPCEncodeProtobufStruct(ctx context.Context, ri rpcinfo.RPCInfo, msg any, isCompress bool) (EncodeResult, error) { + prefixLen := GRPCDataFrameHeaderLen + if isCompress { + prefixLen = 0 + } + return encodeProtobufStruct(ctx, ri, msg, safemcache.Malloc, safemcache.Free, prefixLen) +} + +func TTStreamEncodeProtobufStruct(ctx context.Context, ri rpcinfo.RPCInfo, msg any) (EncodeResult, error) { + return encodeProtobufStruct(ctx, ri, msg, mcacheMalloc, mcache.Free, 0) +} + +func mcacheMalloc(size int) []byte { + return mcache.Malloc(size) +} + +func encodeProtobufStruct(ctx context.Context, ri rpcinfo.RPCInfo, msg any, + mallocFunc func(int) []byte, freeFunc func([]byte), prefixLen int, +) (res EncodeResult, err error) { + var payload []byte + switch t := msg.(type) { + // Deprecated: fastpb is no longer used + case fastpb.Writer: + res.PreAllocate = true + res.PreAllocateSize = t.Size() + payload = mallocFunc(res.PreAllocateSize + prefixLen) + t.FastWrite(payload[prefixLen:]) + case marshaler: + size := t.Size() + payload = mallocFunc(size + prefixLen) + if _, err = t.MarshalTo(payload[prefixLen:]); err != nil { + freeFunc(payload) + return res, err + } + res.PreAllocate = true + res.PreAllocateSize = size + case protobufV2MsgCodec: + payload, err = t.XXX_Marshal(nil, true) + case proto.Message: + payload, err = proto.Marshal(t) + case protobuf.ProtobufMsgCodec: + payload, err = t.Marshal(nil) + case protobuf.MessageWriterWithContext: + payload, err = encodeProtobufGeneric(ctx, ri, t) + default: + err = fmt.Errorf("invalid payload %T in EncodeProtobufStruct", t) + } + + if err != nil { + return res, err + } + res.Payload = payload + return res, nil +} + +func encodeProtobufGeneric(ctx context.Context, ri rpcinfo.RPCInfo, w protobuf.MessageWriterWithContext) ([]byte, error) { + methodName := ri.Invocation().MethodName() + if methodName == "" { + return nil, errEncodePbGenericEmptyMethod + } + actualMsg, err := w.WritePb(ctx, methodName) + if err != nil { + return nil, err + } + payload, ok := actualMsg.([]byte) + if !ok { + return nil, fmt.Errorf("encodePbGeneric failed, got %T", actualMsg) + } + return payload, nil +} + +func DecodeProtobufStruct(ctx context.Context, ri rpcinfo.RPCInfo, payload []byte, msg any) (err error) { + // Deprecated: fastpb is no longer used + if t, ok := msg.(fastpb.Reader); ok { + if len(payload) == 0 { + // if all fields of a struct is default value, data will be nil + // In the implementation of fastpb, if data is nil, then fastpb will skip creating this struct, as a result user will get a nil pointer which is not expected. + // So, when data is nil, use default protobuf unmarshal method to decode the struct. + // todo: fix fastpb + } else { + _, err = fastpb.ReadMessage(payload, fastpb.SkipTypeCheck, t) + return err + } + } + switch t := msg.(type) { + case protobufV2MsgCodec: + return t.XXX_Unmarshal(payload) + case proto.Message: + return proto.Unmarshal(payload, t) + case protobuf.ProtobufMsgCodec: + return t.Unmarshal(payload) + case protobuf.MessageReaderWithMethodWithContext: + methodName := ri.Invocation().MethodName() + if methodName == "" { + return errDecodePbGenericEmptyMethod + } + return t.ReadPb(ctx, methodName, payload) + default: + return fmt.Errorf("invalid payload %T in DecodeProtobufStruct", t) + } +} diff --git a/internal/codec/protobuf_struct_test.go b/internal/codec/protobuf_struct_test.go new file mode 100644 index 0000000000..63c12bfb28 --- /dev/null +++ b/internal/codec/protobuf_struct_test.go @@ -0,0 +1,595 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package codec + +import ( + "context" + "encoding/binary" + "errors" + "math/bits" + "testing" + + "google.golang.org/protobuf/proto" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/remote/codec/protobuf" + "github.com/cloudwego/kitex/pkg/rpcinfo" +) + +// Mock marshaler (gogoproto style) for testing +type mockMarshaler struct { + data string + fail bool +} + +func (m *mockMarshaler) Size() int { + return len(m.data) +} + +func (m *mockMarshaler) MarshalTo(data []byte) (int, error) { + if m.fail { + return 0, errors.New("marshal failed") + } + copy(data, m.data) + return len(m.data), nil +} + +// Mock protobufV2MsgCodec for testing +type mockProtobufV2Msg struct { + data string + fail bool +} + +func (m *mockProtobufV2Msg) XXX_Unmarshal(b []byte) error { + if m.fail { + return errors.New("unmarshal failed") + } + m.data = string(b) + return nil +} + +func (m *mockProtobufV2Msg) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if m.fail { + return nil, errors.New("marshal failed") + } + return []byte(m.data), nil +} + +// Mock MessageWriterWithContext for generic protobuf testing +type mockMessageWriter struct { + data []byte + fail bool + method string +} + +func (m *mockMessageWriter) WritePb(ctx context.Context, method string) (interface{}, error) { + m.method = method + if m.fail { + return nil, errors.New("write failed") + } + return m.data, nil +} + +// Mock MessageReaderWithMethodWithContext for generic protobuf testing +type mockMessageReader struct { + data []byte + fail bool + method string +} + +func (m *mockMessageReader) ReadPb(ctx context.Context, method string, payload []byte) error { + m.method = method + if m.fail { + return errors.New("read failed") + } + m.data = payload + return nil +} + +// sizeVarint returns the encoded size of a varint. +func sizeVarint(v uint64) int { + return int(9*uint32(bits.Len64(v))+64) / 64 +} + +// Mock fastpb.Writer (deprecated but still needs testing) +type mockFastWriter struct { + num int32 + v string +} + +func (p *mockFastWriter) Size() int { + n := sizeVarint(uint64(p.num)<<3 | uint64(2)) + n += sizeVarint(uint64(len(p.v))) + n += len(p.v) + return n +} + +func (p *mockFastWriter) FastWrite(in []byte) int { + n := binary.PutUvarint(in, uint64(p.num)<<3|uint64(2)) + n += binary.PutUvarint(in[n:], uint64(len(p.v))) + n += copy(in[n:], p.v) + return n +} + +// Mock fastpb.Reader +type mockFastReader struct { + num int32 + v string +} + +func (p *mockFastReader) FastRead(buf []byte, t int8, number int32) (int, error) { + if t != 2 { + return 0, errors.New("invalid type") + } + p.num = number + sz, n := binary.Uvarint(buf) + buf = buf[n:] + p.v = string(buf[:sz]) + return int(sz) + n, nil +} + +func newMockRPCInfo(method string) rpcinfo.RPCInfo { + ri := rpcinfo.NewRPCInfo( + rpcinfo.NewEndpointInfo("", method, nil, nil), + rpcinfo.NewEndpointInfo("", method, nil, nil), + rpcinfo.NewInvocation("", method), + rpcinfo.NewRPCConfig(), + rpcinfo.NewRPCStats(), + ) + return ri +} + +func TestGRPCEncodeProtobufStruct(t *testing.T) { + t.Run("fastpb.Writer", func(t *testing.T) { + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + msg := &mockFastWriter{ + num: 100, + v: "fastpb test", + } + + // Test without compression (with GRPC header space) + res, err := GRPCEncodeProtobufStruct(ctx, ri, msg, false) + test.Assert(t, err == nil, err) + test.Assert(t, res.Payload != nil, res) + test.Assert(t, len(res.Payload) == GRPCDataFrameHeaderLen+msg.Size(), res) + test.Assert(t, res.PreAllocate, res) + test.Assert(t, res.PreAllocateSize == msg.Size(), res) + + // Test with compression (no header space) + res2, err := GRPCEncodeProtobufStruct(ctx, ri, msg, true) + test.Assert(t, err == nil, err) + test.Assert(t, res2.Payload != nil, res2) + test.Assert(t, len(res2.Payload) == msg.Size(), res2) + test.Assert(t, res2.PreAllocate, res2) + test.Assert(t, res2.PreAllocateSize == msg.Size(), res2) + }) + t.Run("marshaler", func(t *testing.T) { + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + msg := &mockMarshaler{data: "test data"} + + // Test without compression + res, err := GRPCEncodeProtobufStruct(ctx, ri, msg, false) + test.Assert(t, err == nil, err) + test.Assert(t, res.Payload != nil) + test.Assert(t, len(res.Payload) >= GRPCDataFrameHeaderLen, res) + test.Assert(t, res.PreAllocate) + test.Assert(t, res.PreAllocateSize == len(msg.data)) + test.Assert(t, string(res.Payload[GRPCDataFrameHeaderLen:]) == msg.data) + + // Test with compression + res2, err := GRPCEncodeProtobufStruct(ctx, ri, msg, true) + test.Assert(t, err == nil, err) + test.Assert(t, res2.Payload != nil) + test.Assert(t, len(res2.Payload) == len(msg.data), res2) + test.Assert(t, res2.PreAllocate) + test.Assert(t, res2.PreAllocateSize == len(msg.data)) + test.Assert(t, string(res2.Payload) == msg.data) + + // Test marshal failure + msgFail := &mockMarshaler{data: "test", fail: true} + resFail, errFail := GRPCEncodeProtobufStruct(ctx, ri, msgFail, false) + test.Assert(t, errFail != nil) + test.Assert(t, resFail.Payload == nil) + }) + t.Run("protobufV2MsgCodec", func(t *testing.T) { + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + msg := &mockProtobufV2Msg{data: "protobuf v2 data"} + + // Test without compression + res, err := GRPCEncodeProtobufStruct(ctx, ri, msg, false) + test.Assert(t, err == nil, err) + test.Assert(t, res.Payload != nil) + test.Assert(t, string(res.Payload) == msg.data) + test.Assert(t, !res.PreAllocate) + + // Test with compression + res2, err := GRPCEncodeProtobufStruct(ctx, ri, msg, true) + test.Assert(t, err == nil, err) + test.Assert(t, res2.Payload != nil) + test.Assert(t, string(res2.Payload) == msg.data) + test.Assert(t, !res2.PreAllocate) + + // Test encode failure + msgFail := &mockProtobufV2Msg{data: "test", fail: true} + resFail, errFail := GRPCEncodeProtobufStruct(ctx, ri, msgFail, false) + test.Assert(t, errFail != nil) + test.Assert(t, resFail.Payload == nil) + }) + t.Run("proto.Message", func(t *testing.T) { + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + req := &protobuf.MockReq{ + Msg: "test message", + StrMap: map[string]string{ + "key1": "value1", + "key2": "value2", + }, + StrList: []string{"item1", "item2"}, + } + + // Test without compression (with GRPC header space) + res, err := GRPCEncodeProtobufStruct(ctx, ri, req, false) + test.Assert(t, err == nil, err) + test.Assert(t, res.Payload != nil) + test.Assert(t, len(res.Payload) >= GRPCDataFrameHeaderLen) + test.Assert(t, !res.PreAllocate) + + // Test with compression (no header space) + res2, err := GRPCEncodeProtobufStruct(ctx, ri, req, true) + test.Assert(t, err == nil, err) + test.Assert(t, res2.Payload != nil) + test.Assert(t, !res2.PreAllocate) + + // Decode and verify + newReq := &protobuf.MockReq{} + err = DecodeProtobufStruct(ctx, ri, res.Payload, newReq) + test.Assert(t, err == nil, err) + test.Assert(t, newReq.Msg == req.Msg) + test.DeepEqual(t, newReq.StrMap, req.StrMap) + test.DeepEqual(t, newReq.StrList, req.StrList) + }) + t.Run("generic", func(t *testing.T) { + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + data := []byte("generic data") + msg := &mockMessageWriter{data: data} + + // Test without compression + res, err := GRPCEncodeProtobufStruct(ctx, ri, msg, false) + test.Assert(t, err == nil, err) + test.Assert(t, res.Payload != nil) + test.DeepEqual(t, res.Payload, data) + test.Assert(t, msg.method == "TestMethod") + + // Test with compression + res2, err := GRPCEncodeProtobufStruct(ctx, ri, msg, true) + test.Assert(t, err == nil, err) + test.Assert(t, res2.Payload != nil) + test.DeepEqual(t, res2.Payload, data) + + // Test with empty method name + riEmpty := newMockRPCInfo("") + _, errEmpty := GRPCEncodeProtobufStruct(ctx, riEmpty, msg, false) + test.Assert(t, errEmpty == errEncodePbGenericEmptyMethod, errEmpty) + + // Test write failure + msgFail := &mockMessageWriter{data: data, fail: true} + _, errFail := GRPCEncodeProtobufStruct(ctx, ri, msgFail, false) + test.Assert(t, errFail != nil) + test.Assert(t, errFail.Error() == "write failed", errFail) + }) + t.Run("invalid payload", func(t *testing.T) { + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + _, err := GRPCEncodeProtobufStruct(ctx, ri, 123, false) + test.Assert(t, err != nil) + test.Assert(t, err.Error() == "invalid payload int in EncodeProtobufStruct", err) + }) +} + +func TestTTStreamEncodeProtobufStruct(t *testing.T) { + t.Run("fastpb.Writer", func(t *testing.T) { + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + msg := &mockFastWriter{ + num: 100, + v: "fastpb test", + } + + res, err := TTStreamEncodeProtobufStruct(ctx, ri, msg) + test.Assert(t, err == nil, err) + test.Assert(t, res.Payload != nil, res) + test.Assert(t, len(res.Payload) == msg.Size(), res) + test.Assert(t, res.PreAllocate, res) + test.Assert(t, res.PreAllocateSize == msg.Size(), res) + }) + t.Run("marshaler", func(t *testing.T) { + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + msg := &mockMarshaler{data: "ttstream data"} + + res, err := TTStreamEncodeProtobufStruct(ctx, ri, msg) + test.Assert(t, err == nil, err) + test.Assert(t, res.Payload != nil) + test.Assert(t, len(res.Payload) == len(msg.data), res) + test.Assert(t, res.PreAllocate) + test.Assert(t, res.PreAllocateSize == len(msg.data)) + test.Assert(t, string(res.Payload) == msg.data) + + // Test marshal failure + msgFail := &mockMarshaler{data: "test", fail: true} + resFail, errFail := TTStreamEncodeProtobufStruct(ctx, ri, msgFail) + test.Assert(t, errFail != nil) + test.Assert(t, resFail.Payload == nil) + }) + t.Run("protobufV2MsgCodec", func(t *testing.T) { + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + msg := &mockProtobufV2Msg{data: "protobuf v2 data"} + + res, err := TTStreamEncodeProtobufStruct(ctx, ri, msg) + test.Assert(t, err == nil, err) + test.Assert(t, res.Payload != nil) + test.Assert(t, string(res.Payload) == msg.data) + test.Assert(t, !res.PreAllocate) + + // Test encode failure + msgFail := &mockProtobufV2Msg{data: "test", fail: true} + resFail, errFail := TTStreamEncodeProtobufStruct(ctx, ri, msgFail) + test.Assert(t, errFail != nil) + test.Assert(t, resFail.Payload == nil) + }) + t.Run("proto.Message", func(t *testing.T) { + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + req := &protobuf.MockReq{ + Msg: "ttstream message", + StrMap: map[string]string{ + "k1": "v1", + "k2": "v2", + }, + StrList: []string{"s1", "s2"}, + } + + res, err := TTStreamEncodeProtobufStruct(ctx, ri, req) + test.Assert(t, err == nil, err) + test.Assert(t, res.Payload != nil) + test.Assert(t, !res.PreAllocate) + + // Decode and verify + newReq := &protobuf.MockReq{} + err = DecodeProtobufStruct(ctx, ri, res.Payload, newReq) + test.Assert(t, err == nil, err) + test.Assert(t, newReq.Msg == req.Msg) + test.DeepEqual(t, newReq.StrMap, req.StrMap) + test.DeepEqual(t, newReq.StrList, req.StrList) + }) + t.Run("generic", func(t *testing.T) { + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + data := []byte("generic ttstream data") + msg := &mockMessageWriter{data: data} + + res, err := TTStreamEncodeProtobufStruct(ctx, ri, msg) + test.Assert(t, err == nil, err) + test.Assert(t, res.Payload != nil) + test.DeepEqual(t, res.Payload, data) + test.Assert(t, msg.method == "TestMethod") + + // Test with empty method name + riEmpty := newMockRPCInfo("") + _, errEmpty := TTStreamEncodeProtobufStruct(ctx, riEmpty, msg) + test.Assert(t, errEmpty == errEncodePbGenericEmptyMethod, errEmpty) + + // Test write failure + msgFail := &mockMessageWriter{data: data, fail: true} + _, errFail := TTStreamEncodeProtobufStruct(ctx, ri, msgFail) + test.Assert(t, errFail != nil) + test.Assert(t, errFail.Error() == "write failed", errFail) + }) + t.Run("invalid payload", func(t *testing.T) { + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + _, err := TTStreamEncodeProtobufStruct(ctx, ri, "invalid string") + test.Assert(t, err != nil) + test.Assert(t, err.Error() == "invalid payload string in EncodeProtobufStruct", err) + }) +} + +func TestDecodeProtobufStruct(t *testing.T) { + t.Run("fastpb.Reader", func(t *testing.T) { + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + // Test with empty payload (should skip fastpb and use proto.Unmarshal fallback) + msg := &mockFastReader{} + err := DecodeProtobufStruct(ctx, ri, []byte{}, msg) + // This should not use FastRead because of empty payload + test.Assert(t, err != nil) // Will fail since mockFastReader doesn't implement proto.Message + + // Test with non-empty payload + buf := make([]byte, 20) + n := binary.PutUvarint(buf, uint64(5)<<3|uint64(2)) + n += binary.PutUvarint(buf[n:], uint64(4)) + n += copy(buf[n:], "test") + + msg2 := &mockFastReader{} + err = DecodeProtobufStruct(ctx, ri, buf[:n], msg2) + test.Assert(t, err == nil, err) + test.Assert(t, msg2.num == 5) + test.Assert(t, msg2.v == "test") + }) + t.Run("protobufV2MsgCodec", func(t *testing.T) { + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + payload := []byte("v2 payload") + msg := &mockProtobufV2Msg{} + + err := DecodeProtobufStruct(ctx, ri, payload, msg) + test.Assert(t, err == nil, err) + test.Assert(t, msg.data == string(payload)) + + // Test decode failure + msgFail := &mockProtobufV2Msg{fail: true} + err = DecodeProtobufStruct(ctx, ri, payload, msgFail) + test.Assert(t, err != nil) + test.Assert(t, err.Error() == "unmarshal failed", err) + }) + t.Run("proto.Message", func(t *testing.T) { + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + req := &protobuf.MockReq{ + Msg: "decode test", + StrMap: map[string]string{"k": "v"}, + StrList: []string{"a", "b"}, + } + + payload, err := proto.Marshal(req) + test.Assert(t, err == nil, err) + + newReq := &protobuf.MockReq{} + err = DecodeProtobufStruct(ctx, ri, payload, newReq) + test.Assert(t, err == nil, err) + test.Assert(t, newReq.Msg == req.Msg) + test.DeepEqual(t, newReq.StrMap, req.StrMap) + test.DeepEqual(t, newReq.StrList, req.StrList) + }) + t.Run("generic", func(t *testing.T) { + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + payload := []byte("generic payload") + msg := &mockMessageReader{} + + err := DecodeProtobufStruct(ctx, ri, payload, msg) + test.Assert(t, err == nil, err) + test.DeepEqual(t, msg.data, payload) + test.Assert(t, msg.method == "TestMethod") + + // Test with empty method name + riEmpty := newMockRPCInfo("") + err = DecodeProtobufStruct(ctx, riEmpty, payload, msg) + test.Assert(t, err == errDecodePbGenericEmptyMethod, err) + + // Test read failure + msgFail := &mockMessageReader{fail: true} + err = DecodeProtobufStruct(ctx, ri, payload, msgFail) + test.Assert(t, err != nil) + test.Assert(t, err.Error() == "read failed", err) + }) + t.Run("invalid payload", func(t *testing.T) { + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + payload := []byte("test") + + // Test with invalid payload type + err := DecodeProtobufStruct(ctx, ri, payload, "invalid string") + test.Assert(t, err != nil) + test.Assert(t, err.Error() == "invalid payload string in DecodeProtobufStruct", err) + + err = DecodeProtobufStruct(ctx, ri, payload, 123) + test.Assert(t, err != nil) + test.Assert(t, err.Error() == "invalid payload int in DecodeProtobufStruct", err) + }) +} + +func Test_EncodeDecodeRoundTrip(t *testing.T) { + ctx := context.Background() + ri := newMockRPCInfo("RoundTripMethod") + + testCases := []struct { + name string + msg *protobuf.MockReq + }{ + { + name: "simple message", + msg: &protobuf.MockReq{ + Msg: "simple", + }, + }, + { + name: "message with map", + msg: &protobuf.MockReq{ + Msg: "with map", + StrMap: map[string]string{"a": "1", "b": "2", "c": "3"}, + }, + }, + { + name: "message with list", + msg: &protobuf.MockReq{ + Msg: "with list", + StrList: []string{"x", "y", "z"}, + }, + }, + { + name: "complete message", + msg: &protobuf.MockReq{ + Msg: "complete", + StrMap: map[string]string{"key": "value"}, + StrList: []string{"item"}, + }, + }, + { + name: "empty message", + msg: &protobuf.MockReq{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Test GRPC encode/decode + res, err := GRPCEncodeProtobufStruct(ctx, ri, tc.msg, false) + test.Assert(t, err == nil, err) + + decoded := &protobuf.MockReq{} + err = DecodeProtobufStruct(ctx, ri, res.Payload, decoded) + test.Assert(t, err == nil, err) + test.DeepEqual(t, decoded.Msg, tc.msg.Msg) + test.DeepEqual(t, decoded.StrMap, tc.msg.StrMap) + test.DeepEqual(t, decoded.StrList, tc.msg.StrList) + + // Test TTStream encode/decode + res2, err := TTStreamEncodeProtobufStruct(ctx, ri, tc.msg) + test.Assert(t, err == nil, err) + + decoded2 := &protobuf.MockReq{} + err = DecodeProtobufStruct(ctx, ri, res2.Payload, decoded2) + test.Assert(t, err == nil, err) + test.DeepEqual(t, decoded2.Msg, tc.msg.Msg) + test.DeepEqual(t, decoded2.StrMap, tc.msg.StrMap) + test.DeepEqual(t, decoded2.StrList, tc.msg.StrList) + }) + } +} diff --git a/pkg/remote/codec/grpc/grpc.go b/pkg/remote/codec/grpc/grpc.go index 4052efa024..104f984b01 100644 --- a/pkg/remote/codec/grpc/grpc.go +++ b/pkg/remote/codec/grpc/grpc.go @@ -23,16 +23,13 @@ import ( "fmt" "io" - "github.com/cloudwego/fastpb" - "github.com/cloudwego/gopkg/bufiox" - "google.golang.org/protobuf/proto" + icodec "github.com/cloudwego/kitex/internal/codec" "github.com/cloudwego/kitex/internal/generic" "github.com/cloudwego/kitex/internal/utils/safemcache" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/remote/codec/perrors" - "github.com/cloudwego/kitex/pkg/remote/codec/protobuf" "github.com/cloudwego/kitex/pkg/remote/codec/thrift" "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" @@ -126,53 +123,24 @@ func (c *grpcCodec) Encode(ctx context.Context, message remote.Message, out remo payload, err = thrift.MarshalThriftData(ctx, c.ThriftCodec, message.Data()) } case serviceinfo.Protobuf: - switch t := message.Data().(type) { - // Deprecated: fastpb is no longer used - case fastpb.Writer: - size := t.Size() - if !isCompressed { - payload = mallocWithFirstByteZeroed(size + dataFrameHeaderLen) - t.FastWrite(payload[dataFrameHeaderLen:]) - binary.BigEndian.PutUint32(payload[1:dataFrameHeaderLen], uint32(size)) - return writer.WriteData(payload) - } - payload = safemcache.Malloc(size) - t.FastWrite(payload) - case marshaler: - size := t.Size() - if !isCompressed { - payload = mallocWithFirstByteZeroed(size + dataFrameHeaderLen) - if _, err = t.MarshalTo(payload[dataFrameHeaderLen:]); err != nil { - return err - } - binary.BigEndian.PutUint32(payload[1:dataFrameHeaderLen], uint32(size)) - return writer.WriteData(payload) - } - payload = safemcache.Malloc(size) - if _, err = t.MarshalTo(payload); err != nil { + var res icodec.EncodeResult + if !isCompressed { + res, err = icodec.GRPCEncodeProtobufStruct(ctx, message.RPCInfo(), message.Data(), false) + if err != nil { return err } - case protobufV2MsgCodec: - payload, err = t.XXX_Marshal(nil, true) - case proto.Message: - payload, err = proto.Marshal(t) - case protobuf.ProtobufMsgCodec: - payload, err = t.Marshal(nil) - case protobuf.MessageWriterWithContext: - methodName := message.RPCInfo().Invocation().MethodName() - if methodName == "" { - return errors.New("empty methodName in grpc Encode") + payload = res.Payload + if res.PreAllocate { + payload[0] = 0 + binary.BigEndian.PutUint32(payload[1:dataFrameHeaderLen], uint32(res.PreAllocateSize)) + return writer.WriteData(payload) } - actualMsg, err := t.WritePb(ctx, methodName) + } else { + res, err = icodec.GRPCEncodeProtobufStruct(ctx, message.RPCInfo(), message.Data(), true) if err != nil { return err } - payload, ok = actualMsg.([]byte) - if !ok { - return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("grpc marshal message failed: %s", err.Error())) - } - default: - return ErrInvalidPayload + payload = res.Payload } default: return ErrInvalidPayload @@ -240,34 +208,7 @@ func (c *grpcCodec) Decode(ctx context.Context, message remote.Message, in remot return thrift.UnmarshalThriftData(ctx, c.ThriftCodec, "", d, message.Data()) } case serviceinfo.Protobuf: - // Deprecated: fastpb is no longer used - if t, ok := data.(fastpb.Reader); ok { - if len(d) == 0 { - // if all fields of a struct is default value, data will be nil - // In the implementation of fastpb, if data is nil, then fastpb will skip creating this struct, as a result user will get a nil pointer which is not expected. - // So, when data is nil, use default protobuf unmarshal method to decode the struct. - // todo: fix fastpb - } else { - _, err = fastpb.ReadMessage(d, fastpb.SkipTypeCheck, t) - return err - } - } - switch t := data.(type) { - case protobufV2MsgCodec: - return t.XXX_Unmarshal(d) - case proto.Message: - return proto.Unmarshal(d, t) - case protobuf.ProtobufMsgCodec: - return t.Unmarshal(d) - case protobuf.MessageReaderWithMethodWithContext: - methodName := message.RPCInfo().Invocation().MethodName() - if methodName == "" { - return errors.New("empty methodName in grpc Decode") - } - return t.ReadPb(ctx, methodName, d) - default: - return ErrInvalidPayload - } + return icodec.DecodeProtobufStruct(ctx, message.RPCInfo(), d, data) default: return ErrInvalidPayload } diff --git a/pkg/remote/trans/ttstream/client_handler.go b/pkg/remote/trans/ttstream/client_handler.go index ec73d3ccab..6da6fb1674 100644 --- a/pkg/remote/trans/ttstream/client_handler.go +++ b/pkg/remote/trans/ttstream/client_handler.go @@ -18,6 +18,7 @@ package ttstream import ( "context" + "fmt" "github.com/bytedance/gopkg/cloud/metainfo" "github.com/cloudwego/gopkg/protocol/ttheader" @@ -25,6 +26,7 @@ import ( "github.com/cloudwego/kitex/pkg/kerrors" "github.com/cloudwego/kitex/pkg/remote" "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streaming" ) @@ -55,10 +57,13 @@ func (c clientTransHandler) NewStream(ctx context.Context, ri rpcinfo.RPCInfo) ( if addr == nil { return nil, kerrors.ErrNoDestAddress } + protocolID, err := getProtocolID(ri) + if err != nil { + return nil, err + } var strHeader streaming.Header var intHeader IntHeader - var err error if c.headerHandler != nil { intHeader, strHeader, err = c.headerHandler.OnWriteStream(ctx) if err != nil { @@ -80,7 +85,7 @@ func (c clientTransHandler) NewStream(ctx context.Context, ri rpcinfo.RPCInfo) ( } // create new stream - cs := newClientStream(ctx, trans, streamFrame{sid: genStreamID(), method: method}) + cs := newClientStream(ctx, trans, streamFrame{sid: genStreamID(), method: method, protocolID: protocolID}) // stream should be configured before WriteStream or there would be a race condition for metaFrameHandler cs.setRecvTimeout(rconfig.StreamRecvTimeout()) cs.setMetaFrameHandler(c.metaHandler) @@ -91,3 +96,14 @@ func (c clientTransHandler) NewStream(ctx context.Context, ri rpcinfo.RPCInfo) ( return cs, err } + +func getProtocolID(ri rpcinfo.RPCInfo) (ttheader.ProtocolID, error) { + switch ri.Config().PayloadCodec() { + case serviceinfo.Thrift: + return ttheader.ProtocolIDThriftStruct, nil + case serviceinfo.Protobuf: + return ttheader.ProtocolIDProtobufStruct, nil + default: + return 0, fmt.Errorf("not supported payload type: %v", ri.Config().PayloadCodec()) + } +} diff --git a/pkg/remote/trans/ttstream/client_handler_test.go b/pkg/remote/trans/ttstream/client_handler_test.go new file mode 100644 index 0000000000..e5d93078ac --- /dev/null +++ b/pkg/remote/trans/ttstream/client_handler_test.go @@ -0,0 +1,72 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttstream + +import ( + "strings" + "testing" + + "github.com/cloudwego/gopkg/protocol/ttheader" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/serviceinfo" +) + +func Test_getProtocolId(t *testing.T) { + tests := []struct { + desc string + payloadType serviceinfo.PayloadCodec + expectID ttheader.ProtocolID + expectErr bool + }{ + { + desc: "thrift payload", + payloadType: serviceinfo.Thrift, + expectID: ttheader.ProtocolIDThriftStruct, + }, + { + desc: "protobuf payload", + payloadType: serviceinfo.Protobuf, + expectID: ttheader.ProtocolIDProtobufStruct, + }, + { + desc: "unsupported payload type", + payloadType: serviceinfo.PayloadCodec(999), + expectID: 0, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + // Create mock rpcinfo with specific payload codec + cfg := rpcinfo.NewRPCConfig() + rpcinfo.AsMutableRPCConfig(cfg).SetPayloadCodec(tt.payloadType) + ri := rpcinfo.NewRPCInfo(nil, nil, nil, cfg, nil) + + gotID, err := getProtocolID(ri) + if tt.expectErr { + test.Assert(t, err != nil) + test.Assert(t, strings.Contains(err.Error(), "not supported payload type:"), err) + } else { + test.Assert(t, err == nil, err) + test.Assert(t, gotID == tt.expectID, gotID) + } + }) + } +} diff --git a/pkg/remote/trans/ttstream/codec.go b/pkg/remote/trans/ttstream/codec.go new file mode 100644 index 0000000000..95d89d9adb --- /dev/null +++ b/pkg/remote/trans/ttstream/codec.go @@ -0,0 +1,123 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttstream + +import ( + "context" + "errors" + "fmt" + + "github.com/cloudwego/gopkg/bufiox" + gopkgthrift "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/gopkg/protocol/ttheader" + + icodec "github.com/cloudwego/kitex/internal/codec" + igeneric "github.com/cloudwego/kitex/internal/generic" + "github.com/cloudwego/kitex/pkg/generic" + "github.com/cloudwego/kitex/pkg/rpcinfo" +) + +var ( + defaultThriftCodec = thriftCodec{} + defaultProtobufCodec = protobufCodec{} +) + +func getCodec(protocolId ttheader.ProtocolID) codec { + if protocolId == ttheader.ProtocolIDProtobufStruct { + return defaultProtobufCodec + } + return defaultThriftCodec +} + +// codec is used to encode/decode payload of ttstream Data Frame. +// now supports thrift and protobuf +type codec interface { + encode(ctx context.Context, ri rpcinfo.RPCInfo, msg any) (payload []byte, needRecycle bool, err error) + decode(ctx context.Context, ri rpcinfo.RPCInfo, payload []byte, msg any) error +} + +var ( + errEncodeThriftGenericEmptyMethod = errors.New("empty methodName in ttstream thrift generic Encode") + errDecodeThriftGenericEmptyMethod = errors.New("empty methodName in ttstream thrift generic Decode") +) + +type thriftCodec struct{} + +func (c thriftCodec) encode(ctx context.Context, ri rpcinfo.RPCInfo, msg any) (payload []byte, needRecycle bool, err error) { + switch t := msg.(type) { + case gopkgthrift.FastCodec: + return gopkgthrift.FastMarshal(t), false, nil + case *generic.Args: + return encodeThriftGeneric(ctx, ri, t) + case *generic.Result: + return encodeThriftGeneric(ctx, ri, t) + default: + return nil, false, fmt.Errorf("invalid payload type %T in ttstream thrift Encode", t) + } +} + +func encodeThriftGeneric(ctx context.Context, ri rpcinfo.RPCInfo, writer igeneric.ThriftWriter) (payload []byte, needRecycle bool, err error) { + methodName := ri.Invocation().MethodName() + if methodName == "" { + return nil, false, errEncodeThriftGenericEmptyMethod + } + var buf []byte + w := bufiox.NewBytesWriter(&buf) + err = writer.Write(ctx, methodName, w) + if err != nil { + return nil, false, err + } + w.Flush() + return buf, false, nil +} + +func (c thriftCodec) decode(ctx context.Context, ri rpcinfo.RPCInfo, payload []byte, msg any) error { + switch t := msg.(type) { + case gopkgthrift.FastCodec: + return gopkgthrift.FastUnmarshal(payload, t) + case *generic.Args: + return decodeThriftGeneric(ctx, ri, payload, t) + case *generic.Result: + return decodeThriftGeneric(ctx, ri, payload, t) + default: + return fmt.Errorf("invalid payload type %T in ttstream thrift Decode", t) + } +} + +func decodeThriftGeneric(ctx context.Context, ri rpcinfo.RPCInfo, payload []byte, reader igeneric.ThriftReader) error { + methodName := ri.Invocation().MethodName() + if methodName == "" { + return errDecodeThriftGenericEmptyMethod + } + r := bufiox.NewBytesReader(payload) + return reader.Read(ctx, methodName, len(payload), r) +} + +type protobufCodec struct{} + +func (p protobufCodec) encode(ctx context.Context, ri rpcinfo.RPCInfo, msg any) ([]byte, bool, error) { + res, err := icodec.TTStreamEncodeProtobufStruct(ctx, ri, msg) + if err != nil { + return nil, false, err + } + + return res.Payload, res.PreAllocate, nil +} + +func (p protobufCodec) decode(ctx context.Context, ri rpcinfo.RPCInfo, payload []byte, msg any) error { + return icodec.DecodeProtobufStruct(ctx, ri, payload, msg) +} diff --git a/pkg/remote/trans/ttstream/codec_test.go b/pkg/remote/trans/ttstream/codec_test.go new file mode 100644 index 0000000000..c92390b059 --- /dev/null +++ b/pkg/remote/trans/ttstream/codec_test.go @@ -0,0 +1,258 @@ +/* + * Copyright 2026 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ttstream + +import ( + "context" + "encoding/binary" + "math/bits" + "testing" + + "github.com/cloudwego/gopkg/protocol/ttheader" + + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/remote/codec/protobuf" + "github.com/cloudwego/kitex/pkg/rpcinfo" +) + +// Mock marshaler (gogoproto style) for testing +type mockMarshaler struct { + data string + fail bool +} + +func (m *mockMarshaler) Size() int { + return len(m.data) +} + +func (m *mockMarshaler) MarshalTo(data []byte) (int, error) { + if m.fail { + return 0, errInvalidMessage + } + copy(data, m.data) + return len(m.data), nil +} + +type mockFastCodecMsg struct { + num int32 + v string +} + +// sizeVarint returns the encoded size of a varint. +// The size is guaranteed to be within 1 and 10, inclusive. +func sizeVarint(v uint64) int { + // This computes 1 + (bits.Len64(v)-1)/7. + // 9/64 is a good enough approximation of 1/7 + return int(9*uint32(bits.Len64(v))+64) / 64 +} + +func (p *mockFastCodecMsg) Size() (n int) { + n += sizeVarint(uint64(p.num)<<3 | uint64(2)) + n += sizeVarint(uint64(len(p.v))) + n += len(p.v) + return n +} + +func (p *mockFastCodecMsg) FastWrite(in []byte) (n int) { + n += binary.PutUvarint(in, uint64(p.num)<<3|uint64(2)) // Tag + n += binary.PutUvarint(in[n:], uint64(len(p.v))) // varint len of string + n += copy(in[n:], p.v) + + return n +} + +func (p *mockFastCodecMsg) FastRead(buf []byte, t int8, number int32) (int, error) { + if t != 2 { + panic(t) + } + p.num = number + sz, n := binary.Uvarint(buf) + buf = buf[n:] + p.v = string(buf[:sz]) + return int(sz) + n, nil +} + +func (p *mockFastCodecMsg) Marshal(_ []byte) ([]byte, error) { panic("not in use") } + +func (p *mockFastCodecMsg) Unmarshal(_ []byte) error { panic("not in use") } + +func newMockRPCInfo(method string) rpcinfo.RPCInfo { + ri := rpcinfo.NewRPCInfo( + rpcinfo.NewEndpointInfo("", method, nil, nil), + rpcinfo.NewEndpointInfo("", method, nil, nil), + rpcinfo.NewInvocation("", method), + rpcinfo.NewRPCConfig(), + rpcinfo.NewRPCStats(), + ) + return ri +} + +func Test_getCodec(t *testing.T) { + tests := []struct { + desc string + protocolID ttheader.ProtocolID + expectType string + }{ + { + desc: "thrift codec", + protocolID: ttheader.ProtocolIDThriftStruct, + expectType: "thriftCodec", + }, + { + desc: "protobuf codec", + protocolID: ttheader.ProtocolIDProtobufStruct, + expectType: "protobufCodec", + }, + { + desc: "unknown defaults to thrift", + protocolID: 99, + expectType: "thriftCodec", + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + c := getCodec(tt.protocolID) + test.Assert(t, c != nil) + // Check type by encoding behavior + switch tt.expectType { + case "thriftCodec": + _, ok := c.(thriftCodec) + test.Assert(t, ok, c) + case "protobufCodec": + _, ok := c.(protobufCodec) + test.Assert(t, ok, c) + } + }) + } +} + +func Test_thriftCodec_FastCodec(t *testing.T) { + c := thriftCodec{} + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + req := &testRequest{ + A: 123, + B: "test", + } + + payload, needRecycle, err := c.encode(ctx, ri, req) + test.Assert(t, err == nil, err) + test.Assert(t, payload != nil) + test.Assert(t, !needRecycle) + + newReq := &testRequest{} + err = c.decode(ctx, ri, payload, newReq) + test.Assert(t, err == nil, err) + test.Assert(t, newReq.A == req.A) + test.Assert(t, newReq.B == req.B) +} + +func Test_thriftCodec_InvalidMessage(t *testing.T) { + c := thriftCodec{} + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + _, _, err := c.encode(ctx, ri, "invalid") + test.Assert(t, err.Error() == "invalid payload type string in ttstream thrift Encode", err) + + err = c.decode(ctx, ri, []byte("data"), "invalid") + test.Assert(t, err.Error() == "invalid payload type string in ttstream thrift Decode", err) +} + +func Test_protobufCodec_ProtoMessage(t *testing.T) { + c := protobufCodec{} + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + req := &protobuf.MockReq{ + Msg: "ProtoMessage", + StrMap: map[string]string{ + "TestKey": "TestVal", + }, + StrList: []string{ + "String", + }, + } + + payload, needRecycle, err := c.encode(ctx, ri, req) + test.Assert(t, err == nil, err) + test.Assert(t, payload != nil) + test.Assert(t, !needRecycle) + + newReq := &protobuf.MockReq{} + err = c.decode(ctx, ri, payload, newReq) + test.Assert(t, err == nil, err) + test.Assert(t, newReq.Msg == req.Msg) + test.DeepEqual(t, newReq.StrMap, req.StrMap) + test.DeepEqual(t, newReq.StrList, req.StrList) +} + +func Test_protobufCodec_FastPb(t *testing.T) { + c := protobufCodec{} + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + msg := &mockFastCodecMsg{ + num: 100, + v: "test", + } + + payload, needRecycle, err := c.encode(ctx, ri, msg) + test.Assert(t, err == nil, err) + test.Assert(t, payload != nil) + test.Assert(t, needRecycle) + + newMsg := &mockFastCodecMsg{} + err = c.decode(ctx, ri, payload, newMsg) + test.Assert(t, err == nil, err) + test.Assert(t, newMsg.num == msg.num) + test.Assert(t, newMsg.v == msg.v) +} + +func Test_protobufCodec_Marshaler(t *testing.T) { + c := protobufCodec{} + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + msg := &mockMarshaler{data: "test"} + payload, needRecycle, err := c.encode(ctx, ri, msg) + test.Assert(t, err == nil, err) + test.Assert(t, payload != nil) + test.Assert(t, needRecycle) + test.Assert(t, string(payload) == "test") + + msgFail := &mockMarshaler{data: "test", fail: true} + payload, needRecycle, err = c.encode(ctx, ri, msgFail) + test.Assert(t, err != nil) + test.Assert(t, payload == nil) + test.Assert(t, !needRecycle) +} + +func Test_protobufCodec_InvalidMessage(t *testing.T) { + c := protobufCodec{} + ctx := context.Background() + ri := newMockRPCInfo("TestMethod") + + _, _, err := c.encode(ctx, ri, "invalid") + test.Assert(t, err != nil) + test.Assert(t, err.Error() == "invalid payload string in EncodeProtobufStruct", err) + + err = c.decode(ctx, ri, []byte("data"), "invalid") + test.Assert(t, err.Error() == "invalid payload string in DecodeProtobufStruct", err) +} diff --git a/pkg/remote/trans/ttstream/frame.go b/pkg/remote/trans/ttstream/frame.go index 4e3f529c81..a3c3199d2e 100644 --- a/pkg/remote/trans/ttstream/frame.go +++ b/pkg/remote/trans/ttstream/frame.go @@ -61,15 +61,16 @@ var ( // Frame define a TTHeader Streaming Frame type Frame struct { streamFrame - typ int32 - payload []byte + typ int32 + payload []byte + recyclePayload bool } func (f *Frame) String() string { return fmt.Sprintf("[sid=%d ftype=%d fmethod=%s]", f.sid, f.typ, f.method) } -func newFrame(sframe streamFrame, typ int32, payload []byte) (fr *Frame) { +func newFrame(sframe streamFrame, typ int32, payload []byte, recyclePayload bool) (fr *Frame) { v := framePool.Get() if v == nil { fr = new(Frame) @@ -79,13 +80,19 @@ func newFrame(sframe streamFrame, typ int32, payload []byte) (fr *Frame) { fr.streamFrame = sframe fr.typ = typ fr.payload = payload + fr.recyclePayload = recyclePayload return fr } func recycleFrame(frame *Frame) { frame.streamFrame = streamFrame{} frame.typ = 0 + if frame.recyclePayload { + mcache.Free(frame.payload) + frame.recyclePayload = false + } frame.payload = nil + frame.protocolID = 0 framePool.Put(frame) } @@ -95,7 +102,7 @@ func EncodeFrame(ctx context.Context, writer bufiox.Writer, fr *Frame) (err erro param := ttheader.EncodeParam{ Flags: ttheader.HeaderFlagsStreaming, SeqID: fr.sid, - ProtocolID: ttheader.ProtocolIDThriftStruct, + ProtocolID: fr.protocolID, } param.IntInfo = fr.meta @@ -168,6 +175,7 @@ func DecodeFrame(ctx context.Context, reader bufiox.Reader) (fr *Frame, err erro } fmethod := dp.IntInfo[ttheader.ToMethod] fsid := dp.SeqID + protocolID := dp.ProtocolID // frame payload var fpayload []byte @@ -183,8 +191,8 @@ func DecodeFrame(ctx context.Context, reader bufiox.Reader) (fr *Frame, err erro } fr = newFrame( - streamFrame{sid: fsid, method: fmethod, meta: fmeta, header: fheader, trailer: ftrailer}, - ftype, fpayload, + streamFrame{sid: fsid, method: fmethod, meta: fmeta, header: fheader, trailer: ftrailer, protocolID: protocolID}, + ftype, fpayload, false, ) return fr, nil } diff --git a/pkg/remote/trans/ttstream/frame_test.go b/pkg/remote/trans/ttstream/frame_test.go index 10cc68d6ad..8f3beedc98 100644 --- a/pkg/remote/trans/ttstream/frame_test.go +++ b/pkg/remote/trans/ttstream/frame_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/cloudwego/gopkg/bufiox" + "github.com/cloudwego/gopkg/protocol/ttheader" "github.com/cloudwego/kitex/internal/test" ) @@ -31,10 +32,11 @@ func TestFrameCodec(t *testing.T) { writer := bufiox.NewDefaultWriter(&buf) reader := bufiox.NewDefaultReader(&buf) wframe := newFrame(streamFrame{ - sid: 0, - method: "method", - header: map[string]string{"key": "value"}, - }, headerFrameType, []byte("hello world")) + sid: 0, + method: "method", + header: map[string]string{"key": "value"}, + protocolID: ttheader.ProtocolIDThriftStruct, + }, headerFrameType, []byte("hello world"), false) for i := 0; i < 10; i++ { wframe.sid = int32(i) diff --git a/pkg/remote/trans/ttstream/server_handler.go b/pkg/remote/trans/ttstream/server_handler.go index 5611efe806..4fe92dbd64 100644 --- a/pkg/remote/trans/ttstream/server_handler.go +++ b/pkg/remote/trans/ttstream/server_handler.go @@ -178,8 +178,12 @@ func (t *svrTransHandler) OnStream(ctx context.Context, conn net.Conn, st *serve stCtx := rpcinfo.NewCtxWithRPCInfo(st.ctx, ri) ink := ri.Invocation().(rpcinfo.InvocationSetter) - // TODO: support protobuf codec, and make `strict` true when combine service is not supported. - sinfo := t.opt.SvcSearcher.SearchService(st.Service(), st.Method(), false, serviceinfo.Thrift) + cfg := rpcinfo.AsMutableRPCConfig(ri.Config()) + cfg.SetTransportProtocol(st.TransportProtocol()) + payloadCodec := getPayloadCodecFromProtocolID(st.protocolID) + cfg.SetPayloadCodec(payloadCodec) + // TODO: make `strict` true when combine service is not supported. + sinfo := t.opt.SvcSearcher.SearchService(st.Service(), st.Method(), false, payloadCodec) if sinfo == nil { err = remote.NewTransErrorWithMsg(remote.UnknownService, fmt.Sprintf("unknown service %s", st.Service())) return @@ -202,7 +206,6 @@ func (t *svrTransHandler) OnStream(ctx context.Context, conn net.Conn, st *serve //nolint:staticcheck // SA1029: consts.CtxKeyMethod has been used and we just follow it stCtx = context.WithValue(stCtx, consts.CtxKeyMethod, st.Method()) } - rpcinfo.AsMutableRPCConfig(ri.Config()).SetTransportProtocol(st.TransportProtocol()) // headerHandler return a new stream level ctx // it contains rpcinfo modified by HeaderHandler @@ -333,3 +336,10 @@ func (t *svrTransHandler) finishTracer(ctx context.Context, ri rpcinfo.RPCInfo, t.opt.TracerCtl.DoFinish(ctx, ri, err) rpcStats.Reset() } + +func getPayloadCodecFromProtocolID(protocolID ttheader.ProtocolID) serviceinfo.PayloadCodec { + if protocolID == ttheader.ProtocolIDProtobufStruct { + return serviceinfo.Protobuf + } + return serviceinfo.Thrift +} diff --git a/pkg/remote/trans/ttstream/server_handler_test.go b/pkg/remote/trans/ttstream/server_handler_test.go index 893048292e..608abc28b8 100644 --- a/pkg/remote/trans/ttstream/server_handler_test.go +++ b/pkg/remote/trans/ttstream/server_handler_test.go @@ -164,252 +164,265 @@ func TestOnRead(t *testing.T) { return } - t.Run("invoking handler successfully", func(t *testing.T) { - ctx, ripTag, tracer, transHdl, wconn, wbuf, mockConn := prepare() - tracer.finishFunc = func(ctx context.Context) { - ri := rpcinfo.GetRPCInfo(ctx) - test.Assert(t, ri != nil, ri) - rip, ok := ri.From().Tag(ripTag) - test.Assert(t, ok) - test.Assert(t, rip == "127.0.0.1:8888", rip) - } - var invoked int32 - transHdl.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { - atomic.StoreInt32(&invoked, 1) - return nil - }) - err := EncodeFrame(context.Background(), wbuf, &Frame{ - streamFrame: streamFrame{ - sid: 1, - method: mocks.MockStreamingMethod, - header: map[string]string{ - ttheader.HeaderIDLServiceName: mocks.MockServiceName, - ttheader.HeaderTransRemoteAddr: "127.0.0.1:8888", + for _, protocolID := range []ttheader.ProtocolID{ttheader.ProtocolIDThriftStruct, ttheader.ProtocolIDProtobufStruct} { + + t.Run("invoking handler successfully", func(t *testing.T) { + ctx, ripTag, tracer, transHdl, wconn, wbuf, mockConn := prepare() + tracer.finishFunc = func(ctx context.Context) { + ri := rpcinfo.GetRPCInfo(ctx) + test.Assert(t, ri != nil, ri) + rip, ok := ri.From().Tag(ripTag) + test.Assert(t, ok) + test.Assert(t, rip == "127.0.0.1:8888", rip) + test.Assert(t, ri.Config().PayloadCodec() == getPayloadCodecFromProtocolID(protocolID), ri.Config()) + } + var invoked int32 + transHdl.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { + atomic.StoreInt32(&invoked, 1) + return nil + }) + err := EncodeFrame(context.Background(), wbuf, &Frame{ + streamFrame: streamFrame{ + sid: 1, + method: mocks.MockStreamingMethod, + header: map[string]string{ + ttheader.HeaderIDLServiceName: mocks.MockServiceName, + ttheader.HeaderTransRemoteAddr: "127.0.0.1:8888", + }, + protocolID: protocolID, }, - }, - typ: headerFrameType, - }) - test.Assert(t, err == nil, err) - err = wbuf.Flush() - test.Assert(t, err == nil, err) - go func() { - time.Sleep(1 * time.Second) - wconn.Close() - }() - err = transHdl.OnRead(ctx, mockConn) - test.Assert(t, err == nil, err) - test.Assert(t, atomic.LoadInt32(&invoked) == 1) - }) - - t.Run("invoking handler panic", func(t *testing.T) { - ctx, ripTag, tracer, transHdl, wconn, wbuf, mockConn := prepare() - tracer.finishFunc = func(ctx context.Context) { - ri := rpcinfo.GetRPCInfo(ctx) - test.Assert(t, ri != nil, ri) - rip, ok := ri.From().Tag(ripTag) - test.Assert(t, ok) - test.Assert(t, rip == "127.0.0.1:8888", rip) - ok, pErr := ri.Stats().Panicked() - test.Assert(t, ok) - test.Assert(t, errors.Is(pErr.(error), kerrors.ErrPanic), pErr) - test.Assert(t, errors.Is(ri.Stats().Error(), kerrors.ErrPanic)) - } - transHdl.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { - defer func() { - if handlerErr := recover(); handlerErr != nil { - ri := rpcinfo.GetRPCInfo(ctx) - err = kerrors.ErrPanic.WithCauseAndStack(fmt.Errorf("[panic] %s", handlerErr), string(debug.Stack())) - rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats()) - rpcStats.SetPanicked(err) - } + typ: headerFrameType, + }) + test.Assert(t, err == nil, err) + err = wbuf.Flush() + test.Assert(t, err == nil, err) + go func() { + time.Sleep(1 * time.Second) + wconn.Close() }() - panic("test") + err = transHdl.OnRead(ctx, mockConn) + test.Assert(t, err == nil, err) + test.Assert(t, atomic.LoadInt32(&invoked) == 1) }) - err := EncodeFrame(context.Background(), wbuf, &Frame{ - streamFrame: streamFrame{ - sid: 1, - method: mocks.MockStreamingMethod, - header: map[string]string{ - ttheader.HeaderTransRemoteAddr: "127.0.0.1:8888", - }, - }, - typ: headerFrameType, - }) - test.Assert(t, err == nil, err) - err = wbuf.Flush() - test.Assert(t, err == nil, err) - go func() { - time.Sleep(1 * time.Second) - wconn.Close() - }() - err = transHdl.OnRead(ctx, mockConn) - test.Assert(t, err == nil, err) - }) - - t.Run("invoking handler throws biz error", func(t *testing.T) { - ctx, ripTag, tracer, transHdl, wconn, wbuf, mockConn := prepare() - tracer.finishFunc = func(ctx context.Context) { - ri := rpcinfo.GetRPCInfo(ctx) - test.Assert(t, ri != nil, ri) - rip, ok := ri.From().Tag(ripTag) - test.Assert(t, ok) - test.Assert(t, rip == "127.0.0.1:8888", rip) - bizErr := ri.Invocation().BizStatusErr() - test.Assert(t, bizErr.BizStatusCode() == 10000, bizErr) - test.Assert(t, strings.Contains(bizErr.BizMessage(), "biz-error test"), bizErr) - } - transHdl.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { - ri := rpcinfo.GetRPCInfo(ctx) - defer func() { - if bizErr, ok := kerrors.FromBizStatusError(err); ok { - if setter, ok := ri.Invocation().(rpcinfo.InvocationSetter); ok { - setter.SetBizStatusErr(bizErr) - err = nil + + t.Run("invoking handler panic", func(t *testing.T) { + ctx, ripTag, tracer, transHdl, wconn, wbuf, mockConn := prepare() + tracer.finishFunc = func(ctx context.Context) { + ri := rpcinfo.GetRPCInfo(ctx) + test.Assert(t, ri != nil, ri) + rip, ok := ri.From().Tag(ripTag) + test.Assert(t, ok) + test.Assert(t, rip == "127.0.0.1:8888", rip) + ok, pErr := ri.Stats().Panicked() + test.Assert(t, ok) + test.Assert(t, errors.Is(pErr.(error), kerrors.ErrPanic), pErr) + test.Assert(t, errors.Is(ri.Stats().Error(), kerrors.ErrPanic)) + test.Assert(t, ri.Config().PayloadCodec() == getPayloadCodecFromProtocolID(protocolID), ri.Config()) + } + transHdl.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { + defer func() { + if handlerErr := recover(); handlerErr != nil { + ri := rpcinfo.GetRPCInfo(ctx) + err = kerrors.ErrPanic.WithCauseAndStack(fmt.Errorf("[panic] %s", handlerErr), string(debug.Stack())) + rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats()) + rpcStats.SetPanicked(err) } - } + }() + panic("test") + }) + err := EncodeFrame(context.Background(), wbuf, &Frame{ + streamFrame: streamFrame{ + sid: 1, + method: mocks.MockStreamingMethod, + header: map[string]string{ + ttheader.HeaderTransRemoteAddr: "127.0.0.1:8888", + }, + protocolID: protocolID, + }, + typ: headerFrameType, + }) + test.Assert(t, err == nil, err) + err = wbuf.Flush() + test.Assert(t, err == nil, err) + go func() { + time.Sleep(1 * time.Second) + wconn.Close() }() - return kerrors.NewBizStatusError(10000, "biz-error test") + err = transHdl.OnRead(ctx, mockConn) + test.Assert(t, err == nil, err) }) - err := EncodeFrame(context.Background(), wbuf, &Frame{ - streamFrame: streamFrame{ - sid: 1, - method: mocks.MockStreamingMethod, - header: map[string]string{ - ttheader.HeaderTransRemoteAddr: "127.0.0.1:8888", + + t.Run("invoking handler throws biz error", func(t *testing.T) { + ctx, ripTag, tracer, transHdl, wconn, wbuf, mockConn := prepare() + tracer.finishFunc = func(ctx context.Context) { + ri := rpcinfo.GetRPCInfo(ctx) + test.Assert(t, ri != nil, ri) + rip, ok := ri.From().Tag(ripTag) + test.Assert(t, ok) + test.Assert(t, rip == "127.0.0.1:8888", rip) + bizErr := ri.Invocation().BizStatusErr() + test.Assert(t, bizErr.BizStatusCode() == 10000, bizErr) + test.Assert(t, strings.Contains(bizErr.BizMessage(), "biz-error test"), bizErr) + test.Assert(t, ri.Config().PayloadCodec() == getPayloadCodecFromProtocolID(protocolID), ri.Config()) + } + transHdl.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { + ri := rpcinfo.GetRPCInfo(ctx) + defer func() { + if bizErr, ok := kerrors.FromBizStatusError(err); ok { + if setter, ok := ri.Invocation().(rpcinfo.InvocationSetter); ok { + setter.SetBizStatusErr(bizErr) + err = nil + } + } + }() + return kerrors.NewBizStatusError(10000, "biz-error test") + }) + err := EncodeFrame(context.Background(), wbuf, &Frame{ + streamFrame: streamFrame{ + sid: 1, + method: mocks.MockStreamingMethod, + header: map[string]string{ + ttheader.HeaderTransRemoteAddr: "127.0.0.1:8888", + }, + protocolID: protocolID, }, - }, - typ: headerFrameType, - }) - test.Assert(t, err == nil, err) - err = wbuf.Flush() - test.Assert(t, err == nil, err) - go func() { - time.Sleep(1 * time.Second) - wconn.Close() - }() - err = transHdl.OnRead(ctx, mockConn) - test.Assert(t, err == nil, err) - }) - - t.Run("invoking handler and getting K_METHOD successfully", func(t *testing.T) { - ctx, ripTag, tracer, transHdl, wconn, wbuf, mockConn := prepare() - tracer.finishFunc = func(ctx context.Context) { - ri := rpcinfo.GetRPCInfo(ctx) - test.Assert(t, ri != nil, ri) - rip, ok := ri.From().Tag(ripTag) - test.Assert(t, ok) - test.Assert(t, rip == "127.0.0.1:8888", rip) - // retrieve TO method from rpcinfo - mt := ri.To().Method() - test.Assert(t, mt == mocks.MockStreamingMethod, mt) - // retrieve K_METHOD from ctx - kMt := ctx.Value(consts.CtxKeyMethod).(string) - test.Assert(t, kMt == mocks.MockStreamingMethod, kMt) - } - var invoked int32 - transHdl.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { - atomic.StoreInt32(&invoked, 1) - // retrieve K_METHOD from ctx - kMt := ctx.Value(consts.CtxKeyMethod).(string) - test.Assert(t, kMt == mocks.MockStreamingMethod, kMt) - return nil + typ: headerFrameType, + }) + test.Assert(t, err == nil, err) + err = wbuf.Flush() + test.Assert(t, err == nil, err) + go func() { + time.Sleep(1 * time.Second) + wconn.Close() + }() + err = transHdl.OnRead(ctx, mockConn) + test.Assert(t, err == nil, err) }) - err := EncodeFrame(context.Background(), wbuf, &Frame{ - streamFrame: streamFrame{ - sid: 1, - method: mocks.MockStreamingMethod, - header: map[string]string{ - ttheader.HeaderIDLServiceName: mocks.MockServiceName, - ttheader.HeaderTransRemoteAddr: "127.0.0.1:8888", + + t.Run("invoking handler and getting K_METHOD successfully", func(t *testing.T) { + ctx, ripTag, tracer, transHdl, wconn, wbuf, mockConn := prepare() + tracer.finishFunc = func(ctx context.Context) { + ri := rpcinfo.GetRPCInfo(ctx) + test.Assert(t, ri != nil, ri) + rip, ok := ri.From().Tag(ripTag) + test.Assert(t, ok) + test.Assert(t, rip == "127.0.0.1:8888", rip) + // retrieve TO method from rpcinfo + mt := ri.To().Method() + test.Assert(t, mt == mocks.MockStreamingMethod, mt) + // retrieve K_METHOD from ctx + kMt := ctx.Value(consts.CtxKeyMethod).(string) + test.Assert(t, kMt == mocks.MockStreamingMethod, kMt) + test.Assert(t, ri.Config().PayloadCodec() == getPayloadCodecFromProtocolID(protocolID), ri.Config()) + } + var invoked int32 + transHdl.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { + atomic.StoreInt32(&invoked, 1) + // retrieve K_METHOD from ctx + kMt := ctx.Value(consts.CtxKeyMethod).(string) + test.Assert(t, kMt == mocks.MockStreamingMethod, kMt) + return nil + }) + err := EncodeFrame(context.Background(), wbuf, &Frame{ + streamFrame: streamFrame{ + sid: 1, + method: mocks.MockStreamingMethod, + header: map[string]string{ + ttheader.HeaderIDLServiceName: mocks.MockServiceName, + ttheader.HeaderTransRemoteAddr: "127.0.0.1:8888", + }, + protocolID: protocolID, }, - }, - typ: headerFrameType, + typ: headerFrameType, + }) + test.Assert(t, err == nil, err) + err = wbuf.Flush() + test.Assert(t, err == nil, err) + go func() { + time.Sleep(1 * time.Second) + wconn.Close() + }() + err = transHdl.OnRead(ctx, mockConn) + test.Assert(t, err == nil, err) + test.Assert(t, atomic.LoadInt32(&invoked) == 1) }) - test.Assert(t, err == nil, err) - err = wbuf.Flush() - test.Assert(t, err == nil, err) - go func() { - time.Sleep(1 * time.Second) - wconn.Close() - }() - err = transHdl.OnRead(ctx, mockConn) - test.Assert(t, err == nil, err) - test.Assert(t, atomic.LoadInt32(&invoked) == 1) - }) - - t.Run("rpcinfo reuse disabled", func(t *testing.T) { - var ri rpcinfo.RPCInfo - ctx, _, _, transHdl, wconn, wbuf, mockConn := prepare() - transHdl.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { - ri = rpcinfo.GetRPCInfo(ctx) - return nil - }) - err := EncodeFrame(context.Background(), wbuf, &Frame{ - streamFrame: streamFrame{ - sid: 1, - method: mocks.MockStreamingMethod, - header: map[string]string{ - ttheader.HeaderIDLServiceName: mocks.MockServiceName, - ttheader.HeaderTransRemoteAddr: "127.0.0.1:8888", + t.Run("rpcinfo reuse disabled", func(t *testing.T) { + var ri rpcinfo.RPCInfo + ctx, _, _, transHdl, wconn, wbuf, mockConn := prepare() + + transHdl.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { + ri = rpcinfo.GetRPCInfo(ctx) + return nil + }) + err := EncodeFrame(context.Background(), wbuf, &Frame{ + streamFrame: streamFrame{ + sid: 1, + method: mocks.MockStreamingMethod, + header: map[string]string{ + ttheader.HeaderIDLServiceName: mocks.MockServiceName, + ttheader.HeaderTransRemoteAddr: "127.0.0.1:8888", + }, + protocolID: protocolID, }, - }, - typ: headerFrameType, - }) - test.Assert(t, err == nil, err) - err = wbuf.Flush() - test.Assert(t, err == nil, err) - go func() { - time.Sleep(1 * time.Second) - wconn.Close() - }() - err = transHdl.OnRead(ctx, mockConn) - test.Assert(t, err == nil, err) - test.Assert(t, ri.Invocation().MethodName() == mocks.MockStreamingMethod, ri) - }) - - t.Run("access rpcinfo asynchronously would not cause panic", func(t *testing.T) { - ctx, _, _, transHdl, wconn, wbuf, mockConn := prepare() - var wg sync.WaitGroup - - transHdl.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { - for i := 0; i < 20; i++ { - wg.Add(1) - go func(c context.Context) { - defer func() { - if r := recover(); r != nil { - t.Error(r) - } - wg.Done() - }() - time.Sleep(50 * time.Millisecond) - ri := rpcinfo.GetRPCInfo(c) - // access rpcinfo - ri.From().Tag("key") - }(ctx) - } - return nil + typ: headerFrameType, + }) + test.Assert(t, err == nil, err) + err = wbuf.Flush() + test.Assert(t, err == nil, err) + go func() { + time.Sleep(1 * time.Second) + wconn.Close() + }() + err = transHdl.OnRead(ctx, mockConn) + test.Assert(t, err == nil, err) + test.Assert(t, ri.Invocation().MethodName() == mocks.MockStreamingMethod, ri) }) - err := EncodeFrame(context.Background(), wbuf, &Frame{ - streamFrame: streamFrame{ - sid: 1, - method: mocks.MockStreamingMethod, - header: map[string]string{ - ttheader.HeaderIDLServiceName: mocks.MockServiceName, - ttheader.HeaderTransRemoteAddr: "127.0.0.1:8888", + + t.Run("access rpcinfo asynchronously would not cause panic", func(t *testing.T) { + ctx, _, _, transHdl, wconn, wbuf, mockConn := prepare() + var wg sync.WaitGroup + + transHdl.SetInvokeHandleFunc(func(ctx context.Context, req, resp interface{}) (err error) { + for i := 0; i < 20; i++ { + wg.Add(1) + go func(c context.Context) { + defer func() { + if r := recover(); r != nil { + t.Error(r) + } + wg.Done() + }() + time.Sleep(50 * time.Millisecond) + ri := rpcinfo.GetRPCInfo(c) + // access rpcinfo + ri.From().Tag("key") + }(ctx) + } + return nil + }) + err := EncodeFrame(context.Background(), wbuf, &Frame{ + streamFrame: streamFrame{ + sid: 1, + method: mocks.MockStreamingMethod, + header: map[string]string{ + ttheader.HeaderIDLServiceName: mocks.MockServiceName, + ttheader.HeaderTransRemoteAddr: "127.0.0.1:8888", + }, + protocolID: protocolID, }, - }, - typ: headerFrameType, + typ: headerFrameType, + }) + test.Assert(t, err == nil, err) + err = wbuf.Flush() + test.Assert(t, err == nil, err) + go func() { + time.Sleep(1 * time.Second) + wconn.Close() + }() + err = transHdl.OnRead(ctx, mockConn) + test.Assert(t, err == nil, err) + wg.Wait() }) - test.Assert(t, err == nil, err) - err = wbuf.Flush() - test.Assert(t, err == nil, err) - go func() { - time.Sleep(1 * time.Second) - wconn.Close() - }() - err = transHdl.OnRead(ctx, mockConn) - test.Assert(t, err == nil, err) - wg.Wait() - }) + } } diff --git a/pkg/remote/trans/ttstream/stream.go b/pkg/remote/trans/ttstream/stream.go index a6424b25b3..552a328742 100644 --- a/pkg/remote/trans/ttstream/stream.go +++ b/pkg/remote/trans/ttstream/stream.go @@ -48,16 +48,18 @@ func newBasicStream(ctx context.Context, writer streamWriter, smeta streamFrame) s.writer = writer s.wheader = make(streaming.Header) s.wtrailer = make(streaming.Trailer) + s.codec = getCodec(smeta.protocolID) return s } // streamFrame define a basic stream frame type streamFrame struct { - sid int32 - method string - meta IntHeader - header streaming.Header // key:value, key is full name - trailer streaming.Trailer + sid int32 + method string + meta IntHeader + header streaming.Header // key:value, key is full name + trailer streaming.Trailer + protocolID ttheader.ProtocolID } const ( @@ -86,6 +88,7 @@ type stream struct { recvTimeout time.Duration closeCallback []func(error) + codec codec } func (s *stream) Service() string { @@ -109,7 +112,7 @@ func (s *stream) TransportProtocol() ktransport.Protocol { // here, and the context passed in by the user is ignored. func (s *stream) SendMsg(ctx context.Context, msg any) (err error) { // encode payload - payload, err := EncodePayload(s.ctx, msg) + payload, needRecycle, err := s.codec.encode(s.ctx, s.rpcInfo, msg) if err != nil { return err } @@ -121,7 +124,7 @@ func (s *stream) SendMsg(ctx context.Context, msg any) (err error) { } } // send data frame - return s.writeFrame(dataFrameType, nil, nil, payload) + return s.writeFrame(dataFrameType, nil, nil, payload, needRecycle) } func (s *stream) RecvMsg(ctx context.Context, data any) error { @@ -135,7 +138,7 @@ func (s *stream) RecvMsg(ctx context.Context, data any) error { if err != nil { return err } - err = DecodePayload(nctx, payload, data) + err = s.codec.decode(nctx, s.rpcInfo, payload, data) // payload will not be access after decode mcache.Free(payload) @@ -169,8 +172,8 @@ func (s *stream) runCloseCallback(exception error) { _ = s.writer.CloseStream(s.sid) } -func (s *stream) writeFrame(ftype int32, header streaming.Header, trailer streaming.Trailer, payload []byte) (err error) { - fr := newFrame(streamFrame{sid: s.sid, method: s.method, header: header, trailer: trailer}, ftype, payload) +func (s *stream) writeFrame(ftype int32, header streaming.Header, trailer streaming.Trailer, payload []byte, recyclePayload bool) (err error) { + fr := newFrame(streamFrame{sid: s.sid, method: s.method, header: header, trailer: trailer, protocolID: s.protocolID}, ftype, payload, recyclePayload) return s.writer.WriteFrame(fr) } @@ -190,7 +193,7 @@ func (s *stream) sendTrailer(exception error) (err error) { return err } } - err = s.writeFrame(trailerFrameType, nil, wtrailer, payload) + err = s.writeFrame(trailerFrameType, nil, wtrailer, payload, false) return err } @@ -207,7 +210,7 @@ func (s *stream) sendRst(exception error, cancelPath string) (err error) { header = make(streaming.Header) header[ttheader.HeaderTTStreamCancelPath] = cancelPath } - return s.writeFrame(rstFrameType, header, nil, payload) + return s.writeFrame(rstFrameType, header, nil, payload, false) } // === Frame OnRead callback diff --git a/pkg/remote/trans/ttstream/stream_client_test.go b/pkg/remote/trans/ttstream/stream_client_test.go index acc61b13d6..9d578e968a 100644 --- a/pkg/remote/trans/ttstream/stream_client_test.go +++ b/pkg/remote/trans/ttstream/stream_client_test.go @@ -26,12 +26,14 @@ import ( "testing" "time" + "github.com/cloudwego/gopkg/protocol/ttheader" + "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/rpcinfo" ) func newTestClientStream(ctx context.Context) *clientStream { - return newClientStream(ctx, mockStreamWriter{}, streamFrame{}) + return newClientStream(ctx, mockStreamWriter{}, streamFrame{protocolID: ttheader.ProtocolIDThriftStruct}) } func Test_clientStreamStateChange(t *testing.T) { @@ -136,7 +138,7 @@ func Test_clientStream_parseCtxErr(t *testing.T) { for _, tc := range testcases { t.Run(tc.desc, func(t *testing.T) { ctx, cancel := tc.ctxFunc() - cs := newClientStream(ctx, nil, streamFrame{}) + cs := newClientStream(ctx, nil, streamFrame{protocolID: ttheader.ProtocolIDThriftStruct}) cancel() finalEx, cancelPath := cs.parseCtxErr(ctx) test.DeepEqual(t, finalEx, tc.expectEx) @@ -147,7 +149,7 @@ func Test_clientStream_parseCtxErr(t *testing.T) { func Test_clientStream_SendMsg(t *testing.T) { ctx := context.Background() - cs := newClientStream(ctx, &mockStreamWriter{}, streamFrame{}) + cs := newClientStream(ctx, &mockStreamWriter{}, streamFrame{protocolID: ttheader.ProtocolIDThriftStruct}) req := &testRequest{B: "SendMsgTest"} // Send successfully @@ -161,7 +163,7 @@ func Test_clientStream_SendMsg(t *testing.T) { test.Assert(t, errors.Is(err, errIllegalOperation), err) test.Assert(t, strings.Contains(err.Error(), "stream is closed send")) - cs = newClientStream(ctx, &mockStreamWriter{}, streamFrame{}) + cs = newClientStream(ctx, &mockStreamWriter{}, streamFrame{protocolID: ttheader.ProtocolIDThriftStruct}) // Send retrieves the close stream exception ex := errDownstreamCancel.newBuilder().withSide(clientSide) cs.close(ex, false, "") diff --git a/pkg/remote/trans/ttstream/stream_server.go b/pkg/remote/trans/ttstream/stream_server.go index 10db82fe7f..770dadef41 100644 --- a/pkg/remote/trans/ttstream/stream_server.go +++ b/pkg/remote/trans/ttstream/stream_server.go @@ -82,7 +82,7 @@ func (s *serverStream) sendHeader() (err error) { if wheader == nil { return fmt.Errorf("stream header already sent") } - err = s.writeFrame(headerFrameType, wheader, nil, nil) + err = s.writeFrame(headerFrameType, wheader, nil, nil, false) return err } diff --git a/pkg/remote/trans/ttstream/stream_server_test.go b/pkg/remote/trans/ttstream/stream_server_test.go index fcab6b5431..e4d4cade4b 100644 --- a/pkg/remote/trans/ttstream/stream_server_test.go +++ b/pkg/remote/trans/ttstream/stream_server_test.go @@ -25,6 +25,8 @@ import ( "testing" "time" + "github.com/cloudwego/gopkg/protocol/ttheader" + "github.com/cloudwego/kitex/internal/test" "github.com/cloudwego/kitex/pkg/streaming" ) @@ -36,7 +38,7 @@ func newTestServerStream() *serverStream { func newTestServerStreamWithStreamWriter(w streamWriter) *serverStream { ctx, cancel := context.WithCancel(context.Background()) ctx, cancelFunc := newContextWithCancelReason(ctx, cancel) - srvSt := newServerStream(ctx, w, streamFrame{}) + srvSt := newServerStream(ctx, w, streamFrame{protocolID: ttheader.ProtocolIDThriftStruct}) srvSt.cancelFunc = cancelFunc return srvSt } diff --git a/pkg/remote/trans/ttstream/test_utils.go b/pkg/remote/trans/ttstream/test_utils.go index 02757d41c3..0384ff5b38 100644 --- a/pkg/remote/trans/ttstream/test_utils.go +++ b/pkg/remote/trans/ttstream/test_utils.go @@ -21,6 +21,7 @@ package ttstream import ( "context" + "github.com/cloudwego/gopkg/protocol/ttheader" "github.com/cloudwego/netpoll" "github.com/cloudwego/kitex/pkg/serviceinfo" @@ -61,7 +62,7 @@ func newTestStreamPipe(sinfo *serviceinfo.ServiceInfo, method string) (*clientSt strHeader := make(streaming.Header) ctrans := newClientTransport(cconn, nil) ctx := context.Background() - cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: method}) + cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: method, protocolID: ttheader.ProtocolIDThriftStruct}) if err = ctrans.WriteStream(ctx, cs, intHeader, strHeader); err != nil { return nil, nil, err } diff --git a/pkg/remote/trans/ttstream/transport_client.go b/pkg/remote/trans/ttstream/transport_client.go index 10903ea4df..c86c6c38ff 100644 --- a/pkg/remote/trans/ttstream/transport_client.go +++ b/pkg/remote/trans/ttstream/transport_client.go @@ -229,6 +229,7 @@ func (t *clientTransport) loopWrite() error { for i := 0; i < n; i++ { fr := fcache[i] if err = EncodeFrame(context.Background(), writer, fr); err != nil { + recycleFrame(fr) return err } recycleFrame(fr) @@ -275,7 +276,7 @@ func (t *clientTransport) WriteStream( ) error { t.storeStream(s) // send create stream request for server - fr := newFrame(streamFrame{sid: s.sid, method: s.method, header: strHeader, meta: intHeader}, headerFrameType, nil) + fr := newFrame(streamFrame{sid: s.sid, method: s.method, header: strHeader, meta: intHeader, protocolID: s.protocolID}, headerFrameType, nil, false) if err := t.WriteFrame(fr); err != nil { return err } diff --git a/pkg/remote/trans/ttstream/transport_server.go b/pkg/remote/trans/ttstream/transport_server.go index 4a67dd33bf..146311b269 100644 --- a/pkg/remote/trans/ttstream/transport_server.go +++ b/pkg/remote/trans/ttstream/transport_server.go @@ -219,6 +219,7 @@ func (t *serverTransport) loopWrite() error { for i := 0; i < n; i++ { fr := fcache[i] if err = EncodeFrame(context.Background(), writer, fr); err != nil { + recycleFrame(fr) return err } recycleFrame(fr) diff --git a/pkg/remote/trans/ttstream/transport_test.go b/pkg/remote/trans/ttstream/transport_test.go index 9a3fae91c3..7a5781d366 100644 --- a/pkg/remote/trans/ttstream/transport_test.go +++ b/pkg/remote/trans/ttstream/transport_test.go @@ -78,7 +78,7 @@ func TestTransportBasic(t *testing.T) { ctrans := newClientTransport(cconn, nil) defer ctrans.Close(nil) ctx := context.Background() - cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: "Bidi"}) + cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: "Bidi", protocolID: ttheader.ProtocolIDThriftStruct}) err = ctrans.WriteStream(ctx, cs, intHeader, strHeader) test.Assert(t, err == nil, err) strans := newServerTransport(sconn) @@ -147,7 +147,7 @@ func TestTransportServerStreaming(t *testing.T) { ctrans := newClientTransport(cconn, nil) defer ctrans.Close(nil) ctx := context.Background() - cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: "Bidi"}) + cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: "Bidi", protocolID: ttheader.ProtocolIDThriftStruct}) err = ctrans.WriteStream(ctx, cs, intHeader, strHeader) test.Assert(t, err == nil, err) strans := newServerTransport(sconn) @@ -212,7 +212,7 @@ func TestTransportException(t *testing.T) { ctrans := newClientTransport(cconn, nil) defer ctrans.Close(nil) ctx := context.Background() - cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: "Bidi"}) + cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: "Bidi", protocolID: ttheader.ProtocolIDThriftStruct}) err = ctrans.WriteStream(ctx, cs, make(IntHeader), make(streaming.Header)) test.Assert(t, err == nil, err) strans := newServerTransport(sconn) @@ -242,7 +242,7 @@ func TestTransportException(t *testing.T) { // server send illegal frame ctx = context.Background() - cs = newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: "Bidi"}) + cs = newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: "Bidi", protocolID: ttheader.ProtocolIDThriftStruct}) err = ctrans.WriteStream(ctx, cs, make(IntHeader), make(streaming.Header)) test.Assert(t, err == nil, err) ss, err = strans.ReadStream(context.Background()) @@ -271,7 +271,7 @@ func TestTransportClose(t *testing.T) { ctrans := newClientTransport(cconn, nil) defer ctrans.Close(nil) ctx := context.Background() - cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: "Bidi"}) + cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: "Bidi", protocolID: ttheader.ProtocolIDThriftStruct}) err = ctrans.WriteStream(ctx, cs, intHeader, strHeader) test.Assert(t, err == nil, err) strans := newServerTransport(sconn) @@ -349,7 +349,7 @@ func Test_clientStreamReceiveTrailer(t *testing.T) { ctrans := newClientTransport(cconn, nil) defer ctrans.Close(nil) ctx := context.Background() - cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: "Bidi"}) + cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: "Bidi", protocolID: ttheader.ProtocolIDThriftStruct}) err = ctrans.WriteStream(ctx, cs, intHeader, strHeader) test.Assert(t, err == nil, err) strans := newServerTransport(sconn) @@ -407,7 +407,7 @@ func Test_clientStreamReceiveMetaFrame(t *testing.T) { nil, remoteinfo.NewRemoteInfo(&rpcinfo.EndpointBasicInfo{Tags: make(map[string]string)}, testMethod), nil, nil, nil, )) finishCh := make(chan struct{}) - cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: testMethod}) + cs := newClientStream(ctx, ctrans, streamFrame{sid: genStreamID(), method: testMethod, protocolID: ttheader.ProtocolIDThriftStruct}) cs.setMetaFrameHandler(&mockMetaFrameHandler{ onMetaFrame: func(ctx context.Context, intHeader IntHeader, header streaming.Header, payload []byte) error { ri := rpcinfo.GetRPCInfo(ctx) @@ -571,7 +571,7 @@ func initTestStreams(t *testing.T, cCtx context.Context, method, cliNodeName, sr intHeader := make(IntHeader) strHeader := make(streaming.Header) ctrans := newClientTransport(cconn, nil) - cs := newClientStream(cCtx, ctrans, streamFrame{sid: genStreamID(), method: method}) + cs := newClientStream(cCtx, ctrans, streamFrame{sid: genStreamID(), method: method, protocolID: ttheader.ProtocolIDThriftStruct}) cs.rpcInfo = rpcinfo.NewRPCInfo( rpcinfo.NewEndpointInfo(cliNodeName, method, nil, nil), nil, nil, nil, nil) err = ctrans.WriteStream(cCtx, cs, intHeader, strHeader) diff --git a/tool/internal_pkg/tpl/client_v2.go b/tool/internal_pkg/tpl/client_v2.go index 5faaa55a6c..2c9a31d062 100644 --- a/tool/internal_pkg/tpl/client_v2.go +++ b/tool/internal_pkg/tpl/client_v2.go @@ -60,11 +60,8 @@ func NewClient(destService string, opts ...client.Option) (Client, error) { {{template "@client.go-NewClient-option" .}} {{if .HasStreaming}} - {{- if eq $.Codec "thrift"}}{{/* thrift streaming enable ttheader streaming protocol by default */}} + {{/* thrift streaming enable ttheader streaming protocol by default */}} options = append(options, client.WithTransportProtocol(transport.TTHeaderStreaming)) - {{- else}} - options = append(options, client.WithTransportProtocol(transport.GRPC)) - {{- end}} {{end}} options = append(options, opts...)