255 lines
11 KiB
Kotlin
255 lines
11 KiB
Kotlin
/*
|
|
* Copyright © 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*/
|
|
package com.wireguard.android.model
|
|
|
|
import android.content.BroadcastReceiver
|
|
import android.content.Context
|
|
import android.content.Intent
|
|
import android.os.Build
|
|
import android.util.Log
|
|
import android.widget.Toast
|
|
import androidx.databinding.BaseObservable
|
|
import androidx.databinding.Bindable
|
|
import com.wireguard.android.Application.Companion.get
|
|
import com.wireguard.android.Application.Companion.getBackend
|
|
import com.wireguard.android.Application.Companion.getTunnelManager
|
|
import com.wireguard.android.BR
|
|
import com.wireguard.android.R
|
|
import com.wireguard.android.backend.Statistics
|
|
import com.wireguard.android.backend.Tunnel
|
|
import com.wireguard.android.configStore.ConfigStore
|
|
import com.wireguard.android.databinding.ObservableSortedKeyedArrayList
|
|
import com.wireguard.android.util.ErrorMessages
|
|
import com.wireguard.android.util.UserKnobs
|
|
import com.wireguard.android.util.applicationScope
|
|
import com.wireguard.config.Config
|
|
import kotlinx.coroutines.CompletableDeferred
|
|
import kotlinx.coroutines.Dispatchers
|
|
import kotlinx.coroutines.SupervisorJob
|
|
import kotlinx.coroutines.async
|
|
import kotlinx.coroutines.awaitAll
|
|
import kotlinx.coroutines.flow.first
|
|
import kotlinx.coroutines.launch
|
|
import kotlinx.coroutines.withContext
|
|
|
|
/**
|
|
* Maintains and mediates changes to the set of available WireGuard tunnels,
|
|
*/
|
|
class TunnelManager(private val configStore: ConfigStore) : BaseObservable() {
|
|
private val tunnels = CompletableDeferred<ObservableSortedKeyedArrayList<String, ObservableTunnel>>()
|
|
private val context: Context = get()
|
|
private val tunnelMap: ObservableSortedKeyedArrayList<String, ObservableTunnel> = ObservableSortedKeyedArrayList(TunnelComparator)
|
|
private var haveLoaded = false
|
|
|
|
private fun addToList(name: String, config: Config?, state: Tunnel.State): ObservableTunnel {
|
|
val tunnel = ObservableTunnel(this, name, config, state)
|
|
tunnelMap.add(tunnel)
|
|
return tunnel
|
|
}
|
|
|
|
suspend fun getTunnels(): ObservableSortedKeyedArrayList<String, ObservableTunnel> = tunnels.await()
|
|
|
|
suspend fun create(name: String, config: Config?): ObservableTunnel = withContext(Dispatchers.Main.immediate) {
|
|
if (Tunnel.isNameInvalid(name))
|
|
throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name))
|
|
if (tunnelMap.containsKey(name))
|
|
throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name))
|
|
addToList(name, withContext(Dispatchers.IO) { configStore.create(name, config!!) }, Tunnel.State.DOWN)
|
|
}
|
|
|
|
suspend fun delete(tunnel: ObservableTunnel) = withContext(Dispatchers.Main.immediate) {
|
|
val originalState = tunnel.state
|
|
val wasLastUsed = tunnel == lastUsedTunnel
|
|
// Make sure nothing touches the tunnel.
|
|
if (wasLastUsed)
|
|
lastUsedTunnel = null
|
|
tunnelMap.remove(tunnel)
|
|
try {
|
|
if (originalState == Tunnel.State.UP)
|
|
withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.DOWN, null) }
|
|
try {
|
|
withContext(Dispatchers.IO) { configStore.delete(tunnel.name) }
|
|
} catch (e: Throwable) {
|
|
if (originalState == Tunnel.State.UP)
|
|
withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) }
|
|
throw e
|
|
}
|
|
} catch (e: Throwable) {
|
|
// Failure, put the tunnel back.
|
|
tunnelMap.add(tunnel)
|
|
if (wasLastUsed)
|
|
lastUsedTunnel = tunnel
|
|
throw e
|
|
}
|
|
}
|
|
|
|
@get:Bindable
|
|
var lastUsedTunnel: ObservableTunnel? = null
|
|
private set(value) {
|
|
if (value == field) return
|
|
field = value
|
|
notifyPropertyChanged(BR.lastUsedTunnel)
|
|
applicationScope.launch { UserKnobs.setLastUsedTunnel(value?.name) }
|
|
}
|
|
|
|
suspend fun getTunnelConfig(tunnel: ObservableTunnel): Config = withContext(Dispatchers.Main.immediate) {
|
|
tunnel.onConfigChanged(withContext(Dispatchers.IO) { configStore.load(tunnel.name) })!!
|
|
}
|
|
|
|
fun onCreate() {
|
|
applicationScope.launch {
|
|
try {
|
|
onTunnelsLoaded(withContext(Dispatchers.IO) { configStore.enumerate() }, withContext(Dispatchers.IO) { getBackend().runningTunnelNames })
|
|
} catch (e: Throwable) {
|
|
Log.e(TAG, Log.getStackTraceString(e))
|
|
}
|
|
}
|
|
}
|
|
|
|
private fun onTunnelsLoaded(present: Iterable<String>, running: Collection<String>) {
|
|
for (name in present)
|
|
addToList(name, null, if (running.contains(name)) Tunnel.State.UP else Tunnel.State.DOWN)
|
|
applicationScope.launch {
|
|
val lastUsedName = UserKnobs.lastUsedTunnel.first()
|
|
if (lastUsedName != null)
|
|
lastUsedTunnel = tunnelMap[lastUsedName]
|
|
haveLoaded = true
|
|
restoreState(true)
|
|
tunnels.complete(tunnelMap)
|
|
}
|
|
}
|
|
|
|
private fun refreshTunnelStates() {
|
|
applicationScope.launch {
|
|
try {
|
|
val running = withContext(Dispatchers.IO) { getBackend().runningTunnelNames }
|
|
for (tunnel in tunnelMap)
|
|
tunnel.onStateChanged(if (running.contains(tunnel.name)) Tunnel.State.UP else Tunnel.State.DOWN)
|
|
} catch (e: Throwable) {
|
|
Log.e(TAG, Log.getStackTraceString(e))
|
|
}
|
|
}
|
|
}
|
|
|
|
suspend fun restoreState(force: Boolean) {
|
|
if (!haveLoaded || (!force && !UserKnobs.restoreOnBoot.first()))
|
|
return
|
|
val previouslyRunning = UserKnobs.runningTunnels.first()
|
|
if (previouslyRunning.isEmpty()) return
|
|
withContext(Dispatchers.IO) {
|
|
try {
|
|
tunnelMap.filter { previouslyRunning.contains(it.name) }.map { async(Dispatchers.IO + SupervisorJob()) { setTunnelState(it, Tunnel.State.UP) } }.awaitAll()
|
|
} catch (e: Throwable) {
|
|
Log.e(TAG, Log.getStackTraceString(e))
|
|
}
|
|
}
|
|
}
|
|
|
|
suspend fun saveState() {
|
|
UserKnobs.setRunningTunnels(tunnelMap.filter { it.state == Tunnel.State.UP }.map { it.name }.toSet())
|
|
}
|
|
|
|
suspend fun setTunnelConfig(tunnel: ObservableTunnel, config: Config): Config = withContext(Dispatchers.Main.immediate) {
|
|
tunnel.onConfigChanged(withContext(Dispatchers.IO) {
|
|
getBackend().setState(tunnel, tunnel.state, config)
|
|
configStore.save(tunnel.name, config)
|
|
})!!
|
|
}
|
|
|
|
suspend fun setTunnelName(tunnel: ObservableTunnel, name: String): String = withContext(Dispatchers.Main.immediate) {
|
|
if (Tunnel.isNameInvalid(name))
|
|
throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name))
|
|
if (tunnelMap.containsKey(name)) {
|
|
throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name))
|
|
}
|
|
val originalState = tunnel.state
|
|
val wasLastUsed = tunnel == lastUsedTunnel
|
|
// Make sure nothing touches the tunnel.
|
|
if (wasLastUsed)
|
|
lastUsedTunnel = null
|
|
tunnelMap.remove(tunnel)
|
|
var throwable: Throwable? = null
|
|
var newName: String? = null
|
|
try {
|
|
if (originalState == Tunnel.State.UP)
|
|
withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.DOWN, null) }
|
|
withContext(Dispatchers.IO) { configStore.rename(tunnel.name, name) }
|
|
newName = tunnel.onNameChanged(name)
|
|
if (originalState == Tunnel.State.UP)
|
|
withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) }
|
|
} catch (e: Throwable) {
|
|
throwable = e
|
|
// On failure, we don't know what state the tunnel might be in. Fix that.
|
|
getTunnelState(tunnel)
|
|
}
|
|
// Add the tunnel back to the manager, under whatever name it thinks it has.
|
|
tunnelMap.add(tunnel)
|
|
if (wasLastUsed)
|
|
lastUsedTunnel = tunnel
|
|
if (throwable != null)
|
|
throw throwable
|
|
newName!!
|
|
}
|
|
|
|
suspend fun setTunnelState(tunnel: ObservableTunnel, state: Tunnel.State): Tunnel.State = withContext(Dispatchers.Main.immediate) {
|
|
var newState = tunnel.state
|
|
var throwable: Throwable? = null
|
|
try {
|
|
newState = withContext(Dispatchers.IO) { getBackend().setState(tunnel, state, tunnel.getConfigAsync()) }
|
|
if (newState == Tunnel.State.UP)
|
|
lastUsedTunnel = tunnel
|
|
} catch (e: Throwable) {
|
|
throwable = e
|
|
}
|
|
tunnel.onStateChanged(newState)
|
|
saveState()
|
|
if (throwable != null)
|
|
throw throwable
|
|
newState
|
|
}
|
|
|
|
class IntentReceiver : BroadcastReceiver() {
|
|
override fun onReceive(context: Context, intent: Intent?) {
|
|
applicationScope.launch {
|
|
val manager = getTunnelManager()
|
|
if (intent == null) return@launch
|
|
val action = intent.action ?: return@launch
|
|
if ("com.wireguard.android.action.REFRESH_TUNNEL_STATES" == action) {
|
|
manager.refreshTunnelStates()
|
|
return@launch
|
|
}
|
|
if (Build.VERSION.SDK_INT < Build.VERSION_CODES.M || !UserKnobs.allowRemoteControlIntents.first())
|
|
return@launch
|
|
val state: Tunnel.State
|
|
state = when (action) {
|
|
"com.wireguard.android.action.SET_TUNNEL_UP" -> Tunnel.State.UP
|
|
"com.wireguard.android.action.SET_TUNNEL_DOWN" -> Tunnel.State.DOWN
|
|
else -> return@launch
|
|
}
|
|
val tunnelName = intent.getStringExtra("tunnel") ?: return@launch
|
|
val tunnels = manager.getTunnels()
|
|
val tunnel = tunnels[tunnelName] ?: return@launch
|
|
try {
|
|
manager.setTunnelState(tunnel, state)
|
|
} catch (e: Throwable) {
|
|
Toast.makeText(context, ErrorMessages[e], Toast.LENGTH_LONG).show()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
suspend fun getTunnelState(tunnel: ObservableTunnel): Tunnel.State = withContext(Dispatchers.Main.immediate) {
|
|
tunnel.onStateChanged(withContext(Dispatchers.IO) { getBackend().getState(tunnel) })
|
|
}
|
|
|
|
suspend fun getTunnelStatistics(tunnel: ObservableTunnel): Statistics = withContext(Dispatchers.Main.immediate) {
|
|
tunnel.onStatisticsChanged(withContext(Dispatchers.IO) { getBackend().getStatistics(tunnel) })!!
|
|
}
|
|
|
|
companion object {
|
|
private const val TAG = "WireGuard/TunnelManager"
|
|
}
|
|
}
|