Clean more

This commit is contained in:
Jason A. Donenfeld 2018-05-14 12:27:29 +02:00
parent 8b30278ce6
commit 355e9bd619
7 changed files with 52 additions and 65 deletions

View File

@ -217,19 +217,6 @@ func (bind *NativeBind) Send(buff []byte, end Endpoint) error {
} }
} }
func rawAddrToIP4(addr *unix.SockaddrInet4) net.IP {
return net.IPv4(
addr.Addr[0],
addr.Addr[1],
addr.Addr[2],
addr.Addr[3],
)
}
func rawAddrToIP6(addr *unix.SockaddrInet6) net.IP {
return addr.Addr[:]
}
func (end *NativeEndpoint) SrcIP() net.IP { func (end *NativeEndpoint) SrcIP() net.IP {
if !end.isV6 { if !end.isV6 {
return net.IPv4( return net.IPv4(
@ -624,6 +611,10 @@ func (bind *NativeBind) routineRouteListener(device *Device) {
peer.mutex.RUnlock() peer.mutex.RUnlock()
continue continue
} }
if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 {
peer.mutex.RUnlock()
break
}
nlmsg := struct { nlmsg := struct {
hdr unix.NlMsghdr hdr unix.NlMsghdr
msg unix.RtMsg msg unix.RtMsg

View File

@ -48,19 +48,19 @@ func (st *CookieChecker) Init(pk NoisePublicKey) {
// mac1 state // mac1 state
func() { func() {
hsh, _ := blake2s.New256(nil) hash, _ := blake2s.New256(nil)
hsh.Write([]byte(WGLabelMAC1)) hash.Write([]byte(WGLabelMAC1))
hsh.Write(pk[:]) hash.Write(pk[:])
hsh.Sum(st.mac1.key[:0]) hash.Sum(st.mac1.key[:0])
}() }()
// mac2 state // mac2 state
func() { func() {
hsh, _ := blake2s.New256(nil) hash, _ := blake2s.New256(nil)
hsh.Write([]byte(WGLabelCookie)) hash.Write([]byte(WGLabelCookie))
hsh.Write(pk[:]) hash.Write(pk[:])
hsh.Sum(st.mac2.encryptionKey[:0]) hash.Sum(st.mac2.encryptionKey[:0])
}() }()
st.mac2.secretSet = time.Time{} st.mac2.secretSet = time.Time{}
@ -181,17 +181,17 @@ func (st *CookieGenerator) Init(pk NoisePublicKey) {
defer st.mutex.Unlock() defer st.mutex.Unlock()
func() { func() {
hsh, _ := blake2s.New256(nil) hash, _ := blake2s.New256(nil)
hsh.Write([]byte(WGLabelMAC1)) hash.Write([]byte(WGLabelMAC1))
hsh.Write(pk[:]) hash.Write(pk[:])
hsh.Sum(st.mac1.key[:0]) hash.Sum(st.mac1.key[:0])
}() }()
func() { func() {
hsh, _ := blake2s.New256(nil) hash, _ := blake2s.New256(nil)
hsh.Write([]byte(WGLabelCookie)) hash.Write([]byte(WGLabelCookie))
hsh.Write(pk[:]) hash.Write(pk[:])
hsh.Sum(st.mac2.encryptionKey[:0]) hash.Sum(st.mac2.encryptionKey[:0])
}() }()
st.mac2.cookieSet = time.Time{} st.mac2.cookieSet = time.Time{}

View File

@ -225,15 +225,15 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
for key, peer := range device.peers.keyMap { for key, peer := range device.peers.keyMap {
hs := &peer.handshake handshake := &peer.handshake
if rmKey { if rmKey {
hs.precomputedStaticStatic = [NoisePublicKeySize]byte{} handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{}
} else { } else {
hs.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(hs.remoteStatic) handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
} }
if isZero(hs.precomputedStaticStatic[:]) { if isZero(handshake.precomputedStaticStatic[:]) {
unsafeRemovePeer(device, peer, key) unsafeRemovePeer(device, peer, key)
} }
} }
@ -267,18 +267,12 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
device.peers.keyMap = make(map[NoisePublicKey]*Peer) device.peers.keyMap = make(map[NoisePublicKey]*Peer)
// initialize rate limiter
device.rate.limiter.Init() device.rate.limiter.Init()
device.rate.underLoadUntil.Store(time.Time{}) device.rate.underLoadUntil.Store(time.Time{})
// initialize staticIdentity & crypt-key routine
device.indexTable.Init() device.indexTable.Init()
device.allowedips.Reset() device.allowedips.Reset()
// setup buffer pool
device.pool.messageBuffers = sync.Pool{ device.pool.messageBuffers = sync.Pool{
New: func() interface{} { New: func() interface{} {
return new([MaxMessageSize]byte) return new([MaxMessageSize]byte)

View File

@ -186,7 +186,7 @@ func main() {
env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD)) env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD))
env = append(env, fmt.Sprintf("%s=1", ENV_WG_PROCESS_FOREGROUND)) env = append(env, fmt.Sprintf("%s=1", ENV_WG_PROCESS_FOREGROUND))
files := [3]*os.File{} files := [3]*os.File{}
if os.Getenv("LOG_LEVEL") != "" { if os.Getenv("LOG_LEVEL") != "" && logLevel != LogLevelSilent {
files[1] = os.Stdout files[1] = os.Stdout
files[2] = os.Stderr files[2] = os.Stderr
} }

View File

@ -121,11 +121,11 @@ func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) {
} }
func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) { func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) {
hsh, _ := blake2s.New256(nil) hash, _ := blake2s.New256(nil)
hsh.Write(h[:]) hash.Write(h[:])
hsh.Write(data) hash.Write(data)
hsh.Sum(dst[:0]) hash.Sum(dst[:0])
hsh.Reset() hash.Reset()
} }
func (h *Handshake) Clear() { func (h *Handshake) Clear() {

View File

@ -125,12 +125,6 @@ func CreateTUNFromFile(file *os.File) (TUNDevice, error) {
return nil, err return nil, err
} }
// set default MTU
err = tun.setMTU(DefaultMTU)
if err != nil {
return nil, err
}
tun.rwcancel, err = rwcancel.NewRWCancel(int(file.Fd())) tun.rwcancel, err = rwcancel.NewRWCancel(int(file.Fd()))
if err != nil { if err != nil {
return nil, err return nil, err
@ -174,6 +168,13 @@ func CreateTUNFromFile(file *os.File) (TUNDevice, error) {
} }
}(tun) }(tun)
// set default MTU
err = tun.setMTU(DefaultMTU)
if err != nil {
tun.Close()
return nil, err
}
return tun, nil return tun, nil
} }

View File

@ -395,7 +395,7 @@ func CreateTUN(name string) (TUNDevice, error) {
} }
func CreateTUNFromFile(fd *os.File) (TUNDevice, error) { func CreateTUNFromFile(fd *os.File) (TUNDevice, error) {
device := &NativeTun{ tun := &NativeTun{
fd: fd, fd: fd,
events: make(chan TUNEvent, 5), events: make(chan TUNEvent, 5),
errors: make(chan error, 5), errors: make(chan error, 5),
@ -404,37 +404,38 @@ func CreateTUNFromFile(fd *os.File) (TUNDevice, error) {
} }
var err error var err error
device.rwcancel, err = rwcancel.NewRWCancel(int(fd.Fd())) tun.rwcancel, err = rwcancel.NewRWCancel(int(fd.Fd()))
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, err = device.Name() _, err = tun.Name()
if err != nil { if err != nil {
return nil, err return nil, err
} }
// start event listener // start event listener
device.index, err = getIFIndex(device.name) tun.index, err = getIFIndex(tun.name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tun.netlinkSock, err = createNetlinkSocket()
if err != nil {
return nil, err
}
go tun.RoutineNetlinkListener()
go tun.RoutineHackListener() // cross namespace
// set default MTU // set default MTU
err = device.setMTU(DefaultMTU) err = tun.setMTU(DefaultMTU)
if err != nil { if err != nil {
tun.Close()
return nil, err return nil, err
} }
device.netlinkSock, err = createNetlinkSocket() return tun, nil
if err != nil {
return nil, err
}
go device.RoutineNetlinkListener()
go device.RoutineHackListener() // cross namespace
return device, nil
} }