ui: update cleanups

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld 2023-05-03 14:28:40 +02:00
parent 99336f6552
commit 8800f861ed

View File

@ -18,7 +18,9 @@ import androidx.core.content.IntentCompat
import com.wireguard.android.Application import com.wireguard.android.Application
import com.wireguard.android.BuildConfig import com.wireguard.android.BuildConfig
import com.wireguard.android.util.UserKnobs import com.wireguard.android.util.UserKnobs
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.asStateFlow import kotlinx.coroutines.flow.asStateFlow
@ -41,12 +43,16 @@ import kotlin.time.Duration.Companion.seconds
object Updater { object Updater {
private const val TAG = "WireGuard/Updater" private const val TAG = "WireGuard/Updater"
private const val LATEST_VERSION_URL = "https://download.wireguard.com/android-client/latest.sig" private const val LATEST_VERSION_URL =
"https://download.wireguard.com/android-client/latest.sig"
private const val APK_PATH_URL = "https://download.wireguard.com/android-client/%s" private const val APK_PATH_URL = "https://download.wireguard.com/android-client/%s"
private const val APK_NAME_PREFIX = BuildConfig.APPLICATION_ID + "-" private val APK_NAME_PREFIX = BuildConfig.APPLICATION_ID.removeSuffix(".debug") + "-"
private const val APK_NAME_SUFFIX = ".apk" private const val APK_NAME_SUFFIX = ".apk"
private const val RELEASE_PUBLIC_KEY_BASE64 = "RWTAzwGRYr3EC9px0Ia3fbttz8WcVN6wrOwWp2delz4el6SI8XmkKSMp" private const val RELEASE_PUBLIC_KEY_BASE64 =
private val CURRENT_VERSION = BuildConfig.VERSION_NAME.removeSuffix("-debug") "RWTAzwGRYr3EC9px0Ia3fbttz8WcVN6wrOwWp2delz4el6SI8XmkKSMp"
private val CURRENT_VERSION = Version(BuildConfig.VERSION_NAME.removeSuffix("-debug"))
private val updaterScope = CoroutineScope(Job() + Dispatchers.IO)
sealed class Progress { sealed class Progress {
object Complete : Progress() object Complete : Progress()
@ -89,7 +95,7 @@ object Updater {
class Failure(val error: Throwable) : Progress() { class Failure(val error: Throwable) : Progress() {
fun retry() { fun retry() {
Application.getCoroutineScope().launch { updaterScope.launch {
downloadAndUpdateWrapErrors() downloadAndUpdateWrapErrors()
} }
} }
@ -104,29 +110,61 @@ object Updater {
mutableState.emit(progress) mutableState.emit(progress)
} }
private fun versionIsNewer(lhs: String, rhs: String): Boolean { private class Sha256Digest(hex: String) {
val lhsParts = lhs.split(".") val bytes: ByteArray
val rhsParts = rhs.split(".")
if (lhsParts.isEmpty() || rhsParts.isEmpty())
throw InvalidParameterException("Version is empty")
for (i in 0 until max(lhsParts.size, rhsParts.size)) { init {
val lhsPart = if (i < lhsParts.size) lhsParts[i].toULong() else 0UL if (hex.length != 64)
val rhsPart = if (i < rhsParts.size) rhsParts[i].toULong() else 0UL throw InvalidParameterException("SHA256 hashes must be 32 bytes long")
if (lhsPart == rhsPart) bytes = hex.chunked(2).map { it.toInt(16).toByte() }.toByteArray()
continue
return lhsPart > rhsPart
} }
return false
} }
private fun versionOfFile(name: String): String? { @OptIn(ExperimentalUnsignedTypes::class)
private class Version(version: String) : Comparable<Version> {
val parts: ULongArray
init {
val strParts = version.split(".")
if (strParts.isEmpty())
throw InvalidParameterException("Version has no parts")
parts = ULongArray(strParts.size)
for (i in parts.indices) {
parts[i] = strParts[i].toULong()
}
}
override fun toString(): String {
return parts.joinToString(".")
}
override fun compareTo(other: Version): Int {
for (i in 0 until max(parts.size, other.parts.size)) {
val lhsPart = if (i < parts.size) parts[i] else 0UL
val rhsPart = if (i < other.parts.size) other.parts[i] else 0UL
if (lhsPart > rhsPart)
return 1
else if (lhsPart < rhsPart)
return -1
}
return 0
}
}
private class Update(val fileName: String, val version: Version, val hash: Sha256Digest)
private fun versionOfFile(name: String): Version? {
if (!name.startsWith(APK_NAME_PREFIX) || !name.endsWith(APK_NAME_SUFFIX)) if (!name.startsWith(APK_NAME_PREFIX) || !name.endsWith(APK_NAME_SUFFIX))
return null return null
return name.substring(APK_NAME_PREFIX.length, name.length - APK_NAME_SUFFIX.length) return try {
Version(name.substring(APK_NAME_PREFIX.length, name.length - APK_NAME_SUFFIX.length))
} catch (_: Throwable) {
null
}
} }
private fun verifySignedFileList(signifyDigest: String): Map<String, Sha256Digest> { private fun verifySignedFileList(signifyDigest: String): List<Update> {
val updates = ArrayList<Update>(1)
val publicKeyBytes = Base64.decode(RELEASE_PUBLIC_KEY_BASE64, Base64.DEFAULT) val publicKeyBytes = Base64.decode(RELEASE_PUBLIC_KEY_BASE64, Base64.DEFAULT)
if (publicKeyBytes == null || publicKeyBytes.size != 32 + 10 || publicKeyBytes[0] != 'E'.code.toByte() || publicKeyBytes[1] != 'd'.code.toByte()) if (publicKeyBytes == null || publicKeyBytes.size != 32 + 10 || publicKeyBytes[0] != 'E'.code.toByte() || publicKeyBytes[1] != 'd'.code.toByte())
throw InvalidKeyException("Invalid public key") throw InvalidKeyException("Invalid public key")
@ -149,32 +187,23 @@ object Updater {
) )
) )
throw SecurityException("Invalid signature") throw SecurityException("Invalid signature")
val hashes: MutableMap<String, Sha256Digest> = HashMap()
for (line in lines[2].split("\n").dropLastWhile { it.isEmpty() }) { for (line in lines[2].split("\n").dropLastWhile { it.isEmpty() }) {
val components = line.split(" ", limit = 2) val components = line.split(" ", limit = 2)
if (components.size != 2) if (components.size != 2)
throw InvalidParameterException("Invalid file list format: too few components") throw InvalidParameterException("Invalid file list format: too few components")
hashes[components[1]] = Sha256Digest(components[0]) /* If version is null, it's not a file we understand, but still a legitimate entry, so don't throw. */
val version = versionOfFile(components[1]) ?: continue
updates.add(Update(components[1], version, Sha256Digest(components[0])))
} }
return hashes return updates
} }
private class Sha256Digest(hex: String) { private fun checkForUpdates(): Update? {
val bytes: ByteArray
init {
if (hex.length != 64)
throw InvalidParameterException("SHA256 hashes must be 32 bytes long")
bytes = hex.chunked(2).map { it.toInt(16).toByte() }.toByteArray()
}
}
private fun checkForUpdates(): Pair<String, Sha256Digest> {
val connection = URL(LATEST_VERSION_URL).openConnection() as HttpURLConnection val connection = URL(LATEST_VERSION_URL).openConnection() as HttpURLConnection
connection.setRequestProperty("User-Agent", Application.USER_AGENT) connection.setRequestProperty("User-Agent", Application.USER_AGENT)
connection.connect() connection.connect()
if (connection.responseCode != HttpURLConnection.HTTP_OK) if (connection.responseCode != HttpURLConnection.HTTP_OK)
throw IOException("File list could not be fetched: ${connection.responseCode}") throw IOException(connection.responseMessage)
var fileListBytes = ByteArray(1024 * 512 /* 512 KiB */) var fileListBytes = ByteArray(1024 * 512 /* 512 KiB */)
connection.inputStream.use { connection.inputStream.use {
val len = it.read(fileListBytes) val len = it.read(fileListBytes)
@ -182,26 +211,7 @@ object Updater {
throw IOException("File list is empty") throw IOException("File list is empty")
fileListBytes = fileListBytes.sliceArray(0 until len) fileListBytes = fileListBytes.sliceArray(0 until len)
} }
val fileList = verifySignedFileList(fileListBytes.decodeToString()) return verifySignedFileList(fileListBytes.decodeToString()).maxByOrNull { it.version }
if (fileList.isEmpty())
throw InvalidParameterException("File list is empty")
var newestFile: String? = null
var newestVersion: String? = null
var newestFileHash: Sha256Digest? = null
for (file in fileList) {
val fileVersion = versionOfFile(file.key)
try {
if (fileVersion != null && (newestVersion == null || versionIsNewer(fileVersion, newestVersion))) {
newestVersion = fileVersion
newestFile = file.key
newestFileHash = file.value
}
} catch (_: Throwable) {
}
}
if (newestFile == null || newestFileHash == null)
throw InvalidParameterException("File list is empty")
return Pair(newestFile, newestFileHash)
} }
private suspend fun downloadAndUpdate() = withContext(Dispatchers.IO) { private suspend fun downloadAndUpdate() = withContext(Dispatchers.IO) {
@ -224,14 +234,14 @@ object Updater {
emitProgress(Progress.Rechecking) emitProgress(Progress.Rechecking)
val update = checkForUpdates() val update = checkForUpdates()
val updateVersion = versionOfFile(checkForUpdates().first) ?: throw Exception("No versions returned") if (update == null || update.version <= CURRENT_VERSION) {
if (!versionIsNewer(updateVersion, CURRENT_VERSION)) {
emitProgress(Progress.Complete) emitProgress(Progress.Complete)
return@withContext return@withContext
} }
emitProgress(Progress.Downloading(0UL, 0UL), true) emitProgress(Progress.Downloading(0UL, 0UL), true)
val connection = URL(APK_PATH_URL.format(update.first)).openConnection() as HttpURLConnection val connection =
URL(APK_PATH_URL.format(update.fileName)).openConnection() as HttpURLConnection
connection.setRequestProperty("User-Agent", Application.USER_AGENT) connection.setRequestProperty("User-Agent", Application.USER_AGENT)
connection.connect() connection.connect()
if (connection.responseCode != HttpURLConnection.HTTP_OK) if (connection.responseCode != HttpURLConnection.HTTP_OK)
@ -246,7 +256,8 @@ object Updater {
emitProgress(Progress.Downloading(downloadedByteLen, totalByteLen), true) emitProgress(Progress.Downloading(downloadedByteLen, totalByteLen), true)
val installer = context.packageManager.packageInstaller val installer = context.packageManager.packageInstaller
val params = PackageInstaller.SessionParams(PackageInstaller.SessionParams.MODE_FULL_INSTALL) val params =
PackageInstaller.SessionParams(PackageInstaller.SessionParams.MODE_FULL_INSTALL)
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.S) if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.S)
params.setRequireUserAction(PackageInstaller.SessionParams.USER_ACTION_NOT_REQUIRED) params.setRequireUserAction(PackageInstaller.SessionParams.USER_ACTION_NOT_REQUIRED)
params.setAppPackageName(context.packageName) /* Enforces updates; disallows new apps. */ params.setAppPackageName(context.packageName) /* Enforces updates; disallows new apps. */
@ -275,7 +286,7 @@ object Updater {
} }
emitProgress(Progress.Installing) emitProgress(Progress.Installing)
if (!digest.digest().contentEquals(update.second.bytes)) if (!digest.digest().contentEquals(update.hash.bytes))
throw SecurityException("Update has invalid hash") throw SecurityException("Update has invalid hash")
sessionFailure = false sessionFailure = false
} finally { } finally {
@ -305,10 +316,17 @@ object Updater {
return return
when (val status = when (val status =
intent.getIntExtra(PackageInstaller.EXTRA_STATUS, PackageInstaller.STATUS_FAILURE_INVALID)) { intent.getIntExtra(
PackageInstaller.EXTRA_STATUS,
PackageInstaller.STATUS_FAILURE_INVALID
)) {
PackageInstaller.STATUS_PENDING_USER_ACTION -> { PackageInstaller.STATUS_PENDING_USER_ACTION -> {
val id = intent.getIntExtra(PackageInstaller.EXTRA_SESSION_ID, 0) val id = intent.getIntExtra(PackageInstaller.EXTRA_SESSION_ID, 0)
val userIntervention = IntentCompat.getParcelableExtra(intent, Intent.EXTRA_INTENT, Intent::class.java)!! val userIntervention = IntentCompat.getParcelableExtra(
intent,
Intent.EXTRA_INTENT,
Intent::class.java
)!!
Application.getCoroutineScope().launch { Application.getCoroutineScope().launch {
emitProgress(Progress.NeedsUserIntervention(userIntervention, id)) emitProgress(Progress.NeedsUserIntervention(userIntervention, id))
} }
@ -328,7 +346,8 @@ object Updater {
} catch (_: SecurityException) { } catch (_: SecurityException) {
} }
val message = val message =
intent.getStringExtra(PackageInstaller.EXTRA_STATUS_MESSAGE) ?: "Installation error $status" intent.getStringExtra(PackageInstaller.EXTRA_STATUS_MESSAGE)
?: "Installation error $status"
Application.getCoroutineScope().launch { Application.getCoroutineScope().launch {
val e = Exception(message) val e = Exception(message)
Log.e(TAG, "Update failure", e) Log.e(TAG, "Update failure", e)
@ -344,21 +363,22 @@ object Updater {
if (installerIsGooglePlay()) if (installerIsGooglePlay())
return return
Application.getCoroutineScope().launch(Dispatchers.IO) { updaterScope.launch {
if (UserKnobs.updaterNewerVersionSeen.firstOrNull()?.let { versionIsNewer(it, CURRENT_VERSION) } == true) if (UserKnobs.updaterNewerVersionSeen.firstOrNull()
?.let { Version(it) > CURRENT_VERSION } == true
)
return@launch return@launch
var waitTime = 15 var waitTime = 15
while (true) { while (true) {
try { try {
val updateVersion = versionOfFile(checkForUpdates().first) ?: throw IllegalStateException("No versions returned") val update = checkForUpdates() ?: continue
if (versionIsNewer(updateVersion, CURRENT_VERSION)) { if (update.version > CURRENT_VERSION) {
Log.i(TAG, "Update available: $updateVersion") Log.i(TAG, "Update available: ${update.version}")
UserKnobs.setUpdaterNewerVersionSeen(updateVersion) UserKnobs.setUpdaterNewerVersionSeen(update.version.toString())
return@launch return@launch
} }
} catch (e: Throwable) { } catch (_: Throwable) {
Log.e(TAG, "Failed to check for updates", e)
} }
delay(waitTime.minutes) delay(waitTime.minutes)
waitTime = 45 waitTime = 45
@ -366,18 +386,17 @@ object Updater {
} }
UserKnobs.updaterNewerVersionSeen.onEach { ver -> UserKnobs.updaterNewerVersionSeen.onEach { ver ->
if (ver != null && versionIsNewer( if (ver != null && Version(ver) > CURRENT_VERSION && UserKnobs.updaterNewerVersionConsented.firstOrNull()
ver, ?.let { Version(it) > CURRENT_VERSION } != true
CURRENT_VERSION
) && UserKnobs.updaterNewerVersionConsented.firstOrNull()
?.let { versionIsNewer(it, CURRENT_VERSION) } != true
) )
emitProgress(Progress.Available(ver)) emitProgress(Progress.Available(ver))
}.launchIn(Application.getCoroutineScope()) }.launchIn(Application.getCoroutineScope())
UserKnobs.updaterNewerVersionConsented.onEach { ver -> UserKnobs.updaterNewerVersionConsented.onEach { ver ->
if (ver != null && versionIsNewer(ver, CURRENT_VERSION)) if (ver != null && Version(ver) > CURRENT_VERSION)
downloadAndUpdateWrapErrors() updaterScope.launch {
downloadAndUpdateWrapErrors()
}
}.launchIn(Application.getCoroutineScope()) }.launchIn(Application.getCoroutineScope())
} }