Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/protoc-gen-go-grpc/go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module google.golang.org/grpc/cmd/protoc-gen-go-grpc

go 1.23.0
go 1.24.0

require (
google.golang.org/grpc v1.70.0
Expand Down
14 changes: 11 additions & 3 deletions credentials/alts/alts.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ type ClientOptions struct {
// HandshakerServiceAddress represents the ALTS handshaker gRPC service
// address to connect to.
HandshakerServiceAddress string
// EnableRcvlowat enables CPU savings via SO_RCVLOWAT.
EnableRcvlowat bool
}

// DefaultClientOptions creates a new ClientOptions object with the default
Expand All @@ -120,6 +122,8 @@ type ServerOptions struct {
// HandshakerServiceAddress represents the ALTS handshaker gRPC service
// address to connect to.
HandshakerServiceAddress string
// EnableRcvlowat enables CPU savings via SO_RCVLOWAT.
EnableRcvlowat bool
}

// DefaultServerOptions creates a new ServerOptions object with the default
Expand All @@ -138,19 +142,20 @@ type altsTC struct {
accounts []string
hsAddress string
boundAccessToken string
rcvlowat bool
}

// NewClientCreds constructs a client-side ALTS TransportCredentials object.
func NewClientCreds(opts *ClientOptions) credentials.TransportCredentials {
return newALTS(core.ClientSide, opts.TargetServiceAccounts, opts.HandshakerServiceAddress)
return newALTS(core.ClientSide, opts.TargetServiceAccounts, opts.HandshakerServiceAddress, opts.EnableRcvlowat)
}

// NewServerCreds constructs a server-side ALTS TransportCredentials object.
func NewServerCreds(opts *ServerOptions) credentials.TransportCredentials {
return newALTS(core.ServerSide, nil, opts.HandshakerServiceAddress)
return newALTS(core.ServerSide, nil, opts.HandshakerServiceAddress, opts.EnableRcvlowat)
}

func newALTS(side core.Side, accounts []string, hsAddress string) credentials.TransportCredentials {
func newALTS(side core.Side, accounts []string, hsAddress string, rcvlowat bool) credentials.TransportCredentials {
once.Do(func() {
vmOnGCP = googlecloud.OnGCE()
})
Expand All @@ -165,6 +170,7 @@ func newALTS(side core.Side, accounts []string, hsAddress string) credentials.Tr
side: side,
accounts: accounts,
hsAddress: hsAddress,
rcvlowat: rcvlowat,
}
}

Expand Down Expand Up @@ -200,6 +206,7 @@ func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.C
MinRpcVersion: minRPCVersion,
}
opts.BoundAccessToken = g.boundAccessToken
opts.Rcvlowat = g.rcvlowat
chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, opts)
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -243,6 +250,7 @@ func (g *altsTC) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ credentials.Au
MaxRpcVersion: maxRPCVersion,
MinRpcVersion: minRPCVersion,
}
opts.Rcvlowat = g.rcvlowat
shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, opts)
if err != nil {
return nil, nil, err
Expand Down
32 changes: 0 additions & 32 deletions credentials/alts/internal/conn/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
package conn

import (
"encoding/binary"
"errors"
"fmt"
)

const (
Expand All @@ -48,33 +46,3 @@ func SliceForAppend(in []byte, n int) (head, tail []byte) {
tail = head[len(in):]
return head, tail
}

// ParseFramedMsg parse the provided buffer and returns a frame of the format
// msgLength+msg and any remaining bytes in that buffer.
func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
// If the size field is not complete, return the provided buffer as
// remaining buffer.
length, sufficientBytes := parseMessageLength(b)
if !sufficientBytes {
return nil, b, nil
}
if length > maxLen {
return nil, nil, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen)
}
if len(b) < int(length)+4 { // account for the first 4 msg length bytes.
// Frame is not complete yet.
return nil, b, nil
}
return b[:MsgLenFieldSize+length], b[MsgLenFieldSize+length:], nil
}

// parseMessageLength returns the message length based on frame header. It also
// returns a boolean indicating if the buffer contains sufficient bytes to parse
// the length header. If there are insufficient bytes, (0, false) is returned.
func parseMessageLength(b []byte) (uint32, bool) {
if len(b) < MsgLenFieldSize {
return 0, false
}
msgLenField := b[:MsgLenFieldSize]
return binary.LittleEndian.Uint32(msgLenField), true
}
26 changes: 26 additions & 0 deletions credentials/alts/internal/conn/conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//go:build !linux

