feat: Do not quit when keyfile can not be opened, create key

- Refactor AddTestPeer

Signed-off-by: HeshamTB <hishaminv@gmail.com>
This commit is contained in:
HeshamTB 2024-03-18 01:35:37 +03:00
parent 44961e91dc
commit e5e4641264
Signed by: Hesham
GPG Key ID: 74876157D199B09E
2 changed files with 50 additions and 62 deletions

View File

@ -4,10 +4,8 @@ import (
"bufio" "bufio"
"fmt" "fmt"
"log/slog" "log/slog"
"net"
"net/http" "net/http"
"net/netip" "net/netip"
"net/url"
"os" "os"
"os/signal" "os/signal"
"strings" "strings"
@ -210,7 +208,7 @@ func setup(ctx *cli.Context) error {
if err != nil { if err != nil {
return cli.Exit(err, 1) return cli.Exit(err, 1)
} }
slog.Debug(fmt.Sprintf("Private key: %s", privateKey.String())) slog.Debug(fmt.Sprintf("new public key: %s", privateKey.PublicKey().String()))
return nil return nil
} }
@ -223,8 +221,13 @@ func setup(ctx *cli.Context) error {
privKeyFile, err := os.Open(PrivateKeyPath) privKeyFile, err := os.Open(PrivateKeyPath)
defer privKeyFile.Close() defer privKeyFile.Close()
if err != nil { if err != nil {
return cli.Exit(err, 1) slog.Error(err.Error())
slog.Info("Could not open private key file")
err := createPrivKey()
if err != nil {
return err
} }
} else {
privateKeyStr := make([]byte, 45) privateKeyStr := make([]byte, 45)
n, err := privKeyFile.Read(privateKeyStr) n, err := privKeyFile.Read(privateKeyStr)
if err != nil { if err != nil {
@ -242,6 +245,7 @@ func setup(ctx *cli.Context) error {
} }
slog.Debug("Private key parsed and is correct") slog.Debug("Private key parsed and is correct")
} }
}
wg, err := hvpnnode3.InitWGLink( wg, err := hvpnnode3.InitWGLink(
InterfaceName, InterfaceName,
@ -275,7 +279,7 @@ func setup(ctx *cli.Context) error {
ipPool, err := hvpnnode3.NewPool(VPNIPCIDR) ipPool, err := hvpnnode3.NewPool(VPNIPCIDR)
if err != nil { if err != nil {
slog.Error(fmt.Sprintf("main.IPPool: %s", err)) slog.Error(fmt.Sprintf("IPPool: %s", err))
os.Exit(1) os.Exit(1)
} }
slog.Debug(fmt.Sprintf("Init ip pool %s", VPNIPCIDR)) slog.Debug(fmt.Sprintf("Init ip pool %s", VPNIPCIDR))
@ -286,13 +290,13 @@ func setup(ctx *cli.Context) error {
os.Exit(1) os.Exit(1)
} }
slog.Debug(fmt.Sprintf("main.testVip: IP Pool Test IP: %s", testVip.String())) slog.Debug(fmt.Sprintf("IP Pool Test IP: %s", testVip.String()))
err = ipPool.Free(testVip) err = ipPool.Free(testVip)
if err != nil { if err != nil {
slog.Error("main.testVip: Could not free test Vip from IPPool!", err) slog.Error("Could not free test Vip from IPPool!", err)
os.Exit(1) return cli.Exit(err.Error(), 1)
} }
slog.Debug("main.testVip: Test IP Freed") slog.Debug("Test IP Freed")
IPPool = ipPool IPPool = ipPool
wgLink.IPPool = ipPool wgLink.IPPool = ipPool
@ -328,47 +332,22 @@ func testWgPeerAdd(wgLink *hvpnnode3.WGLink) error {
return err return err
} }
publicKey := privateKey.PublicKey() publicKey := privateKey.PublicKey()
_, err = wgLink.AddPeer(publicKey.String())
if err != nil {
slog.Error(err.Error())
return err
}
slog.Debug(fmt.Sprintf("Added test peer %v", publicKey.String()))
urlsafe := url.QueryEscape(publicKey.String()) err = wgLink.DeletePeer(publicKey.String())
slog.Debug(urlsafe)
ip, err := wgLink.Allocate()
if err != nil { if err != nil {
return err return err
} }
peers, err := wgLink.GetAllPeers()
peerConfig := wgtypes.PeerConfig{ if len(peers) != 0 {
PublicKey: publicKey, slog.Warn(fmt.Sprintf("Expected 0 peers, got %d", len(peers)))
AllowedIPs: []net.IPNet{
{
IP: ip,
Mask: net.IPv4Mask(255, 255, 255, 255),
},
},
}
wgConfig := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peerConfig},
}
err = wgLink.ConfigureDevice(wgLink.Name, wgConfig)
if err != nil {
return err
}
slog.Debug(fmt.Sprintf("Added test peer %v", peerConfig.PublicKey))
wgConfig.ReplacePeers = true
wgConfig.Peers = []wgtypes.PeerConfig{}
err = wgLink.ConfigureDevice(wgLink.Name, wgConfig)
if err != nil {
return err
} }
slog.Debug("Removed test peer") slog.Debug("Removed test peer")
wgLink.Free(ip)
slog.Debug("Freed test peer ip")
return nil return nil
} }

13
link.go
View File

@ -185,12 +185,13 @@ func (wg *WGLink) GetPeer(publickey string) (*wgtypes.Peer, error) {
} }
func (wg *WGLink) getPeer(pubkey wgtypes.Key) (*wgtypes.Peer, error) { func (wg *WGLink) getPeer(pubkey wgtypes.Key) (*wgtypes.Peer, error) {
dev, err := wg.Device(wg.Name)
peers, err := wg.GetAllPeers()
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, peer := range dev.Peers { for _, peer := range peers {
if peer.PublicKey == pubkey { if peer.PublicKey == pubkey {
return &peer, nil return &peer, nil
} }
@ -198,6 +199,14 @@ func (wg *WGLink) getPeer(pubkey wgtypes.Key) (*wgtypes.Peer, error) {
return nil, proto.PeerDoesNotExist return nil, proto.PeerDoesNotExist
} }
func (wg *WGLink) GetAllPeers() ([]wgtypes.Peer, error) {
dev, err := wg.Device(wg.Name)
if err != nil {
return nil, err
}
return dev.Peers, nil
}
func createARemovePeerCfg(publickey wgtypes.Key) wgtypes.Config { func createARemovePeerCfg(publickey wgtypes.Key) wgtypes.Config {
rmPeerCfg := wgtypes.PeerConfig{ rmPeerCfg := wgtypes.PeerConfig{
Remove: true, Remove: true,