@@ -24,6 +24,8 @@ import (
24
24
"gvisor.dev/gvisor/pkg/sync"
25
25
"gvisor.dev/gvisor/pkg/tcpip"
26
26
"gvisor.dev/gvisor/pkg/tcpip/header"
27
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
28
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
27
29
"gvisor.dev/gvisor/pkg/tcpip/stack"
28
30
"gvisor.dev/gvisor/pkg/tcpip/transport"
29
31
"gvisor.dev/gvisor/pkg/waiter"
@@ -952,21 +954,26 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
952
954
e .multicastAddr = addr
953
955
954
956
case * tcpip.AddMembershipOption :
955
- if ! (header .IsV4MulticastAddress (v .MulticastAddr ) && e .netProto == header .IPv4ProtocolNumber ) && ! (header .IsV6MulticastAddress (v .MulticastAddr ) && e .netProto == header .IPv6ProtocolNumber ) {
957
+ var proto tcpip.NetworkProtocolNumber
958
+ switch {
959
+ case header .IsV4MulticastAddress (v .MulticastAddr ):
960
+ proto = ipv4 .ProtocolNumber
961
+ case header .IsV6MulticastAddress (v .MulticastAddr ):
962
+ proto = ipv6 .ProtocolNumber
963
+ default :
956
964
return & tcpip.ErrInvalidOptionValue {}
957
965
}
958
966
959
967
nicID := v .NIC
960
-
961
968
if v .InterfaceAddr .Unspecified () {
962
969
if nicID == 0 {
963
- if r , err := e .stack .FindRoute (0 , tcpip.Address {}, v .MulticastAddr , e . netProto , false /* multicastLoop */ ); err == nil {
970
+ if r , err := e .stack .FindRoute (0 , tcpip.Address {}, v .MulticastAddr , proto , false /* multicastLoop */ ); err == nil {
964
971
nicID = r .NICID ()
965
972
r .Release ()
966
973
}
967
974
}
968
975
} else {
969
- nicID = e .stack .CheckLocalAddress (nicID , e . netProto , v .InterfaceAddr )
976
+ nicID = e .stack .CheckLocalAddress (nicID , proto , v .InterfaceAddr )
970
977
}
971
978
if nicID == 0 {
972
979
return & tcpip.ErrUnknownDevice {}
@@ -981,27 +988,33 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
981
988
return & tcpip.ErrPortInUse {}
982
989
}
983
990
984
- if err := e .stack .JoinGroup (e . netProto , nicID , v .MulticastAddr ); err != nil {
991
+ if err := e .stack .JoinGroup (proto , nicID , v .MulticastAddr ); err != nil {
985
992
return err
986
993
}
987
994
988
995
e .multicastMemberships [memToInsert ] = struct {}{}
989
996
990
997
case * tcpip.RemoveMembershipOption :
991
- if ! (header .IsV4MulticastAddress (v .MulticastAddr ) && e .netProto == header .IPv4ProtocolNumber ) && ! (header .IsV6MulticastAddress (v .MulticastAddr ) && e .netProto == header .IPv6ProtocolNumber ) {
998
+ var proto tcpip.NetworkProtocolNumber
999
+ switch {
1000
+ case header .IsV4MulticastAddress (v .MulticastAddr ):
1001
+ proto = ipv4 .ProtocolNumber
1002
+ case header .IsV6MulticastAddress (v .MulticastAddr ):
1003
+ proto = ipv6 .ProtocolNumber
1004
+ default :
992
1005
return & tcpip.ErrInvalidOptionValue {}
993
1006
}
994
1007
995
1008
nicID := v .NIC
996
1009
if v .InterfaceAddr .Unspecified () {
997
1010
if nicID == 0 {
998
- if r , err := e .stack .FindRoute (0 , tcpip.Address {}, v .MulticastAddr , e . netProto , false /* multicastLoop */ ); err == nil {
1011
+ if r , err := e .stack .FindRoute (0 , tcpip.Address {}, v .MulticastAddr , proto , false /* multicastLoop */ ); err == nil {
999
1012
nicID = r .NICID ()
1000
1013
r .Release ()
1001
1014
}
1002
1015
}
1003
1016
} else {
1004
- nicID = e .stack .CheckLocalAddress (nicID , e . netProto , v .InterfaceAddr )
1017
+ nicID = e .stack .CheckLocalAddress (nicID , proto , v .InterfaceAddr )
1005
1018
}
1006
1019
if nicID == 0 {
1007
1020
return & tcpip.ErrUnknownDevice {}
@@ -1016,7 +1029,7 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
1016
1029
return & tcpip.ErrBadLocalAddress {}
1017
1030
}
1018
1031
1019
- if err := e .stack .LeaveGroup (e . netProto , nicID , v .MulticastAddr ); err != nil {
1032
+ if err := e .stack .LeaveGroup (proto , nicID , v .MulticastAddr ); err != nil {
1020
1033
return err
1021
1034
}
1022
1035
0 commit comments