/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package conn

// SO_RCVLOWAT exists on non-Linux OSes, but we have't tested them.
func (p *conn) setRcvlowat(length int) error {
return nil
}
79 changes: 79 additions & 0 deletions credentials/alts/internal/conn/conn_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package conn

import (
"errors"

"golang.org/x/sys/unix"
)

// setRcvlowat updates SO_RCVLOWAT to reduce CPU usage.
func (p *conn) setRcvlowat(length int) error {
if p.rawConn == nil {
return nil
}

const (
rcvlowatMax = 16 * 1024 * 1024
rcvlowatMin = 32 * 1024
rcvlowatGap = 16 * 1024
)

remaining := min(cap(p.protected), length, rcvlowatMax)

// Small SO_RCVLOWAT values don't actually save CPU.
if remaining < rcvlowatMin {
remaining = 0
}

// Allow for a small gap, which can wake us up a tiny bit early. This
// helps with latency, as bytes can arrive between wakeup and the
// ensuing read.
if remaining > 0 {
remaining -= rcvlowatGap
}

// Don't hold up the socket once we've hit our threshold.
if len(p.protected) > remaining {
remaining = 0
}

// Don't enable SO_RCVLOWAT if it's not useful.
if p.rcvlowat <= 1 && remaining <= 1 {
return nil
}

// Don't make a syscall if nothing changed.
if p.rcvlowat == remaining {
return nil
}

// Make the actual setsockopt call.
var sockoptErr error
err := p.rawConn.Control(func(fd uintptr) {
sockoptErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVLOWAT, p.rcvlowat)
})
if err != nil || sockoptErr != nil {
return errors.Join(err, sockoptErr)
}

p.rcvlowat = remaining
return nil
}
87 changes: 72 additions & 15 deletions credentials/alts/internal/conn/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"fmt"
"math"
"net"
"syscall"

core "google.golang.org/grpc/credentials/alts/internal"
)
Expand Down Expand Up @@ -56,17 +57,12 @@ const (
MsgLenFieldSize = 4
// The byte size of the message type field of a framed message.
msgTypeFieldSize = 4
// The bytes size limit for a ALTS record message.
// The bytes size limit for an ALTS record message.
altsRecordLengthLimit = 1024 * 1024 // 1 MiB
// The default bytes size of a ALTS record message.
altsRecordDefaultLength = 4 * 1024 // 4KiB
// Message type value included in ALTS record framing.
altsRecordMsgType = uint32(0x06)
// The initial write buffer size.
altsWriteBufferInitialSize = 32 * 1024 // 32KiB
// The maximum write buffer size. This *must* be multiple of
// altsRecordDefaultLength.
altsWriteBufferMaxSize = 512 * 1024 // 512KiB
// The initial buffer used to read from the network.
altsReadBufferInitialSize = 32 * 1024 // 32KiB
)
Expand Down Expand Up @@ -102,11 +98,23 @@ type conn struct {
nextFrame []byte
// overhead is the calculated overhead of each frame.
overhead int
// rcvlowat is the "receive low watermark" used to avoid unnecessary
// early returns from the kernel during [conn.Read], which saves CPU and
// can boost throughput under load. When we receive the first few bytes
// of a message we examine the length field. If, for example, we know
// there's 512KB of data remaining in the record, rcvlowat tells the
// kernel "don't wake me up every time you get another packet; wait
// until you have all 512KB."
//
// See SO_RCVLOWAT in tcp(7) for more info.
rcvlowat int
// rawConn allows us to set SO_RCVLOWAT on the underlying TCP socket.
rawConn syscall.RawConn
}

// NewConn creates a new secure channel instance given the other party role and
// handshaking result.
func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, protected []byte) (net.Conn, error) {
func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, protected []byte, rcvlowat bool) (net.Conn, error) {
newCrypto := protocols[recordProtocol]
if newCrypto == nil {
return nil, fmt.Errorf("negotiated unknown next_protocol %q", recordProtocol)
Expand All @@ -116,7 +124,7 @@ func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, prot
return nil, fmt.Errorf("protocol %q: %v", recordProtocol, err)
}
overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead()
payloadLengthLimit := altsRecordDefaultLength - overhead
payloadLengthLimit := altsRecordLengthLimit - overhead
// We pre-allocate protected to be of size 32KB during initialization.
// We increase the size of the buffer by the required amount if it can't
// hold a complete encrypted record.
Expand All @@ -134,6 +142,18 @@ func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, prot
nextFrame: protectedBuf,
overhead: overhead,
}

