Smoother netlink shutdown

This commit is contained in:
Jason A. Donenfeld 2018-05-14 03:43:56 +02:00
parent c1e097d6d0
commit f738c45a68

View File

@ -31,15 +31,15 @@ const (
) )
type NativeTun struct { type NativeTun struct {
fd *os.File fd *os.File
index int32 // if index index int32 // if index
name string // name of interface name string // name of interface
errors chan error // async error handling errors chan error // async error handling
events chan TUNEvent // device related events events chan TUNEvent // device related events
nopi bool // the device was pased IFF_NO_PI nopi bool // the device was pased IFF_NO_PI
rwcancel *rwcancel.RWCancel rwcancel *rwcancel.RWCancel
netlinkSock int netlinkSock int
shutdownHackListener chan struct{} statusListenersShutdown chan struct{}
} }
func (tun *NativeTun) File() *os.File { func (tun *NativeTun) File() *os.File {
@ -63,7 +63,7 @@ func (tun *NativeTun) RoutineHackListener() {
} }
select { select {
case <-time.After(time.Second / 10): case <-time.After(time.Second / 10):
case <-tun.shutdownHackListener: case <-tun.statusListenersShutdown:
return return
} }
} }
@ -94,6 +94,12 @@ func (tun *NativeTun) RoutineNetlinkListener() {
return return
} }
select {
case <-tun.statusListenersShutdown:
return
default:
}
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
@ -328,16 +334,19 @@ func (tun *NativeTun) Events() chan TUNEvent {
} }
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
err1 := tun.fd.Close() close(tun.statusListenersShutdown)
err2 := closeUnblock(tun.netlinkSock) err1 := closeUnblock(tun.netlinkSock)
tun.rwcancel.Cancel() err2 := tun.fd.Close()
err3 := tun.rwcancel.Cancel()
close(tun.events) close(tun.events)
close(tun.shutdownHackListener)
if err1 != nil { if err1 != nil {
return err1 return err1
} }
return err2 if err2 != nil {
return err2
}
return err3
} }
func CreateTUN(name string) (TUNDevice, error) { func CreateTUN(name string) (TUNDevice, error) {
@ -387,11 +396,11 @@ func CreateTUN(name string) (TUNDevice, error) {
func CreateTUNFromFile(fd *os.File) (TUNDevice, error) { func CreateTUNFromFile(fd *os.File) (TUNDevice, error) {
device := &NativeTun{ device := &NativeTun{
fd: fd, fd: fd,
events: make(chan TUNEvent, 5), events: make(chan TUNEvent, 5),
errors: make(chan error, 5), errors: make(chan error, 5),
shutdownHackListener: make(chan struct{}, 0), statusListenersShutdown: make(chan struct{}, 0),
nopi: false, nopi: false,
} }
var err error var err error