Allows passing UAPI fd to service
This commit is contained in:
		
							parent
							
								
									88801529fd
								
							
						
					
					
						commit
						e1227d3af4
					
				
							
								
								
									
										59
									
								
								src/main.go
									
									
									
									
									
								
							
							
						
						
									
										59
									
								
								src/main.go
									
									
									
									
									
								
							@ -9,7 +9,8 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
	EnvWGTunFD = "WG_TUN_FD"
 | 
						ENV_WG_TUN_FD  = "WG_TUN_FD"
 | 
				
			||||||
 | 
						ENV_WG_UAPI_FD = "WG_UAPI_FD"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func printUsage() {
 | 
					func printUsage() {
 | 
				
			||||||
@ -65,46 +66,69 @@ func main() {
 | 
				
			|||||||
		logLevel,
 | 
							logLevel,
 | 
				
			||||||
		fmt.Sprintf("(%s) ", interfaceName),
 | 
							fmt.Sprintf("(%s) ", interfaceName),
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	logger.Debug.Println("Debug log enabled")
 | 
						logger.Debug.Println("Debug log enabled")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// open TUN device
 | 
						// open TUN device (or use supplied fd)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	tun, err := func() (TUNDevice, error) {
 | 
						tun, err := func() (TUNDevice, error) {
 | 
				
			||||||
		tunFdStr := os.Getenv(EnvWGTunFD)
 | 
							tunFdStr := os.Getenv(ENV_WG_TUN_FD)
 | 
				
			||||||
		if tunFdStr == "" {
 | 
							if tunFdStr == "" {
 | 
				
			||||||
			return CreateTUN(interfaceName)
 | 
								return CreateTUN(interfaceName)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// construct tun device from supplied FD
 | 
							// construct tun device from supplied fd
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		fd, err := strconv.ParseUint(tunFdStr, 10, 32)
 | 
							fd, err := strconv.ParseUint(tunFdStr, 10, 32)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		file := os.NewFile(uintptr(fd), "/dev/net/tun")
 | 
							file := os.NewFile(uintptr(fd), "")
 | 
				
			||||||
		return CreateTUNFromFile(interfaceName, file)
 | 
							return CreateTUNFromFile(interfaceName, file)
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		logger.Error.Println("Failed to create TUN device:", err)
 | 
							logger.Error.Println("Failed to create TUN device:", err)
 | 
				
			||||||
 | 
							os.Exit(ExitSetupFailed)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// open UAPI file (or use supplied fd)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						fileUAPI, err := func() (*os.File, error) {
 | 
				
			||||||
 | 
							uapiFdStr := os.Getenv(ENV_WG_UAPI_FD)
 | 
				
			||||||
 | 
							if uapiFdStr == "" {
 | 
				
			||||||
 | 
								return UAPIOpen(interfaceName)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// use supplied fd
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return os.NewFile(uintptr(fd), ""), nil
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							logger.Error.Println("UAPI listen error:", err)
 | 
				
			||||||
 | 
							os.Exit(ExitSetupFailed)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	// daemonize the process
 | 
						// daemonize the process
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !foreground {
 | 
						if !foreground {
 | 
				
			||||||
		env := os.Environ()
 | 
							env := os.Environ()
 | 
				
			||||||
		_, ok := os.LookupEnv(EnvWGTunFD)
 | 
							env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD))
 | 
				
			||||||
		if !ok {
 | 
							env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD))
 | 
				
			||||||
			kvp := fmt.Sprintf("%s=3", EnvWGTunFD)
 | 
					 | 
				
			||||||
			env = append(env, kvp)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		attr := &os.ProcAttr{
 | 
							attr := &os.ProcAttr{
 | 
				
			||||||
			Files: []*os.File{
 | 
								Files: []*os.File{
 | 
				
			||||||
				nil, // stdin
 | 
									nil, // stdin
 | 
				
			||||||
				nil, // stdout
 | 
									nil, // stdout
 | 
				
			||||||
				nil, // stderr
 | 
									nil, // stderr
 | 
				
			||||||
				tun.File(),
 | 
									tun.File(),
 | 
				
			||||||
 | 
									fileUAPI,
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
			Dir: ".",
 | 
								Dir: ".",
 | 
				
			||||||
			Env: env,
 | 
								Env: env,
 | 
				
			||||||
@ -112,6 +136,7 @@ func main() {
 | 
				
			|||||||
		err = Daemonize(attr)
 | 
							err = Daemonize(attr)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			logger.Error.Println("Failed to daemonize:", err)
 | 
								logger.Error.Println("Failed to daemonize:", err)
 | 
				
			||||||
 | 
								os.Exit(ExitSetupFailed)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -123,20 +148,17 @@ func main() {
 | 
				
			|||||||
	// create wireguard device
 | 
						// create wireguard device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	device := NewDevice(tun, logger)
 | 
						device := NewDevice(tun, logger)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	logger.Info.Println("Device started")
 | 
						logger.Info.Println("Device started")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// start configuration lister
 | 
						// start uapi listener
 | 
				
			||||||
 | 
					 | 
				
			||||||
	uapi, err := NewUAPIListener(interfaceName)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		logger.Error.Println("UAPI listen error:", err)
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	errs := make(chan error)
 | 
						errs := make(chan error)
 | 
				
			||||||
	term := make(chan os.Signal)
 | 
						term := make(chan os.Signal)
 | 
				
			||||||
	wait := device.WaitChannel()
 | 
						wait := device.WaitChannel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						uapi, err := UAPIListen(interfaceName, fileUAPI)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	go func() {
 | 
						go func() {
 | 
				
			||||||
		for {
 | 
							for {
 | 
				
			||||||
			conn, err := uapi.Accept()
 | 
								conn, err := uapi.Accept()
 | 
				
			||||||
@ -161,9 +183,10 @@ func main() {
 | 
				
			|||||||
	case <-errs:
 | 
						case <-errs:
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// clean up UAPI bind
 | 
						// clean up
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	uapi.Close()
 | 
						uapi.Close()
 | 
				
			||||||
 | 
						device.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	logger.Info.Println("Shutting down")
 | 
						logger.Info.Println("Shutting down")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -227,7 +227,7 @@ func (tun *NativeTun) MTU() (int, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	val := binary.LittleEndian.Uint32(ifr[16:20])
 | 
						val := binary.LittleEndian.Uint32(ifr[16:20])
 | 
				
			||||||
	if val >= (1 << 31) {
 | 
						if val >= (1 << 31) {
 | 
				
			||||||
		return int(val-(1<<31)) - (1 << 31), nil
 | 
							return int(toInt32(val)), nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return int(val), nil
 | 
						return int(val), nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -10,12 +10,12 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
	ipcErrorIO         = -int64(unix.EIO)
 | 
						ipcErrorIO        = -int64(unix.EIO)
 | 
				
			||||||
	ipcErrorProtocol   = -int64(unix.EPROTO)
 | 
						ipcErrorProtocol  = -int64(unix.EPROTO)
 | 
				
			||||||
	ipcErrorInvalid    = -int64(unix.EINVAL)
 | 
						ipcErrorInvalid   = -int64(unix.EINVAL)
 | 
				
			||||||
	ipcErrorPortInUse  = -int64(unix.EADDRINUSE)
 | 
						ipcErrorPortInUse = -int64(unix.EADDRINUSE)
 | 
				
			||||||
	socketDirectory    = "/var/run/wireguard"
 | 
						socketDirectory   = "/var/run/wireguard"
 | 
				
			||||||
	socketName         = "%s.sock"
 | 
						socketName        = "%s.sock"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type UAPIListener struct {
 | 
					type UAPIListener struct {
 | 
				
			||||||
@ -50,49 +50,11 @@ func (l *UAPIListener) Addr() net.Addr {
 | 
				
			|||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func connectUnixSocket(path string) (net.Listener, error) {
 | 
					func UAPIListen(name string, file *os.File) (net.Listener, error) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// attempt inital connection
 | 
						// wrap file in listener
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	listener, err := net.Listen("unix", path)
 | 
						listener, err := net.FileListener(file)
 | 
				
			||||||
	if err == nil {
 | 
					 | 
				
			||||||
		return listener, nil
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// check if active
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	_, err = net.Dial("unix", path)
 | 
					 | 
				
			||||||
	if err == nil {
 | 
					 | 
				
			||||||
		return nil, errors.New("Unix socket in use")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// attempt cleanup
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	err = os.Remove(path)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	return net.Listen("unix", path)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func NewUAPIListener(name string) (net.Listener, error) {
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// check if path exist
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	err := os.MkdirAll(socketDirectory, 077)
 | 
					 | 
				
			||||||
	if err != nil && !os.IsExist(err) {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// open UNIX socket
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	socketPath := path.Join(
 | 
					 | 
				
			||||||
		socketDirectory,
 | 
					 | 
				
			||||||
		fmt.Sprintf(socketName, name),
 | 
					 | 
				
			||||||
	)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	listener, err := connectUnixSocket(socketPath)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@ -105,6 +67,11 @@ func NewUAPIListener(name string) (net.Listener, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// watch for deletion of socket
 | 
						// watch for deletion of socket
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						socketPath := path.Join(
 | 
				
			||||||
 | 
							socketDirectory,
 | 
				
			||||||
 | 
							fmt.Sprintf(socketName, name),
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	uapi.inotifyFd, err = unix.InotifyInit()
 | 
						uapi.inotifyFd, err = unix.InotifyInit()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
@ -125,11 +92,12 @@ func NewUAPIListener(name string) (net.Listener, error) {
 | 
				
			|||||||
	go func(l *UAPIListener) {
 | 
						go func(l *UAPIListener) {
 | 
				
			||||||
		var buff [4096]byte
 | 
							var buff [4096]byte
 | 
				
			||||||
		for {
 | 
							for {
 | 
				
			||||||
			unix.Read(uapi.inotifyFd, buff[:])
 | 
								// start with lstat to avoid race condition
 | 
				
			||||||
			if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
 | 
								if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
 | 
				
			||||||
				l.connErr <- err
 | 
									l.connErr <- err
 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
								unix.Read(uapi.inotifyFd, buff[:])
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}(uapi)
 | 
						}(uapi)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -148,3 +116,56 @@ func NewUAPIListener(name string) (net.Listener, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	return uapi, nil
 | 
						return uapi, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func UAPIOpen(name string) (*os.File, error) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// check if path exist
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						err := os.MkdirAll(socketDirectory, 0600)
 | 
				
			||||||
 | 
						if err != nil && !os.IsExist(err) {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// open UNIX socket
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						socketPath := path.Join(
 | 
				
			||||||
 | 
							socketDirectory,
 | 
				
			||||||
 | 
							fmt.Sprintf(socketName, name),
 | 
				
			||||||
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						addr, err := net.ResolveUnixAddr("unix", socketPath)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						listener, err := func() (*net.UnixListener, error) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// initial connection attempt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							listener, err := net.ListenUnix("unix", addr)
 | 
				
			||||||
 | 
							if err == nil {
 | 
				
			||||||
 | 
								return listener, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// check if socket already active
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							_, err = net.Dial("unix", socketPath)
 | 
				
			||||||
 | 
							if err == nil {
 | 
				
			||||||
 | 
								return nil, errors.New("unix socket in use")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// cleanup & attempt again
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							err = os.Remove(socketPath)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return net.ListenUnix("unix", addr)
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return listener.File()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user