if rcvlowat {
tcpConn, ok := c.(*net.TCPConn)
if !ok {
return nil, fmt.Errorf("rcvlowat requires a *net.TCPConn, but got %T", c)
}
if altsConn.rawConn, err = tcpConn.SyscallConn(); err != nil {
return nil, fmt.Errorf("failed to get raw connection: %w", err)
}
altsConn.rcvlowat = 1
}

return altsConn, nil
}

Expand All @@ -144,7 +164,8 @@ func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, prot
func (p *conn) Read(b []byte) (n int, err error) {
if len(p.buf) == 0 {
var framedMsg []byte
framedMsg, p.nextFrame, err = ParseFramedMsg(p.nextFrame, altsRecordLengthLimit)
var length uint32
framedMsg, length, err = p.parseFramedMsg(p.nextFrame, altsRecordLengthLimit)
if err != nil {
return n, err
}
Expand All @@ -159,6 +180,10 @@ func (p *conn) Read(b []byte) (n int, err error) {
}
// Check whether a complete frame has been received yet.
for len(framedMsg) == 0 {
if err := p.setRcvlowat(int(length)); err != nil {
return 0, err
}

if len(p.protected) == cap(p.protected) {
// We can parse the length header to know exactly how large
// the buffer needs to be to hold the entire frame.
Expand All @@ -184,7 +209,7 @@ func (p *conn) Read(b []byte) (n int, err error) {
return 0, err
}
p.protected = p.protected[:len(p.protected)+n]
framedMsg, p.nextFrame, err = ParseFramedMsg(p.protected, altsRecordLengthLimit)
framedMsg, length, err = p.parseFramedMsg(p.protected, altsRecordLengthLimit)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -225,6 +250,39 @@ func (p *conn) Read(b []byte) (n int, err error) {
return n, nil
}

// parseFramedMsg parses the provided buffer and returns a frame of the format
// msgLength+msg iff a full frame is available. It also returns the message
// length if available.
func (p *conn) parseFramedMsg(b []byte, maxLen uint32) ([]byte, uint32, error) {
// If the size field is not complete, return the provided buffer as
// remaining buffer.
p.nextFrame = b
length, sufficientBytes := parseMessageLength(b)
if !sufficientBytes {
return nil, length, nil
}
if length > maxLen {
return nil, length, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen)
}
if len(b) < int(length)+MsgLenFieldSize { // account for the first 4 msg length bytes.
// Frame is not complete yet.
return nil, length, nil
}
p.nextFrame = b[MsgLenFieldSize+length:]
return b[:MsgLenFieldSize+length], length, nil
}

// parseMessageLength returns the message length based on frame header. It also
// returns a boolean indicating if the buffer contains sufficient bytes to parse
// the length header. If there are insufficient bytes, (0, false) is returned.
func parseMessageLength(b []byte) (uint32, bool) {
if len(b) < MsgLenFieldSize {
return 0, false
}
msgLenField := b[:MsgLenFieldSize]
return binary.LittleEndian.Uint32(msgLenField), true
}

// Write encrypts, frames, and writes bytes from b to the underlying connection.
func (p *conn) Write(b []byte) (n int, err error) {
n = len(b)
Expand All @@ -233,10 +291,9 @@ func (p *conn) Write(b []byte) (n int, err error) {
size := len(b) + numOfFrames*p.overhead
// If writeBuf is too small, increase its size up to the maximum size.
partialBSize := len(b)
if size > altsWriteBufferMaxSize {
size = altsWriteBufferMaxSize
const numOfFramesInMaxWriteBuf = altsWriteBufferMaxSize / altsRecordDefaultLength
partialBSize = numOfFramesInMaxWriteBuf * p.payloadLengthLimit
if size > altsRecordLengthLimit {
size = altsRecordLengthLimit
partialBSize = p.payloadLengthLimit
}
if len(p.writeBuf) < size {
p.writeBuf = make([]byte, size)
Expand Down Expand Up @@ -282,7 +339,7 @@ func (p *conn) Write(b []byte) (n int, err error) {
// written. This means we need to remove header,
// encryption overheads, and any partially-written
// frame data.
numOfWrittenFrames := int(math.Floor(float64(nn) / float64(altsRecordDefaultLength)))
numOfWrittenFrames := int(math.Floor(float64(nn) / float64(altsRecordLengthLimit)))
return partialBStart + numOfWrittenFrames*p.payloadLengthLimit, err
}
}
Expand Down
Loading
Loading