wireguard-android/ui/src/main/java/com/wireguard/android/model/TunnelManager.kt
Harsh Shandilya 2f088938c6 ui: replace GlobalScope with a hand-rolled CoroutineScope
GlobalScope has numerous problems[1] that make it unfit for
use in most applications and making it behave correctly requires
an excessive amount of verbosity that's alleviated simply by using
any other scope. Since we run multiple operations in the context
of the application's lifecycle, introduce a new scope that is created
when our application is, and cancelled upon its termination.

While at it, make the scope default to Dispatchers.IO to reduce pressure
on the UI event loop. Tasks requiring access to the UI thread appropriately
switch context making the change completely safe.

1: https://medium.com/@elizarov/the-reason-to-avoid-globalscope-835337445abc

Signed-off-by: Harsh Shandilya <me@msfjarvis.dev>
2020-09-16 18:01:06 +02:00

259 lines
11 KiB
Kotlin

/*
* Copyright © 2017-2019 WireGuard LLC. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package com.wireguard.android.model
import android.annotation.SuppressLint
import android.content.BroadcastReceiver
import android.content.Context
import android.content.Intent
import android.os.Build
import android.util.Log
import androidx.databinding.BaseObservable
import androidx.databinding.Bindable
import com.wireguard.android.Application.Companion.get
import com.wireguard.android.Application.Companion.getBackend
import com.wireguard.android.Application.Companion.getSharedPreferences
import com.wireguard.android.Application.Companion.getTunnelManager
import com.wireguard.android.BR
import com.wireguard.android.R
import com.wireguard.android.backend.Statistics
import com.wireguard.android.backend.Tunnel
import com.wireguard.android.configStore.ConfigStore
import com.wireguard.android.databinding.ObservableSortedKeyedArrayList
import com.wireguard.android.util.applicationScope
import com.wireguard.config.Config
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
/**
* Maintains and mediates changes to the set of available WireGuard tunnels,
*/
class TunnelManager(private val configStore: ConfigStore) : BaseObservable() {
private val tunnels = CompletableDeferred<ObservableSortedKeyedArrayList<String, ObservableTunnel>>()
private val context: Context = get()
private val tunnelMap: ObservableSortedKeyedArrayList<String, ObservableTunnel> = ObservableSortedKeyedArrayList(TunnelComparator)
private var haveLoaded = false
private fun addToList(name: String, config: Config?, state: Tunnel.State): ObservableTunnel {
val tunnel = ObservableTunnel(this, name, config, state)
tunnelMap.add(tunnel)
return tunnel
}
suspend fun getTunnels(): ObservableSortedKeyedArrayList<String, ObservableTunnel> = tunnels.await()
suspend fun create(name: String, config: Config?): ObservableTunnel = withContext(Dispatchers.Main.immediate) {
if (Tunnel.isNameInvalid(name))
throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name))
if (tunnelMap.containsKey(name))
throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name))
addToList(name, withContext(Dispatchers.IO) { configStore.create(name, config!!) }, Tunnel.State.DOWN)
}
suspend fun delete(tunnel: ObservableTunnel) = withContext(Dispatchers.Main.immediate) {
val originalState = tunnel.state
val wasLastUsed = tunnel == lastUsedTunnel
// Make sure nothing touches the tunnel.
if (wasLastUsed)
lastUsedTunnel = null
tunnelMap.remove(tunnel)
try {
if (originalState == Tunnel.State.UP)
withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.DOWN, null) }
try {
withContext(Dispatchers.IO) { configStore.delete(tunnel.name) }
} catch (e: Throwable) {
if (originalState == Tunnel.State.UP)
withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) }
throw e
}
} catch (e: Throwable) {
// Failure, put the tunnel back.
tunnelMap.add(tunnel)
if (wasLastUsed)
lastUsedTunnel = tunnel
throw e
}
}
@get:Bindable
@SuppressLint("ApplySharedPref")
var lastUsedTunnel: ObservableTunnel? = null
private set(value) {
if (value == field) return
field = value
notifyPropertyChanged(BR.lastUsedTunnel)
if (value != null)
getSharedPreferences().edit().putString(KEY_LAST_USED_TUNNEL, value.name).commit()
else
getSharedPreferences().edit().remove(KEY_LAST_USED_TUNNEL).commit()
}
suspend fun getTunnelConfig(tunnel: ObservableTunnel): Config = withContext(Dispatchers.Main.immediate) {
tunnel.onConfigChanged(withContext(Dispatchers.IO) { configStore.load(tunnel.name) })!!
}
fun onCreate() {
applicationScope.launch {
try {
onTunnelsLoaded(withContext(Dispatchers.IO) { configStore.enumerate() }, withContext(Dispatchers.IO) { getBackend().runningTunnelNames })
} catch (e: Throwable) {
Log.e(TAG, Log.getStackTraceString(e))
}
}
}
private fun onTunnelsLoaded(present: Iterable<String>, running: Collection<String>) {
for (name in present)
addToList(name, null, if (running.contains(name)) Tunnel.State.UP else Tunnel.State.DOWN)
val lastUsedName = getSharedPreferences().getString(KEY_LAST_USED_TUNNEL, null)
if (lastUsedName != null)
lastUsedTunnel = tunnelMap[lastUsedName]
haveLoaded = true
restoreState(true)
tunnels.complete(tunnelMap)
}
private fun refreshTunnelStates() {
applicationScope.launch {
try {
val running = withContext(Dispatchers.IO) { getBackend().runningTunnelNames }
for (tunnel in tunnelMap)
tunnel.onStateChanged(if (running.contains(tunnel.name)) Tunnel.State.UP else Tunnel.State.DOWN)
} catch (e: Throwable) {
Log.e(TAG, Log.getStackTraceString(e))
}
}
}
fun restoreState(force: Boolean) {
if (!haveLoaded || (!force && !getSharedPreferences().getBoolean(KEY_RESTORE_ON_BOOT, false)))
return
val previouslyRunning = getSharedPreferences().getStringSet(KEY_RUNNING_TUNNELS, null)
?: return
if (previouslyRunning.isEmpty()) return
applicationScope.launch {
withContext(Dispatchers.IO) {
try {
tunnelMap.filter { previouslyRunning.contains(it.name) }.map { async(SupervisorJob()) { setTunnelState(it, Tunnel.State.UP) } }.awaitAll()
} catch (e: Throwable) {
Log.e(TAG, Log.getStackTraceString(e))
}
}
}
}
@SuppressLint("ApplySharedPref")
fun saveState() {
getSharedPreferences().edit().putStringSet(KEY_RUNNING_TUNNELS, tunnelMap.filter { it.state == Tunnel.State.UP }.map { it.name }.toSet()).commit()
}
suspend fun setTunnelConfig(tunnel: ObservableTunnel, config: Config): Config = withContext(Dispatchers.Main.immediate) {
tunnel.onConfigChanged(withContext(Dispatchers.IO) {
getBackend().setState(tunnel, tunnel.state, config)
configStore.save(tunnel.name, config)
})!!
}
suspend fun setTunnelName(tunnel: ObservableTunnel, name: String): String = withContext(Dispatchers.Main.immediate) {
if (Tunnel.isNameInvalid(name))
throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name))
if (tunnelMap.containsKey(name)) {
throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name))
}
val originalState = tunnel.state
val wasLastUsed = tunnel == lastUsedTunnel
// Make sure nothing touches the tunnel.
if (wasLastUsed)
lastUsedTunnel = null
tunnelMap.remove(tunnel)
var throwable: Throwable? = null
var newName: String? = null
try {
if (originalState == Tunnel.State.UP)
withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.DOWN, null) }
withContext(Dispatchers.IO) { configStore.rename(tunnel.name, name) }
newName = tunnel.onNameChanged(name)
if (originalState == Tunnel.State.UP)
withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) }
} catch (e: Throwable) {
throwable = e
// On failure, we don't know what state the tunnel might be in. Fix that.
getTunnelState(tunnel)
}
// Add the tunnel back to the manager, under whatever name it thinks it has.
tunnelMap.add(tunnel)
if (wasLastUsed)
lastUsedTunnel = tunnel
if (throwable != null)
throw throwable
newName!!
}
suspend fun setTunnelState(tunnel: ObservableTunnel, state: Tunnel.State): Tunnel.State = withContext(Dispatchers.Main.immediate) {
var newState = tunnel.state
var throwable: Throwable? = null
try {
newState = withContext(Dispatchers.IO) { getBackend().setState(tunnel, state, tunnel.getConfigAsync()) }
if (newState == Tunnel.State.UP)
lastUsedTunnel = tunnel
} catch (e: Throwable) {
throwable = e
}
tunnel.onStateChanged(newState)
saveState()
if (throwable != null)
throw throwable
newState
}
class IntentReceiver : BroadcastReceiver() {
override fun onReceive(context: Context, intent: Intent?) {
val manager = getTunnelManager()
if (intent == null) return
val action = intent.action ?: return
if ("com.wireguard.android.action.REFRESH_TUNNEL_STATES" == action) {
manager.refreshTunnelStates()
return
}
if (Build.VERSION.SDK_INT < Build.VERSION_CODES.M || !getSharedPreferences().getBoolean("allow_remote_control_intents", false))
return
val state: Tunnel.State
state = when (action) {
"com.wireguard.android.action.SET_TUNNEL_UP" -> Tunnel.State.UP
"com.wireguard.android.action.SET_TUNNEL_DOWN" -> Tunnel.State.DOWN
else -> return
}
val tunnelName = intent.getStringExtra("tunnel") ?: return
applicationScope.launch {
val tunnels = manager.getTunnels()
val tunnel = tunnels[tunnelName] ?: return@launch
manager.setTunnelState(tunnel, state)
}
}
}
suspend fun getTunnelState(tunnel: ObservableTunnel): Tunnel.State = withContext(Dispatchers.Main.immediate) {
tunnel.onStateChanged(withContext(Dispatchers.IO) { getBackend().getState(tunnel) })
}
suspend fun getTunnelStatistics(tunnel: ObservableTunnel): Statistics = withContext(Dispatchers.Main.immediate) {
tunnel.onStatisticsChanged(withContext(Dispatchers.IO) { getBackend().getStatistics(tunnel) })!!
}
companion object {
private const val TAG = "WireGuard/TunnelManager"
private const val KEY_LAST_USED_TUNNEL = "last_used_tunnel"
private const val KEY_RESTORE_ON_BOOT = "restore_on_boot"
private const val KEY_RUNNING_TUNNELS = "enabled_configs"
}
}