Added code from windows branch
This commit is contained in:
		
							parent
							
								
									eafa3df606
								
							
						
					
					
						commit
						6f5ef153c3
					
				
							
								
								
									
										6
									
								
								src/build.cmd
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										6
									
								
								src/build.cmd
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,6 @@
 | 
			
		||||
@echo off
 | 
			
		||||
 | 
			
		||||
REM builds wireguard for windows
 | 
			
		||||
 | 
			
		||||
go get
 | 
			
		||||
go build -o wireguard-go.exe
 | 
			
		||||
@ -6,6 +6,6 @@ import (
 | 
			
		||||
	"net"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func setFwmark(conn *net.UDPConn, value int) error {
 | 
			
		||||
func setMark(conn *net.UDPConn, value int) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										34
									
								
								src/daemon_windows.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								src/daemon_windows.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,34 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"os"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Daemonizes the process on windows
 | 
			
		||||
 *
 | 
			
		||||
 * This is done by spawning and releasing a copy with the --foreground flag
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
func Daemonize() error {
 | 
			
		||||
	argv := []string{os.Args[0], "--foreground"}
 | 
			
		||||
	argv = append(argv, os.Args[1:]...)
 | 
			
		||||
	attr := &os.ProcAttr{
 | 
			
		||||
		Dir: ".",
 | 
			
		||||
		Env: os.Environ(),
 | 
			
		||||
		Files: []*os.File{
 | 
			
		||||
			os.Stdin,
 | 
			
		||||
			nil,
 | 
			
		||||
			nil,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	process, err := os.StartProcess(
 | 
			
		||||
		argv[0],
 | 
			
		||||
		argv,
 | 
			
		||||
		attr,
 | 
			
		||||
	)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	process.Release()
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										672
									
								
								src/timers.go
									
									
									
									
									
								
							
							
						
						
									
										672
									
								
								src/timers.go
									
									
									
									
									
								
							@ -1,336 +1,336 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"golang.org/x/crypto/blake2s"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Called when a new authenticated message has been send
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) KeepKeyFreshSending() {
 | 
			
		||||
	kp := peer.keyPairs.Current()
 | 
			
		||||
	if kp == nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	nonce := atomic.LoadUint64(&kp.sendNonce)
 | 
			
		||||
	if nonce > RekeyAfterMessages {
 | 
			
		||||
		signalSend(peer.signal.handshakeBegin)
 | 
			
		||||
	}
 | 
			
		||||
	if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime {
 | 
			
		||||
		signalSend(peer.signal.handshakeBegin)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Called when a new authenticated message has been recevied
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) KeepKeyFreshReceiving() {
 | 
			
		||||
	// TODO: Add a guard, clear on handshake complete (clear in TimerHandshakeComplete)
 | 
			
		||||
	kp := peer.keyPairs.Current()
 | 
			
		||||
	if kp == nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if !kp.isInitiator {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	nonce := atomic.LoadUint64(&kp.sendNonce)
 | 
			
		||||
	send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
 | 
			
		||||
	if send {
 | 
			
		||||
		signalSend(peer.signal.handshakeBegin)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Queues a keep-alive if no packets are queued for peer
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) SendKeepAlive() bool {
 | 
			
		||||
	elem := peer.device.NewOutboundElement()
 | 
			
		||||
	elem.packet = nil
 | 
			
		||||
	if len(peer.queue.nonce) == 0 {
 | 
			
		||||
		select {
 | 
			
		||||
		case peer.queue.nonce <- elem:
 | 
			
		||||
			return true
 | 
			
		||||
		default:
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Event:
 | 
			
		||||
 * Sent non-empty (authenticated) transport message
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) TimerDataSent() {
 | 
			
		||||
	timerStop(peer.timer.keepalivePassive)
 | 
			
		||||
	if !peer.timer.pendingNewHandshake {
 | 
			
		||||
		peer.timer.pendingNewHandshake = true
 | 
			
		||||
		peer.timer.newHandshake.Reset(NewHandshakeTime)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Event:
 | 
			
		||||
 * Received non-empty (authenticated) transport message
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) TimerDataReceived() {
 | 
			
		||||
	if peer.timer.pendingKeepalivePassive {
 | 
			
		||||
		peer.timer.needAnotherKeepalive = true
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	peer.timer.pendingKeepalivePassive = false
 | 
			
		||||
	peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Event:
 | 
			
		||||
 * Any (authenticated) packet received
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {
 | 
			
		||||
	timerStop(peer.timer.newHandshake)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Event:
 | 
			
		||||
 * Any authenticated packet send / received.
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
 | 
			
		||||
	interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
 | 
			
		||||
	if interval > 0 {
 | 
			
		||||
		duration := time.Duration(interval) * time.Second
 | 
			
		||||
		peer.timer.keepalivePersistent.Reset(duration)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Called after succesfully completing a handshake.
 | 
			
		||||
 * i.e. after:
 | 
			
		||||
 *
 | 
			
		||||
 * - Valid handshake response
 | 
			
		||||
 * - First transport message under the "next" key
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) TimerHandshakeComplete() {
 | 
			
		||||
	atomic.StoreInt64(
 | 
			
		||||
		&peer.stats.lastHandshakeNano,
 | 
			
		||||
		time.Now().UnixNano(),
 | 
			
		||||
	)
 | 
			
		||||
	signalSend(peer.signal.handshakeCompleted)
 | 
			
		||||
	peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Event:
 | 
			
		||||
 * An ephemeral key is generated
 | 
			
		||||
 *
 | 
			
		||||
 * i.e after:
 | 
			
		||||
 *
 | 
			
		||||
 * CreateMessageInitiation
 | 
			
		||||
 * CreateMessageResponse
 | 
			
		||||
 *
 | 
			
		||||
 * Schedules the deletion of all key material
 | 
			
		||||
 * upon failure to complete a handshake
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) TimerEphemeralKeyCreated() {
 | 
			
		||||
	peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (peer *Peer) RoutineTimerHandler() {
 | 
			
		||||
	device := peer.device
 | 
			
		||||
	indices := &device.indices
 | 
			
		||||
 | 
			
		||||
	logDebug := device.log.Debug
 | 
			
		||||
	logDebug.Println("Routine, timer handler, started for peer", peer.String())
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
 | 
			
		||||
		case <-peer.signal.stop:
 | 
			
		||||
			return
 | 
			
		||||
 | 
			
		||||
		// keep-alives
 | 
			
		||||
 | 
			
		||||
		case <-peer.timer.keepalivePersistent.C:
 | 
			
		||||
 | 
			
		||||
			interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
 | 
			
		||||
			if interval > 0 {
 | 
			
		||||
				logDebug.Println("Sending keep-alive to", peer.String())
 | 
			
		||||
				peer.SendKeepAlive()
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		case <-peer.timer.keepalivePassive.C:
 | 
			
		||||
 | 
			
		||||
			logDebug.Println("Sending keep-alive to", peer.String())
 | 
			
		||||
 | 
			
		||||
			peer.SendKeepAlive()
 | 
			
		||||
 | 
			
		||||
			if peer.timer.needAnotherKeepalive {
 | 
			
		||||
				peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
 | 
			
		||||
				peer.timer.needAnotherKeepalive = false
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		// unresponsive session
 | 
			
		||||
 | 
			
		||||
		case <-peer.timer.newHandshake.C:
 | 
			
		||||
 | 
			
		||||
			logDebug.Println("Retrying handshake with", peer.String(), "due to lack of reply")
 | 
			
		||||
 | 
			
		||||
			signalSend(peer.signal.handshakeBegin)
 | 
			
		||||
 | 
			
		||||
		// clear key material
 | 
			
		||||
 | 
			
		||||
		case <-peer.timer.zeroAllKeys.C:
 | 
			
		||||
 | 
			
		||||
			logDebug.Println("Clearing all key material for", peer.String())
 | 
			
		||||
 | 
			
		||||
			hs := &peer.handshake
 | 
			
		||||
			hs.mutex.Lock()
 | 
			
		||||
 | 
			
		||||
			kp := &peer.keyPairs
 | 
			
		||||
			kp.mutex.Lock()
 | 
			
		||||
 | 
			
		||||
			// unmap indecies
 | 
			
		||||
 | 
			
		||||
			indices.mutex.Lock()
 | 
			
		||||
			if kp.previous != nil {
 | 
			
		||||
				delete(indices.table, kp.previous.localIndex)
 | 
			
		||||
			}
 | 
			
		||||
			if kp.current != nil {
 | 
			
		||||
				delete(indices.table, kp.current.localIndex)
 | 
			
		||||
			}
 | 
			
		||||
			if kp.next != nil {
 | 
			
		||||
				delete(indices.table, kp.next.localIndex)
 | 
			
		||||
			}
 | 
			
		||||
			delete(indices.table, hs.localIndex)
 | 
			
		||||
			indices.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
			// zero out key pairs (TODO: better than wait for GC)
 | 
			
		||||
 | 
			
		||||
			kp.current = nil
 | 
			
		||||
			kp.previous = nil
 | 
			
		||||
			kp.next = nil
 | 
			
		||||
			kp.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
			// zero out handshake
 | 
			
		||||
 | 
			
		||||
			hs.localIndex = 0
 | 
			
		||||
			hs.localEphemeral = NoisePrivateKey{}
 | 
			
		||||
			hs.remoteEphemeral = NoisePublicKey{}
 | 
			
		||||
			hs.chainKey = [blake2s.Size]byte{}
 | 
			
		||||
			hs.hash = [blake2s.Size]byte{}
 | 
			
		||||
			hs.mutex.Unlock()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* This is the state machine for handshake initiation
 | 
			
		||||
 *
 | 
			
		||||
 * Associated with this routine is the signal "handshakeBegin"
 | 
			
		||||
 * The routine will read from the "handshakeBegin" channel
 | 
			
		||||
 * at most every RekeyTimeout seconds
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) RoutineHandshakeInitiator() {
 | 
			
		||||
	device := peer.device
 | 
			
		||||
 | 
			
		||||
	logInfo := device.log.Info
 | 
			
		||||
	logError := device.log.Error
 | 
			
		||||
	logDebug := device.log.Debug
 | 
			
		||||
	logDebug.Println("Routine, handshake initator, started for", peer.String())
 | 
			
		||||
 | 
			
		||||
	var temp [256]byte
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
 | 
			
		||||
		// wait for signal
 | 
			
		||||
 | 
			
		||||
		select {
 | 
			
		||||
		case <-peer.signal.handshakeBegin:
 | 
			
		||||
		case <-peer.signal.stop:
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// set deadline
 | 
			
		||||
 | 
			
		||||
	BeginHandshakes:
 | 
			
		||||
 | 
			
		||||
		signalClear(peer.signal.handshakeReset)
 | 
			
		||||
		deadline := time.NewTimer(RekeyAttemptTime)
 | 
			
		||||
 | 
			
		||||
	AttemptHandshakes:
 | 
			
		||||
 | 
			
		||||
		for attempts := uint(1); ; attempts++ {
 | 
			
		||||
 | 
			
		||||
			// check if deadline reached
 | 
			
		||||
 | 
			
		||||
			select {
 | 
			
		||||
			case <-deadline.C:
 | 
			
		||||
				logInfo.Println("Handshake negotiation timed out for:", peer.String())
 | 
			
		||||
				signalSend(peer.signal.flushNonceQueue)
 | 
			
		||||
				timerStop(peer.timer.keepalivePersistent)
 | 
			
		||||
				break
 | 
			
		||||
			case <-peer.signal.stop:
 | 
			
		||||
				return
 | 
			
		||||
			default:
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			signalClear(peer.signal.handshakeCompleted)
 | 
			
		||||
 | 
			
		||||
			// create initiation message
 | 
			
		||||
 | 
			
		||||
			msg, err := peer.device.CreateMessageInitiation(peer)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logError.Println("Failed to create handshake initiation message:", err)
 | 
			
		||||
				break AttemptHandshakes
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
 | 
			
		||||
 | 
			
		||||
			// marshal and send
 | 
			
		||||
 | 
			
		||||
			writer := bytes.NewBuffer(temp[:0])
 | 
			
		||||
			binary.Write(writer, binary.LittleEndian, msg)
 | 
			
		||||
			packet := writer.Bytes()
 | 
			
		||||
			peer.mac.AddMacs(packet)
 | 
			
		||||
 | 
			
		||||
			_, err = peer.SendBuffer(packet)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logError.Println(
 | 
			
		||||
					"Failed to send handshake initiation message to",
 | 
			
		||||
					peer.String(), ":", err,
 | 
			
		||||
				)
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			peer.TimerAnyAuthenticatedPacketTraversal()
 | 
			
		||||
 | 
			
		||||
			// set handshake timeout
 | 
			
		||||
 | 
			
		||||
			timeout := time.NewTimer(RekeyTimeout + jitter)
 | 
			
		||||
			logDebug.Println(
 | 
			
		||||
				"Handshake initiation attempt",
 | 
			
		||||
				attempts, "sent to", peer.String(),
 | 
			
		||||
			)
 | 
			
		||||
 | 
			
		||||
			// wait for handshake or timeout
 | 
			
		||||
 | 
			
		||||
			select {
 | 
			
		||||
 | 
			
		||||
			case <-peer.signal.stop:
 | 
			
		||||
				return
 | 
			
		||||
 | 
			
		||||
			case <-peer.signal.handshakeCompleted:
 | 
			
		||||
				<-timeout.C
 | 
			
		||||
				break AttemptHandshakes
 | 
			
		||||
 | 
			
		||||
			case <-peer.signal.handshakeReset:
 | 
			
		||||
				<-timeout.C
 | 
			
		||||
				goto BeginHandshakes
 | 
			
		||||
 | 
			
		||||
			case <-timeout.C:
 | 
			
		||||
				// TODO: Clear source address for peer
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// clear signal set in the meantime
 | 
			
		||||
 | 
			
		||||
		signalClear(peer.signal.handshakeBegin)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"golang.org/x/crypto/blake2s"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Called when a new authenticated message has been send
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) KeepKeyFreshSending() {
 | 
			
		||||
	kp := peer.keyPairs.Current()
 | 
			
		||||
	if kp == nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	nonce := atomic.LoadUint64(&kp.sendNonce)
 | 
			
		||||
	if nonce > RekeyAfterMessages {
 | 
			
		||||
		signalSend(peer.signal.handshakeBegin)
 | 
			
		||||
	}
 | 
			
		||||
	if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime {
 | 
			
		||||
		signalSend(peer.signal.handshakeBegin)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Called when a new authenticated message has been recevied
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) KeepKeyFreshReceiving() {
 | 
			
		||||
	// TODO: Add a guard, clear on handshake complete (clear in TimerHandshakeComplete)
 | 
			
		||||
	kp := peer.keyPairs.Current()
 | 
			
		||||
	if kp == nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if !kp.isInitiator {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	nonce := atomic.LoadUint64(&kp.sendNonce)
 | 
			
		||||
	send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
 | 
			
		||||
	if send {
 | 
			
		||||
		signalSend(peer.signal.handshakeBegin)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Queues a keep-alive if no packets are queued for peer
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) SendKeepAlive() bool {
 | 
			
		||||
	elem := peer.device.NewOutboundElement()
 | 
			
		||||
	elem.packet = nil
 | 
			
		||||
	if len(peer.queue.nonce) == 0 {
 | 
			
		||||
		select {
 | 
			
		||||
		case peer.queue.nonce <- elem:
 | 
			
		||||
			return true
 | 
			
		||||
		default:
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Event:
 | 
			
		||||
 * Sent non-empty (authenticated) transport message
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) TimerDataSent() {
 | 
			
		||||
	timerStop(peer.timer.keepalivePassive)
 | 
			
		||||
	if !peer.timer.pendingNewHandshake {
 | 
			
		||||
		peer.timer.pendingNewHandshake = true
 | 
			
		||||
		peer.timer.newHandshake.Reset(NewHandshakeTime)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Event:
 | 
			
		||||
 * Received non-empty (authenticated) transport message
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) TimerDataReceived() {
 | 
			
		||||
	if peer.timer.pendingKeepalivePassive {
 | 
			
		||||
		peer.timer.needAnotherKeepalive = true
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	peer.timer.pendingKeepalivePassive = false
 | 
			
		||||
	peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Event:
 | 
			
		||||
 * Any (authenticated) packet received
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {
 | 
			
		||||
	timerStop(peer.timer.newHandshake)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Event:
 | 
			
		||||
 * Any authenticated packet send / received.
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
 | 
			
		||||
	interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
 | 
			
		||||
	if interval > 0 {
 | 
			
		||||
		duration := time.Duration(interval) * time.Second
 | 
			
		||||
		peer.timer.keepalivePersistent.Reset(duration)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Called after succesfully completing a handshake.
 | 
			
		||||
 * i.e. after:
 | 
			
		||||
 *
 | 
			
		||||
 * - Valid handshake response
 | 
			
		||||
 * - First transport message under the "next" key
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) TimerHandshakeComplete() {
 | 
			
		||||
	atomic.StoreInt64(
 | 
			
		||||
		&peer.stats.lastHandshakeNano,
 | 
			
		||||
		time.Now().UnixNano(),
 | 
			
		||||
	)
 | 
			
		||||
	signalSend(peer.signal.handshakeCompleted)
 | 
			
		||||
	peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Event:
 | 
			
		||||
 * An ephemeral key is generated
 | 
			
		||||
 *
 | 
			
		||||
 * i.e after:
 | 
			
		||||
 *
 | 
			
		||||
 * CreateMessageInitiation
 | 
			
		||||
 * CreateMessageResponse
 | 
			
		||||
 *
 | 
			
		||||
 * Schedules the deletion of all key material
 | 
			
		||||
 * upon failure to complete a handshake
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) TimerEphemeralKeyCreated() {
 | 
			
		||||
	peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (peer *Peer) RoutineTimerHandler() {
 | 
			
		||||
	device := peer.device
 | 
			
		||||
	indices := &device.indices
 | 
			
		||||
 | 
			
		||||
	logDebug := device.log.Debug
 | 
			
		||||
	logDebug.Println("Routine, timer handler, started for peer", peer.String())
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
 | 
			
		||||
		case <-peer.signal.stop:
 | 
			
		||||
			return
 | 
			
		||||
 | 
			
		||||
		// keep-alives
 | 
			
		||||
 | 
			
		||||
		case <-peer.timer.keepalivePersistent.C:
 | 
			
		||||
 | 
			
		||||
			interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
 | 
			
		||||
			if interval > 0 {
 | 
			
		||||
				logDebug.Println("Sending keep-alive to", peer.String())
 | 
			
		||||
				peer.SendKeepAlive()
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		case <-peer.timer.keepalivePassive.C:
 | 
			
		||||
 | 
			
		||||
			logDebug.Println("Sending keep-alive to", peer.String())
 | 
			
		||||
 | 
			
		||||
			peer.SendKeepAlive()
 | 
			
		||||
 | 
			
		||||
			if peer.timer.needAnotherKeepalive {
 | 
			
		||||
				peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
 | 
			
		||||
				peer.timer.needAnotherKeepalive = false
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		// unresponsive session
 | 
			
		||||
 | 
			
		||||
		case <-peer.timer.newHandshake.C:
 | 
			
		||||
 | 
			
		||||
			logDebug.Println("Retrying handshake with", peer.String(), "due to lack of reply")
 | 
			
		||||
 | 
			
		||||
			signalSend(peer.signal.handshakeBegin)
 | 
			
		||||
 | 
			
		||||
		// clear key material
 | 
			
		||||
 | 
			
		||||
		case <-peer.timer.zeroAllKeys.C:
 | 
			
		||||
 | 
			
		||||
			logDebug.Println("Clearing all key material for", peer.String())
 | 
			
		||||
 | 
			
		||||
			hs := &peer.handshake
 | 
			
		||||
			hs.mutex.Lock()
 | 
			
		||||
 | 
			
		||||
			kp := &peer.keyPairs
 | 
			
		||||
			kp.mutex.Lock()
 | 
			
		||||
 | 
			
		||||
			// unmap indecies
 | 
			
		||||
 | 
			
		||||
			indices.mutex.Lock()
 | 
			
		||||
			if kp.previous != nil {
 | 
			
		||||
				delete(indices.table, kp.previous.localIndex)
 | 
			
		||||
			}
 | 
			
		||||
			if kp.current != nil {
 | 
			
		||||
				delete(indices.table, kp.current.localIndex)
 | 
			
		||||
			}
 | 
			
		||||
			if kp.next != nil {
 | 
			
		||||
				delete(indices.table, kp.next.localIndex)
 | 
			
		||||
			}
 | 
			
		||||
			delete(indices.table, hs.localIndex)
 | 
			
		||||
			indices.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
			// zero out key pairs (TODO: better than wait for GC)
 | 
			
		||||
 | 
			
		||||
			kp.current = nil
 | 
			
		||||
			kp.previous = nil
 | 
			
		||||
			kp.next = nil
 | 
			
		||||
			kp.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
			// zero out handshake
 | 
			
		||||
 | 
			
		||||
			hs.localIndex = 0
 | 
			
		||||
			hs.localEphemeral = NoisePrivateKey{}
 | 
			
		||||
			hs.remoteEphemeral = NoisePublicKey{}
 | 
			
		||||
			hs.chainKey = [blake2s.Size]byte{}
 | 
			
		||||
			hs.hash = [blake2s.Size]byte{}
 | 
			
		||||
			hs.mutex.Unlock()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* This is the state machine for handshake initiation
 | 
			
		||||
 *
 | 
			
		||||
 * Associated with this routine is the signal "handshakeBegin"
 | 
			
		||||
 * The routine will read from the "handshakeBegin" channel
 | 
			
		||||
 * at most every RekeyTimeout seconds
 | 
			
		||||
 */
 | 
			
		||||
func (peer *Peer) RoutineHandshakeInitiator() {
 | 
			
		||||
	device := peer.device
 | 
			
		||||
 | 
			
		||||
	logInfo := device.log.Info
 | 
			
		||||
	logError := device.log.Error
 | 
			
		||||
	logDebug := device.log.Debug
 | 
			
		||||
	logDebug.Println("Routine, handshake initator, started for", peer.String())
 | 
			
		||||
 | 
			
		||||
	var temp [256]byte
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
 | 
			
		||||
		// wait for signal
 | 
			
		||||
 | 
			
		||||
		select {
 | 
			
		||||
		case <-peer.signal.handshakeBegin:
 | 
			
		||||
		case <-peer.signal.stop:
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// set deadline
 | 
			
		||||
 | 
			
		||||
	BeginHandshakes:
 | 
			
		||||
 | 
			
		||||
		signalClear(peer.signal.handshakeReset)
 | 
			
		||||
		deadline := time.NewTimer(RekeyAttemptTime)
 | 
			
		||||
 | 
			
		||||
	AttemptHandshakes:
 | 
			
		||||
 | 
			
		||||
		for attempts := uint(1); ; attempts++ {
 | 
			
		||||
 | 
			
		||||
			// check if deadline reached
 | 
			
		||||
 | 
			
		||||
			select {
 | 
			
		||||
			case <-deadline.C:
 | 
			
		||||
				logInfo.Println("Handshake negotiation timed out for:", peer.String())
 | 
			
		||||
				signalSend(peer.signal.flushNonceQueue)
 | 
			
		||||
				timerStop(peer.timer.keepalivePersistent)
 | 
			
		||||
				break
 | 
			
		||||
			case <-peer.signal.stop:
 | 
			
		||||
				return
 | 
			
		||||
			default:
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			signalClear(peer.signal.handshakeCompleted)
 | 
			
		||||
 | 
			
		||||
			// create initiation message
 | 
			
		||||
 | 
			
		||||
			msg, err := peer.device.CreateMessageInitiation(peer)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logError.Println("Failed to create handshake initiation message:", err)
 | 
			
		||||
				break AttemptHandshakes
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
 | 
			
		||||
 | 
			
		||||
			// marshal and send
 | 
			
		||||
 | 
			
		||||
			writer := bytes.NewBuffer(temp[:0])
 | 
			
		||||
			binary.Write(writer, binary.LittleEndian, msg)
 | 
			
		||||
			packet := writer.Bytes()
 | 
			
		||||
			peer.mac.AddMacs(packet)
 | 
			
		||||
 | 
			
		||||
			_, err = peer.SendBuffer(packet)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logError.Println(
 | 
			
		||||
					"Failed to send handshake initiation message to",
 | 
			
		||||
					peer.String(), ":", err,
 | 
			
		||||
				)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			peer.TimerAnyAuthenticatedPacketTraversal()
 | 
			
		||||
 | 
			
		||||
			// set handshake timeout
 | 
			
		||||
 | 
			
		||||
			timeout := time.NewTimer(RekeyTimeout + jitter)
 | 
			
		||||
			logDebug.Println(
 | 
			
		||||
				"Handshake initiation attempt",
 | 
			
		||||
				attempts, "sent to", peer.String(),
 | 
			
		||||
			)
 | 
			
		||||
 | 
			
		||||
			// wait for handshake or timeout
 | 
			
		||||
 | 
			
		||||
			select {
 | 
			
		||||
 | 
			
		||||
			case <-peer.signal.stop:
 | 
			
		||||
				return
 | 
			
		||||
 | 
			
		||||
			case <-peer.signal.handshakeCompleted:
 | 
			
		||||
				<-timeout.C
 | 
			
		||||
				break AttemptHandshakes
 | 
			
		||||
 | 
			
		||||
			case <-peer.signal.handshakeReset:
 | 
			
		||||
				<-timeout.C
 | 
			
		||||
				goto BeginHandshakes
 | 
			
		||||
 | 
			
		||||
			case <-timeout.C:
 | 
			
		||||
				// TODO: Clear source address for peer
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// clear signal set in the meantime
 | 
			
		||||
 | 
			
		||||
		signalClear(peer.signal.handshakeBegin)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										475
									
								
								src/tun_windows.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										475
									
								
								src/tun_windows.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,475 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"golang.org/x/sys/windows"
 | 
			
		||||
	"golang.org/x/sys/windows/registry"
 | 
			
		||||
	"net"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"syscall"
 | 
			
		||||
	"time"
 | 
			
		||||
	"unsafe"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Relies on the OpenVPN TAP-Windows driver (NDIS 6 version)
 | 
			
		||||
 *
 | 
			
		||||
 * https://github.com/OpenVPN/tap-windows
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
type NativeTUN struct {
 | 
			
		||||
	fd     windows.Handle
 | 
			
		||||
	rl     sync.Mutex
 | 
			
		||||
	wl     sync.Mutex
 | 
			
		||||
	ro     *windows.Overlapped
 | 
			
		||||
	wo     *windows.Overlapped
 | 
			
		||||
	events chan TUNEvent
 | 
			
		||||
	name   string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	METHOD_BUFFERED = 0
 | 
			
		||||
	ComponentID     = "tap0901" // tap0801
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ctl_code(device_type, function, method, access uint32) uint32 {
 | 
			
		||||
	return (device_type << 16) | (access << 14) | (function << 2) | method
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TAP_CONTROL_CODE(request, method uint32) uint32 {
 | 
			
		||||
	return ctl_code(file_device_unknown, request, method, 0)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	errIfceNameNotFound = errors.New("Failed to find the name of interface")
 | 
			
		||||
 | 
			
		||||
	TAP_IOCTL_GET_MAC               = TAP_CONTROL_CODE(1, METHOD_BUFFERED)
 | 
			
		||||
	TAP_IOCTL_GET_VERSION           = TAP_CONTROL_CODE(2, METHOD_BUFFERED)
 | 
			
		||||
	TAP_IOCTL_GET_MTU               = TAP_CONTROL_CODE(3, METHOD_BUFFERED)
 | 
			
		||||
	TAP_IOCTL_GET_INFO              = TAP_CONTROL_CODE(4, METHOD_BUFFERED)
 | 
			
		||||
	TAP_IOCTL_CONFIG_POINT_TO_POINT = TAP_CONTROL_CODE(5, METHOD_BUFFERED)
 | 
			
		||||
	TAP_IOCTL_SET_MEDIA_STATUS      = TAP_CONTROL_CODE(6, METHOD_BUFFERED)
 | 
			
		||||
	TAP_IOCTL_CONFIG_DHCP_MASQ      = TAP_CONTROL_CODE(7, METHOD_BUFFERED)
 | 
			
		||||
	TAP_IOCTL_GET_LOG_LINE          = TAP_CONTROL_CODE(8, METHOD_BUFFERED)
 | 
			
		||||
	TAP_IOCTL_CONFIG_DHCP_SET_OPT   = TAP_CONTROL_CODE(9, METHOD_BUFFERED)
 | 
			
		||||
	TAP_IOCTL_CONFIG_TUN            = TAP_CONTROL_CODE(10, METHOD_BUFFERED)
 | 
			
		||||
 | 
			
		||||
	file_device_unknown = uint32(0x00000022)
 | 
			
		||||
	nCreateEvent,
 | 
			
		||||
	nResetEvent,
 | 
			
		||||
	nGetOverlappedResult uintptr
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	k32, err := windows.LoadLibrary("kernel32.dll")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic("LoadLibrary " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	defer windows.FreeLibrary(k32)
 | 
			
		||||
	nCreateEvent = getProcAddr(k32, "CreateEventW")
 | 
			
		||||
	nResetEvent = getProcAddr(k32, "ResetEvent")
 | 
			
		||||
	nGetOverlappedResult = getProcAddr(k32, "GetOverlappedResult")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* implementation of the read/write/closer interface */
 | 
			
		||||
 | 
			
		||||
func getProcAddr(lib windows.Handle, name string) uintptr {
 | 
			
		||||
	addr, err := windows.GetProcAddress(lib, name)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(name + " " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	return addr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func resetEvent(h windows.Handle) error {
 | 
			
		||||
	r, _, err := syscall.Syscall(nResetEvent, 1, uintptr(h), 0, 0)
 | 
			
		||||
	if r == 0 {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getOverlappedResult(h windows.Handle, overlapped *windows.Overlapped) (int, error) {
 | 
			
		||||
	var n int
 | 
			
		||||
	r, _, err := syscall.Syscall6(
 | 
			
		||||
		nGetOverlappedResult,
 | 
			
		||||
		4,
 | 
			
		||||
		uintptr(h),
 | 
			
		||||
		uintptr(unsafe.Pointer(overlapped)),
 | 
			
		||||
		uintptr(unsafe.Pointer(&n)), 1, 0, 0)
 | 
			
		||||
 | 
			
		||||
	if r == 0 {
 | 
			
		||||
		return n, err
 | 
			
		||||
	}
 | 
			
		||||
	return n, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newOverlapped() (*windows.Overlapped, error) {
 | 
			
		||||
	var overlapped windows.Overlapped
 | 
			
		||||
	r, _, err := syscall.Syscall6(nCreateEvent, 4, 0, 1, 0, 0, 0, 0)
 | 
			
		||||
	if r == 0 {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	overlapped.HEvent = windows.Handle(r)
 | 
			
		||||
	return &overlapped, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *NativeTUN) Events() chan TUNEvent {
 | 
			
		||||
	return f.events
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *NativeTUN) Close() error {
 | 
			
		||||
	return windows.Close(f.fd)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *NativeTUN) Write(b []byte) (int, error) {
 | 
			
		||||
	f.wl.Lock()
 | 
			
		||||
	defer f.wl.Unlock()
 | 
			
		||||
 | 
			
		||||
	if err := resetEvent(f.wo.HEvent); err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	var n uint32
 | 
			
		||||
	err := windows.WriteFile(f.fd, b, &n, f.wo)
 | 
			
		||||
	if err != nil && err != windows.ERROR_IO_PENDING {
 | 
			
		||||
		return int(n), err
 | 
			
		||||
	}
 | 
			
		||||
	return getOverlappedResult(f.fd, f.wo)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *NativeTUN) Read(b []byte) (int, error) {
 | 
			
		||||
	f.rl.Lock()
 | 
			
		||||
	defer f.rl.Unlock()
 | 
			
		||||
 | 
			
		||||
	if err := resetEvent(f.ro.HEvent); err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	var done uint32
 | 
			
		||||
	err := windows.ReadFile(f.fd, b, &done, f.ro)
 | 
			
		||||
	if err != nil && err != windows.ERROR_IO_PENDING {
 | 
			
		||||
		return int(done), err
 | 
			
		||||
	}
 | 
			
		||||
	return getOverlappedResult(f.fd, f.ro)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getdeviceid(
 | 
			
		||||
	targetComponentId string,
 | 
			
		||||
	targetDeviceName string,
 | 
			
		||||
) (deviceid string, err error) {
 | 
			
		||||
 | 
			
		||||
	getName := func(instanceId string) (string, error) {
 | 
			
		||||
		path := fmt.Sprintf(
 | 
			
		||||
			`SYSTEM\CurrentControlSet\Control\Network\{4D36E972-E325-11CE-BFC1-08002BE10318}\%s\Connection`,
 | 
			
		||||
			instanceId,
 | 
			
		||||
		)
 | 
			
		||||
 | 
			
		||||
		key, err := registry.OpenKey(
 | 
			
		||||
			registry.LOCAL_MACHINE,
 | 
			
		||||
			path,
 | 
			
		||||
			registry.READ,
 | 
			
		||||
		)
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return "", err
 | 
			
		||||
		}
 | 
			
		||||
		defer key.Close()
 | 
			
		||||
 | 
			
		||||
		val, _, err := key.GetStringValue("Name")
 | 
			
		||||
		key.Close()
 | 
			
		||||
		return val, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	getInstanceId := func(keyName string) (string, string, error) {
 | 
			
		||||
		path := fmt.Sprintf(
 | 
			
		||||
			`SYSTEM\CurrentControlSet\Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}\%s`,
 | 
			
		||||
			keyName,
 | 
			
		||||
		)
 | 
			
		||||
 | 
			
		||||
		key, err := registry.OpenKey(
 | 
			
		||||
			registry.LOCAL_MACHINE,
 | 
			
		||||
			path,
 | 
			
		||||
			registry.READ,
 | 
			
		||||
		)
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return "", "", err
 | 
			
		||||
		}
 | 
			
		||||
		defer key.Close()
 | 
			
		||||
 | 
			
		||||
		componentId, _, err := key.GetStringValue("ComponentId")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return "", "", err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		instanceId, _, err := key.GetStringValue("NetCfgInstanceId")
 | 
			
		||||
 | 
			
		||||
		return componentId, instanceId, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// find list of all network devices
 | 
			
		||||
 | 
			
		||||
	k, err := registry.OpenKey(
 | 
			
		||||
		registry.LOCAL_MACHINE,
 | 
			
		||||
		`SYSTEM\CurrentControlSet\Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}`,
 | 
			
		||||
		registry.READ,
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("Failed to open the adapter registry, TAP driver may be not installed, %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer k.Close()
 | 
			
		||||
 | 
			
		||||
	keys, err := k.ReadSubKeyNames(-1)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// look for matching component id and name
 | 
			
		||||
 | 
			
		||||
	var componentFound bool
 | 
			
		||||
 | 
			
		||||
	for _, v := range keys {
 | 
			
		||||
 | 
			
		||||
		componentId, instanceId, err := getInstanceId(v)
 | 
			
		||||
		if err != nil || componentId != targetComponentId {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		componentFound = true
 | 
			
		||||
 | 
			
		||||
		deviceName, err := getName(instanceId)
 | 
			
		||||
		if err != nil || deviceName != targetDeviceName {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return instanceId, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// provide a descriptive error message
 | 
			
		||||
 | 
			
		||||
	if componentFound {
 | 
			
		||||
		return "", fmt.Errorf("Unable to find tun/tap device with name = %s", targetDeviceName)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return "", fmt.Errorf(
 | 
			
		||||
		"Unable to find device in registry with ComponentId = %s, is tap-windows installed?",
 | 
			
		||||
		targetComponentId,
 | 
			
		||||
	)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// setStatus is used to bring up or bring down the interface
 | 
			
		||||
func setStatus(fd windows.Handle, status bool) error {
 | 
			
		||||
	var code [4]byte
 | 
			
		||||
	if status {
 | 
			
		||||
		binary.LittleEndian.PutUint32(code[:], 1)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var bytesReturned uint32
 | 
			
		||||
	rdbbuf := make([]byte, windows.MAXIMUM_REPARSE_DATA_BUFFER_SIZE)
 | 
			
		||||
	return windows.DeviceIoControl(
 | 
			
		||||
		fd,
 | 
			
		||||
		TAP_IOCTL_SET_MEDIA_STATUS,
 | 
			
		||||
		&code[0],
 | 
			
		||||
		uint32(4),
 | 
			
		||||
		&rdbbuf[0],
 | 
			
		||||
		uint32(len(rdbbuf)),
 | 
			
		||||
		&bytesReturned,
 | 
			
		||||
		nil,
 | 
			
		||||
	)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* When operating in TUN mode we must assign an ip address & subnet to the device.
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
func setTUN(fd windows.Handle, network string) error {
 | 
			
		||||
	var bytesReturned uint32
 | 
			
		||||
	rdbbuf := make([]byte, windows.MAXIMUM_REPARSE_DATA_BUFFER_SIZE)
 | 
			
		||||
	localIP, remoteNet, err := net.ParseCIDR(network)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("Failed to parse network CIDR in config, %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if localIP.To4() == nil {
 | 
			
		||||
		return fmt.Errorf("Provided network(%s) is not a valid IPv4 address", network)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var param [12]byte
 | 
			
		||||
 | 
			
		||||
	copy(param[0:4], localIP.To4())
 | 
			
		||||
	copy(param[4:8], remoteNet.IP.To4())
 | 
			
		||||
	copy(param[8:12], remoteNet.Mask)
 | 
			
		||||
 | 
			
		||||
	return windows.DeviceIoControl(
 | 
			
		||||
		fd,
 | 
			
		||||
		TAP_IOCTL_CONFIG_TUN,
 | 
			
		||||
		¶m[0],
 | 
			
		||||
		uint32(12),
 | 
			
		||||
		&rdbbuf[0],
 | 
			
		||||
		uint32(len(rdbbuf)),
 | 
			
		||||
		&bytesReturned,
 | 
			
		||||
		nil,
 | 
			
		||||
	)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tun *NativeTUN) MTU() (int, error) {
 | 
			
		||||
	var mtu [4]byte
 | 
			
		||||
	var bytesReturned uint32
 | 
			
		||||
	err := windows.DeviceIoControl(
 | 
			
		||||
		tun.fd,
 | 
			
		||||
		TAP_IOCTL_GET_MTU,
 | 
			
		||||
		&mtu[0],
 | 
			
		||||
		uint32(len(mtu)),
 | 
			
		||||
		&mtu[0],
 | 
			
		||||
		uint32(len(mtu)),
 | 
			
		||||
		&bytesReturned,
 | 
			
		||||
		nil,
 | 
			
		||||
	)
 | 
			
		||||
	val := binary.LittleEndian.Uint32(mtu[:])
 | 
			
		||||
	return int(val), err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (tun *NativeTUN) Name() string {
 | 
			
		||||
	return tun.name
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CreateTUN(name string) (TUNDevice, error) {
 | 
			
		||||
 | 
			
		||||
	// find the device in registry.
 | 
			
		||||
 | 
			
		||||
	deviceid, err := getdeviceid(ComponentID, name)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	path := "\\\\.\\Global\\" + deviceid + ".tap"
 | 
			
		||||
	pathp, err := windows.UTF16PtrFromString(path)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// create TUN device
 | 
			
		||||
 | 
			
		||||
	handle, err := windows.CreateFile(
 | 
			
		||||
		pathp,
 | 
			
		||||
		windows.GENERIC_READ|windows.GENERIC_WRITE,
 | 
			
		||||
		0,
 | 
			
		||||
		nil,
 | 
			
		||||
		windows.OPEN_EXISTING,
 | 
			
		||||
		windows.FILE_ATTRIBUTE_SYSTEM|windows.FILE_FLAG_OVERLAPPED,
 | 
			
		||||
		0,
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ro, err := newOverlapped()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		windows.Close(handle)
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	wo, err := newOverlapped()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		windows.Close(handle)
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tun := &NativeTUN{
 | 
			
		||||
		fd:     handle,
 | 
			
		||||
		name:   name,
 | 
			
		||||
		ro:     ro,
 | 
			
		||||
		wo:     wo,
 | 
			
		||||
		events: make(chan TUNEvent, 5),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// find addresses of interface
 | 
			
		||||
	// TODO: fix this hack, the question is how
 | 
			
		||||
 | 
			
		||||
	inter, err := net.InterfaceByName(name)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		windows.Close(handle)
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	addrs, err := inter.Addrs()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		windows.Close(handle)
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var ip net.IP
 | 
			
		||||
	for _, addr := range addrs {
 | 
			
		||||
		ip = func() net.IP {
 | 
			
		||||
			switch v := addr.(type) {
 | 
			
		||||
			case *net.IPNet:
 | 
			
		||||
				return v.IP.To4()
 | 
			
		||||
			case *net.IPAddr:
 | 
			
		||||
				return v.IP.To4()
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		}()
 | 
			
		||||
		if ip != nil {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ip == nil {
 | 
			
		||||
		windows.Close(handle)
 | 
			
		||||
		return nil, errors.New("No IPv4 address found for interface")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// bring up device.
 | 
			
		||||
 | 
			
		||||
	if err := setStatus(handle, true); err != nil {
 | 
			
		||||
		windows.Close(handle)
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// set tun mode
 | 
			
		||||
 | 
			
		||||
	mask := ip.String() + "/0"
 | 
			
		||||
	if err := setTUN(handle, mask); err != nil {
 | 
			
		||||
		windows.Close(handle)
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// start listener
 | 
			
		||||
 | 
			
		||||
	go func(native *NativeTUN, ifname string) {
 | 
			
		||||
		// TODO: Fix this very niave implementation
 | 
			
		||||
		var (
 | 
			
		||||
			statusUp  bool
 | 
			
		||||
			statusMTU int
 | 
			
		||||
		)
 | 
			
		||||
 | 
			
		||||
		for ; ; time.Sleep(time.Second) {
 | 
			
		||||
			intr, err := net.InterfaceByName(name)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				// TODO: handle
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Up / Down event
 | 
			
		||||
			up := (intr.Flags & net.FlagUp) != 0
 | 
			
		||||
			if up != statusUp && up {
 | 
			
		||||
				native.events <- TUNEventUp
 | 
			
		||||
			}
 | 
			
		||||
			if up != statusUp && !up {
 | 
			
		||||
				native.events <- TUNEventDown
 | 
			
		||||
			}
 | 
			
		||||
			statusUp = up
 | 
			
		||||
 | 
			
		||||
			// MTU changes
 | 
			
		||||
			if intr.MTU != statusMTU {
 | 
			
		||||
				native.events <- TUNEventMTUUpdate
 | 
			
		||||
			}
 | 
			
		||||
			statusMTU = intr.MTU
 | 
			
		||||
		}
 | 
			
		||||
	}(tun, name)
 | 
			
		||||
 | 
			
		||||
	return tun, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										44
									
								
								src/uapi_windows.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								src/uapi_windows.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,44 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
/* UAPI on windows uses a bidirectional named pipe
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/Microsoft/go-winio"
 | 
			
		||||
	"golang.org/x/sys/windows"
 | 
			
		||||
	"net"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	ipcErrorIO         = -int64(windows.ERROR_BROKEN_PIPE)
 | 
			
		||||
	ipcErrorNotDefined = -int64(windows.ERROR_SERVICE_SPECIFIC_ERROR)
 | 
			
		||||
	ipcErrorProtocol   = -int64(windows.ERROR_SERVICE_SPECIFIC_ERROR)
 | 
			
		||||
	ipcErrorInvalid    = -int64(windows.ERROR_SERVICE_SPECIFIC_ERROR)
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const PipeNameFmt = "\\\\.\\pipe\\wireguard-ipc-%s"
 | 
			
		||||
 | 
			
		||||
type UAPIListener struct {
 | 
			
		||||
	listener net.Listener
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (uapi *UAPIListener) Accept() (net.Conn, error) {
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (uapi *UAPIListener) Close() error {
 | 
			
		||||
	return uapi.listener.Close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (uapi *UAPIListener) Addr() net.Addr {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewUAPIListener(name string) (net.Listener, error) {
 | 
			
		||||
	path := fmt.Sprintf(PipeNameFmt, name)
 | 
			
		||||
	return winio.ListenPipe(path, &winio.PipeConfig{
 | 
			
		||||
		InputBufferSize:  2048,
 | 
			
		||||
		OutputBufferSize: 2048,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user