Use more Kotlin-esque code where applicable

Signed-off-by: Harsh Shandilya <me@msfjarvis.dev>
This commit is contained in:
Harsh Shandilya 2020-03-19 14:45:07 +05:30
parent fc0660ca8d
commit 0899b49bb3
7 changed files with 70 additions and 90 deletions

View File

@ -51,9 +51,11 @@ class AddTunnelsSheet : BottomSheetDialogFragment() {
view.viewTreeObserver.removeOnGlobalLayoutListener(this) view.viewTreeObserver.removeOnGlobalLayoutListener(this)
val dialog = dialog as BottomSheetDialog? ?: return val dialog = dialog as BottomSheetDialog? ?: return
behavior = dialog.behavior behavior = dialog.behavior
behavior.state = BottomSheetBehavior.STATE_EXPANDED behavior.apply {
behavior.peekHeight = 0 state = BottomSheetBehavior.STATE_EXPANDED
behavior.addBottomSheetCallback(bottomSheetCallback) peekHeight = 0
addBottomSheetCallback(bottomSheetCallback)
}
dialog.findViewById<View>(R.id.create_empty)?.setOnClickListener { dialog.findViewById<View>(R.id.create_empty)?.setOnClickListener {
dismiss() dismiss()
onRequestCreateConfig() onRequestCreateConfig()

View File

@ -19,9 +19,6 @@ import com.wireguard.android.model.ApplicationData
import com.wireguard.android.util.ErrorMessages import com.wireguard.android.util.ErrorMessages
import com.wireguard.android.util.ObservableKeyedArrayList import com.wireguard.android.util.ObservableKeyedArrayList
import com.wireguard.android.util.ObservableKeyedList import com.wireguard.android.util.ObservableKeyedList
import java9.util.Comparators
import java9.util.function.Function
import java.util.Collections
class AppListDialogFragment : DialogFragment() { class AppListDialogFragment : DialogFragment() {
private val appData: ObservableKeyedList<String, ApplicationData> = ObservableKeyedArrayList() private val appData: ObservableKeyedList<String, ApplicationData> = ObservableKeyedArrayList()
@ -39,8 +36,7 @@ class AppListDialogFragment : DialogFragment() {
val packageName = it.activityInfo.packageName val packageName = it.activityInfo.packageName
applicationData.add(ApplicationData(it.loadIcon(pm), it.loadLabel(pm).toString(), packageName, currentlyExcludedApps.contains(packageName))) applicationData.add(ApplicationData(it.loadIcon(pm), it.loadLabel(pm).toString(), packageName, currentlyExcludedApps.contains(packageName)))
} }
applicationData.sortWith(compareBy(String.CASE_INSENSITIVE_ORDER) { it.name })
Collections.sort(applicationData, Comparators.comparing(Function { obj: ApplicationData -> obj.name }, java.lang.String.CASE_INSENSITIVE_ORDER))
applicationData applicationData
}.whenComplete { data, throwable -> }.whenComplete { data, throwable ->
if (data != null) { if (data != null) {
@ -77,7 +73,7 @@ class AppListDialogFragment : DialogFragment() {
dialog.setOnShowListener { dialog.setOnShowListener {
dialog.getButton(DialogInterface.BUTTON_NEUTRAL).setOnClickListener { dialog.getButton(DialogInterface.BUTTON_NEUTRAL).setOnClickListener {
val selectedItems = appData val selectedItems = appData
.filter { obj: ApplicationData -> obj.isExcludedFromTunnel } .filter { it.isExcludedFromTunnel }
val excludeAll = selectedItems.isEmpty() val excludeAll = selectedItems.isEmpty()
appData.forEach { appData.forEach {

View File

@ -69,7 +69,7 @@ abstract class BaseFragment : Fragment(), OnSelectedTunnelChangedListener {
is TunnelDetailFragmentBinding -> binding.tunnel is TunnelDetailFragmentBinding -> binding.tunnel
is TunnelListItemBinding -> binding.item is TunnelListItemBinding -> binding.item
else -> return else -> return
} } ?: return
Application.getBackendAsync().thenAccept { backend: Backend? -> Application.getBackendAsync().thenAccept { backend: Backend? ->
if (backend is GoBackend) { if (backend is GoBackend) {
val intent = GoBackend.VpnService.prepare(view.context) val intent = GoBackend.VpnService.prepare(view.context)
@ -80,7 +80,7 @@ abstract class BaseFragment : Fragment(), OnSelectedTunnelChangedListener {
return@thenAccept return@thenAccept
} }
} }
setTunnelStateWithPermissionsResult(tunnel!!, checked) setTunnelStateWithPermissionsResult(tunnel, checked)
} }
} }

View File

@ -4,7 +4,6 @@
*/ */
package com.wireguard.android.fragment package com.wireguard.android.fragment
import android.app.Activity
import android.app.Dialog import android.app.Dialog
import android.content.Context import android.content.Context
import android.content.DialogInterface import android.content.DialogInterface
@ -15,7 +14,6 @@ import androidx.fragment.app.DialogFragment
import com.wireguard.android.Application import com.wireguard.android.Application
import com.wireguard.android.R import com.wireguard.android.R
import com.wireguard.android.databinding.ConfigNamingDialogFragmentBinding import com.wireguard.android.databinding.ConfigNamingDialogFragmentBinding
import com.wireguard.android.model.ObservableTunnel
import com.wireguard.config.BadConfigException import com.wireguard.config.BadConfigException
import com.wireguard.config.Config import com.wireguard.config.Config
import java.io.ByteArrayInputStream import java.io.ByteArrayInputStream
@ -28,13 +26,13 @@ class ConfigNamingDialogFragment : DialogFragment() {
private var imm: InputMethodManager? = null private var imm: InputMethodManager? = null
private fun createTunnelAndDismiss() { private fun createTunnelAndDismiss() {
if (binding != null) { binding?.let {
val name = binding!!.tunnelNameText.text.toString() val name = it.tunnelNameText.text.toString()
Application.getTunnelManager().create(name, config).whenComplete { tunnel: ObservableTunnel?, throwable: Throwable -> Application.getTunnelManager().create(name, config).whenComplete { tunnel, throwable ->
if (tunnel != null) { if (tunnel != null) {
dismiss() dismiss()
} else { } else {
binding!!.tunnelNameTextLayout.error = throwable.message it.tunnelNameTextLayout.error = throwable.message
} }
} }
} }
@ -51,15 +49,16 @@ class ConfigNamingDialogFragment : DialogFragment() {
val configBytes = configText!!.toByteArray(StandardCharsets.UTF_8) val configBytes = configText!!.toByteArray(StandardCharsets.UTF_8)
config = try { config = try {
Config.parse(ByteArrayInputStream(configBytes)) Config.parse(ByteArrayInputStream(configBytes))
} catch (e: BadConfigException) { } catch(e: Exception) {
throw IllegalArgumentException("Invalid config passed to " + javaClass.simpleName, e) when(e) {
} catch (e: IOException) { is BadConfigException, is IOException -> throw IllegalArgumentException("Invalid config passed to ${javaClass.simpleName}", e)
throw IllegalArgumentException("Invalid config passed to " + javaClass.simpleName, e) else -> throw e
}
} }
} }
override fun onCreateDialog(savedInstanceState: Bundle?): Dialog { override fun onCreateDialog(savedInstanceState: Bundle?): Dialog {
val activity: Activity = requireActivity() val activity = requireActivity()
imm = activity.getSystemService(Context.INPUT_METHOD_SERVICE) as InputMethodManager imm = activity.getSystemService(Context.INPUT_METHOD_SERVICE) as InputMethodManager
val alertDialogBuilder = AlertDialog.Builder(activity) val alertDialogBuilder = AlertDialog.Builder(activity)
alertDialogBuilder.setTitle(R.string.import_from_qr_code) alertDialogBuilder.setTitle(R.string.import_from_qr_code)

View File

@ -18,7 +18,6 @@ import com.wireguard.android.databinding.TunnelDetailPeerBinding
import com.wireguard.android.model.ObservableTunnel import com.wireguard.android.model.ObservableTunnel
import com.wireguard.android.ui.EdgeToEdge.setUpRoot import com.wireguard.android.ui.EdgeToEdge.setUpRoot
import com.wireguard.android.ui.EdgeToEdge.setUpScrollingContent import com.wireguard.android.ui.EdgeToEdge.setUpScrollingContent
import com.wireguard.config.Config
import java.util.Timer import java.util.Timer
import java.util.TimerTask import java.util.TimerTask
@ -27,7 +26,7 @@ import java.util.TimerTask
*/ */
class TunnelDetailFragment : BaseFragment() { class TunnelDetailFragment : BaseFragment() {
private var binding: TunnelDetailFragmentBinding? = null private var binding: TunnelDetailFragmentBinding? = null
private var lastState: Tunnel.State? = Tunnel.State.TOGGLE private var lastState = Tunnel.State.TOGGLE
private var timer: Timer? = null private var timer: Timer? = null
private fun formatBytes(bytes: Long): String { private fun formatBytes(bytes: Long): String {
@ -78,9 +77,9 @@ class TunnelDetailFragment : BaseFragment() {
} }
override fun onSelectedTunnelChanged(oldTunnel: ObservableTunnel?, newTunnel: ObservableTunnel?) { override fun onSelectedTunnelChanged(oldTunnel: ObservableTunnel?, newTunnel: ObservableTunnel?) {
if (binding == null) return binding ?: return
binding!!.tunnel = newTunnel binding!!.tunnel = newTunnel
if (newTunnel == null) binding!!.config = null else newTunnel.configAsync.thenAccept { config: Config? -> binding!!.config = config } if (newTunnel == null) binding!!.config = null else newTunnel.configAsync.thenAccept { config -> binding!!.config = config }
lastState = Tunnel.State.TOGGLE lastState = Tunnel.State.TOGGLE
updateStats() updateStats()
} }
@ -94,9 +93,7 @@ class TunnelDetailFragment : BaseFragment() {
} }
override fun onViewStateRestored(savedInstanceState: Bundle?) { override fun onViewStateRestored(savedInstanceState: Bundle?) {
if (binding == null) { binding ?: return
return
}
binding!!.fragment = this binding!!.fragment = this
onSelectedTunnelChanged(null, selectedTunnel) onSelectedTunnelChanged(null, selectedTunnel)
super.onViewStateRestored(savedInstanceState) super.onViewStateRestored(savedInstanceState)

View File

@ -35,9 +35,7 @@ class TunnelEditorFragment : BaseFragment(), AppExclusionListener {
private var binding: TunnelEditorFragmentBinding? = null private var binding: TunnelEditorFragmentBinding? = null
private var tunnel: ObservableTunnel? = null private var tunnel: ObservableTunnel? = null
private fun onConfigLoaded(config: Config) { private fun onConfigLoaded(config: Config) {
if (binding != null) { binding?.config = ConfigProxy(config)
binding!!.config = ConfigProxy(config)
}
} }
private fun onConfigSaved(savedTunnel: Tunnel, throwable: Throwable?) { private fun onConfigSaved(savedTunnel: Tunnel, throwable: Throwable?) {
@ -51,8 +49,8 @@ class TunnelEditorFragment : BaseFragment(), AppExclusionListener {
val error = ErrorMessages.get(throwable) val error = ErrorMessages.get(throwable)
message = getString(R.string.config_save_error, savedTunnel.name, error) message = getString(R.string.config_save_error, savedTunnel.name, error)
Log.e(TAG, message, throwable) Log.e(TAG, message, throwable)
if (binding != null) { binding?.let {
Snackbar.make(binding!!.mainContainer, message, Snackbar.LENGTH_LONG).show() Snackbar.make(it.mainContainer, message, Snackbar.LENGTH_LONG).show()
} }
} }
} }
@ -112,9 +110,8 @@ class TunnelEditorFragment : BaseFragment(), AppExclusionListener {
override fun onOptionsItemSelected(item: MenuItem): Boolean { override fun onOptionsItemSelected(item: MenuItem): Boolean {
if (item.itemId == R.id.menu_action_save) { if (item.itemId == R.id.menu_action_save) {
if (binding == null) return false binding ?: return false
val newConfig: Config val newConfig = try {
newConfig = try {
binding!!.config!!.resolve() binding!!.config!!.resolve()
} catch (e: Exception) { } catch (e: Exception) {
val error = ErrorMessages.get(e) val error = ErrorMessages.get(e)
@ -129,7 +126,7 @@ class TunnelEditorFragment : BaseFragment(), AppExclusionListener {
Log.d(TAG, "Attempting to create new tunnel " + binding!!.name) Log.d(TAG, "Attempting to create new tunnel " + binding!!.name)
val manager = Application.getTunnelManager() val manager = Application.getTunnelManager()
manager.create(binding!!.name, newConfig) manager.create(binding!!.name, newConfig)
.whenComplete { newTunnel, throwable -> onTunnelCreated(newTunnel, throwable) } .whenComplete(this::onTunnelCreated)
} }
tunnel!!.name != binding!!.name -> { tunnel!!.name != binding!!.name -> {
Log.d(TAG, "Attempting to rename tunnel to " + binding!!.name) Log.d(TAG, "Attempting to rename tunnel to " + binding!!.name)
@ -169,7 +166,7 @@ class TunnelEditorFragment : BaseFragment(), AppExclusionListener {
binding!!.config = ConfigProxy() binding!!.config = ConfigProxy()
if (tunnel != null) { if (tunnel != null) {
binding!!.name = tunnel!!.name binding!!.name = tunnel!!.name
tunnel!!.configAsync.thenAccept { config: Config -> onConfigLoaded(config) } tunnel!!.configAsync.thenAccept(this::onConfigLoaded)
} else { } else {
binding!!.name = "" binding!!.name = ""
} }
@ -187,8 +184,8 @@ class TunnelEditorFragment : BaseFragment(), AppExclusionListener {
val error = ErrorMessages.get(throwable) val error = ErrorMessages.get(throwable)
message = getString(R.string.tunnel_create_error, error) message = getString(R.string.tunnel_create_error, error)
Log.e(TAG, message, throwable) Log.e(TAG, message, throwable)
if (binding != null) { binding?.let {
Snackbar.make(binding!!.mainContainer, message, Snackbar.LENGTH_LONG).show() Snackbar.make(it.mainContainer, message, Snackbar.LENGTH_LONG).show()
} }
} }
} }
@ -206,16 +203,14 @@ class TunnelEditorFragment : BaseFragment(), AppExclusionListener {
val error = ErrorMessages.get(throwable) val error = ErrorMessages.get(throwable)
message = getString(R.string.tunnel_rename_error, error) message = getString(R.string.tunnel_rename_error, error)
Log.e(TAG, message, throwable) Log.e(TAG, message, throwable)
if (binding != null) { binding?.let {
Snackbar.make(binding!!.mainContainer, message, Snackbar.LENGTH_LONG).show() Snackbar.make(it.mainContainer, message, Snackbar.LENGTH_LONG).show()
} }
} }
} }
override fun onViewStateRestored(savedInstanceState: Bundle?) { override fun onViewStateRestored(savedInstanceState: Bundle?) {
if (binding == null) { binding ?: return
return
}
binding!!.fragment = this binding!!.fragment = this
if (savedInstanceState == null) { if (savedInstanceState == null) {
onSelectedTunnelChanged(null, selectedTunnel) onSelectedTunnelChanged(null, selectedTunnel)

View File

@ -23,7 +23,6 @@ import com.google.android.material.snackbar.Snackbar
import com.google.zxing.integration.android.IntentIntegrator import com.google.zxing.integration.android.IntentIntegrator
import com.wireguard.android.Application import com.wireguard.android.Application
import com.wireguard.android.R import com.wireguard.android.R
import com.wireguard.android.activity.TunnelCreatorActivity
import com.wireguard.android.databinding.ObservableKeyedRecyclerViewAdapter.RowConfigurationHandler import com.wireguard.android.databinding.ObservableKeyedRecyclerViewAdapter.RowConfigurationHandler
import com.wireguard.android.databinding.TunnelListFragmentBinding import com.wireguard.android.databinding.TunnelListFragmentBinding
import com.wireguard.android.databinding.TunnelListItemBinding import com.wireguard.android.databinding.TunnelListItemBinding
@ -33,7 +32,6 @@ import com.wireguard.android.ui.EdgeToEdge.setUpFAB
import com.wireguard.android.ui.EdgeToEdge.setUpRoot import com.wireguard.android.ui.EdgeToEdge.setUpRoot
import com.wireguard.android.ui.EdgeToEdge.setUpScrollingContent import com.wireguard.android.ui.EdgeToEdge.setUpScrollingContent
import com.wireguard.android.util.ErrorMessages import com.wireguard.android.util.ErrorMessages
import com.wireguard.android.util.ObservableSortedKeyedList
import com.wireguard.android.widget.MultiselectableRelativeLayout import com.wireguard.android.widget.MultiselectableRelativeLayout
import com.wireguard.config.BadConfigException import com.wireguard.config.BadConfigException
import com.wireguard.config.Config import com.wireguard.config.Config
@ -63,10 +61,11 @@ class TunnelListFragment : BaseFragment() {
// Config text is valid, now create the tunnel… // Config text is valid, now create the tunnel…
newInstance(configText).show(parentFragmentManager, null) newInstance(configText).show(parentFragmentManager, null)
} catch (e: BadConfigException) { } catch (e: Exception) {
onTunnelImportFinished(emptyList(), listOf<Throwable>(e)) when(e) {
} catch (e: IOException) { is BadConfigException, is IOException -> onTunnelImportFinished(emptyList(), listOf<Throwable>(e))
onTunnelImportFinished(emptyList(), listOf<Throwable>(e)) else -> throw e
}
} }
} }
@ -86,28 +85,28 @@ class TunnelListFragment : BaseFragment() {
if (cursor.moveToFirst() && !cursor.isNull(0)) { if (cursor.moveToFirst() && !cursor.isNull(0)) {
name = cursor.getString(0) name = cursor.getString(0)
} }
cursor.close()
} }
if (name.isEmpty()) { if (name.isEmpty()) {
name = Uri.decode(uri.lastPathSegment) name = Uri.decode(uri.lastPathSegment)
} }
var idx = name.lastIndexOf('/') var idx = name.lastIndexOf('/')
if (idx >= 0) { if (idx >= 0) {
require(idx < name.length - 1) { "Illegal file name: $name" } require(idx < name.length - 1) { resources.getString(R.string.illegal_filename_error, name) }
name = name.substring(idx + 1) name = name.substring(idx + 1)
} }
val isZip = name.toLowerCase(Locale.ROOT).endsWith(".zip") val isZip = name.toLowerCase(Locale.ROOT).endsWith(".zip")
if (name.toLowerCase(Locale.ROOT).endsWith(".conf")) { if (name.toLowerCase(Locale.ROOT).endsWith(".conf")) {
name = name.substring(0, name.length - ".conf".length) name = name.substring(0, name.length - ".conf".length)
} else { } else {
require(isZip) { "File must be .conf or .zip" } require(isZip) { resources.getString(R.string.bad_extension_error) }
} }
if (isZip) { if (isZip) {
ZipInputStream(contentResolver.openInputStream(uri)).use { zip -> ZipInputStream(contentResolver.openInputStream(uri)).use { zip ->
val reader = BufferedReader(InputStreamReader(zip, StandardCharsets.UTF_8)) val reader = BufferedReader(InputStreamReader(zip, StandardCharsets.UTF_8))
var entry: ZipEntry var entry: ZipEntry?
while (zip.nextEntry.also { entry = it } != null) { while (true) {
entry = zip.nextEntry ?: break
name = entry.name name = entry.name
idx = name.lastIndexOf('/') idx = name.lastIndexOf('/')
if (idx >= 0) { if (idx >= 0) {
@ -121,15 +120,13 @@ class TunnelListFragment : BaseFragment() {
} else { } else {
continue continue
} }
val config: Config? = try { try {
Config.parse(reader) Config.parse(reader)
} catch (e: Exception) { } catch (e: Exception) {
throwables.add(e) throwables.add(e)
null null
} }?.let {
futureTunnels.add(Application.getTunnelManager().create(name, it).toCompletableFuture())
if (config != null) {
futureTunnels.add(Application.getTunnelManager().create(name, config).toCompletableFuture())
} }
} }
} }
@ -146,7 +143,7 @@ class TunnelListFragment : BaseFragment() {
if (throwables.size == 1) { if (throwables.size == 1) {
throw throwables[0] throw throwables[0]
} else { } else {
require(throwables.isNotEmpty()) { "No configurations found" } require(throwables.isNotEmpty()) { resources.getString(R.string.no_configs_error) }
} }
} }
CompletableFuture.allOf(*futureTunnels.toTypedArray()) CompletableFuture.allOf(*futureTunnels.toTypedArray())
@ -177,7 +174,7 @@ class TunnelListFragment : BaseFragment() {
override fun onActivityCreated(savedInstanceState: Bundle?) { override fun onActivityCreated(savedInstanceState: Bundle?) {
super.onActivityCreated(savedInstanceState) super.onActivityCreated(savedInstanceState)
if (savedInstanceState != null) { if (savedInstanceState != null) {
val checkedItems: Collection<Int>? = savedInstanceState.getIntegerArrayList("CHECKED_ITEMS") val checkedItems = savedInstanceState.getIntegerArrayList(CHECKED_ITEMS)
if (checkedItems != null) { if (checkedItems != null) {
for (i in checkedItems) actionModeListener.setItemChecked(i, true) for (i in checkedItems) actionModeListener.setItemChecked(i, true)
} }
@ -225,18 +222,14 @@ class TunnelListFragment : BaseFragment() {
super.onDestroyView() super.onDestroyView()
} }
fun onRequestCreateConfig(view: View?) {
startActivity(Intent(activity, TunnelCreatorActivity::class.java))
}
override fun onSaveInstanceState(outState: Bundle) { override fun onSaveInstanceState(outState: Bundle) {
super.onSaveInstanceState(outState) super.onSaveInstanceState(outState)
outState.putIntegerArrayList("CHECKED_ITEMS", actionModeListener.getCheckedItems()) outState.putIntegerArrayList(CHECKED_ITEMS, actionModeListener.getCheckedItems())
} }
override fun onSelectedTunnelChanged(oldTunnel: ObservableTunnel?, newTunnel: ObservableTunnel?) { override fun onSelectedTunnelChanged(oldTunnel: ObservableTunnel?, newTunnel: ObservableTunnel?) {
if (binding == null) return binding ?: return
Application.getTunnelManager().tunnels.thenAccept { tunnels: ObservableSortedKeyedList<String?, ObservableTunnel?> -> Application.getTunnelManager().tunnels.thenAccept { tunnels ->
if (newTunnel != null) viewForTunnel(newTunnel, tunnels).setSingleSelected(true) if (newTunnel != null) viewForTunnel(newTunnel, tunnels).setSingleSelected(true)
if (oldTunnel != null) viewForTunnel(oldTunnel, tunnels).setSingleSelected(false) if (oldTunnel != null) viewForTunnel(oldTunnel, tunnels).setSingleSelected(false)
} }
@ -255,7 +248,7 @@ class TunnelListFragment : BaseFragment() {
} }
private fun onTunnelImportFinished(tunnels: List<ObservableTunnel>, throwables: Collection<Throwable>) { private fun onTunnelImportFinished(tunnels: List<ObservableTunnel>, throwables: Collection<Throwable>) {
var message: String? = null var message = ""
for (throwable in throwables) { for (throwable in throwables) {
val error = ErrorMessages.get(throwable) val error = ErrorMessages.get(throwable)
message = getString(R.string.import_error, error) message = getString(R.string.import_error, error)
@ -276,12 +269,10 @@ class TunnelListFragment : BaseFragment() {
override fun onViewStateRestored(savedInstanceState: Bundle?) { override fun onViewStateRestored(savedInstanceState: Bundle?) {
super.onViewStateRestored(savedInstanceState) super.onViewStateRestored(savedInstanceState)
if (binding == null) { binding ?: return
return
}
binding!!.fragment = this binding!!.fragment = this
Application.getTunnelManager().tunnels.thenAccept { tunnels: ObservableSortedKeyedList<String?, ObservableTunnel?>? -> binding!!.tunnels = tunnels } Application.getTunnelManager().tunnels.thenAccept { tunnels -> binding!!.tunnels = tunnels }
binding!!.rowConfigurationHandler = RowConfigurationHandler { binding: TunnelListItemBinding, tunnel: ObservableTunnel, position: Int -> binding!!.rowConfigurationHandler = RowConfigurationHandler { binding: TunnelListItemBinding, tunnel: ObservableTunnel, position ->
binding.fragment = this binding.fragment = this
binding.root.setOnClickListener { binding.root.setOnClickListener {
if (actionMode == null) { if (actionMode == null) {
@ -297,15 +288,15 @@ class TunnelListFragment : BaseFragment() {
if (actionMode != null) if (actionMode != null)
(binding.root as MultiselectableRelativeLayout).setMultiSelected(actionModeListener.checkedItems.contains(position)) (binding.root as MultiselectableRelativeLayout).setMultiSelected(actionModeListener.checkedItems.contains(position))
else else
(binding.root as MultiselectableRelativeLayout).setSingleSelected(selectedTunnel === tunnel) (binding.root as MultiselectableRelativeLayout).setSingleSelected(selectedTunnel == tunnel)
} }
} }
private fun showSnackbar(message: CharSequence?) { private fun showSnackbar(message: CharSequence) {
if (binding != null) { binding?.let {
val snackbar = Snackbar.make(binding!!.mainContainer, message!!, Snackbar.LENGTH_LONG) Snackbar.make(it.mainContainer, message, Snackbar.LENGTH_LONG)
snackbar.anchorView = binding!!.createFab .setAnchorView(it.createFab)
snackbar.show() .show()
} }
} }
@ -316,6 +307,7 @@ class TunnelListFragment : BaseFragment() {
private inner class ActionModeListener : ActionMode.Callback { private inner class ActionModeListener : ActionMode.Callback {
val checkedItems: MutableCollection<Int> = HashSet() val checkedItems: MutableCollection<Int> = HashSet()
private var resources: Resources? = null private var resources: Resources? = null
fun getCheckedItems(): ArrayList<Int> { fun getCheckedItems(): ArrayList<Int> {
return ArrayList(checkedItems) return ArrayList(checkedItems)
} }
@ -323,8 +315,8 @@ class TunnelListFragment : BaseFragment() {
override fun onActionItemClicked(mode: ActionMode, item: MenuItem): Boolean { override fun onActionItemClicked(mode: ActionMode, item: MenuItem): Boolean {
return when (item.itemId) { return when (item.itemId) {
R.id.menu_action_delete -> { R.id.menu_action_delete -> {
val copyCheckedItems: Iterable<Int> = HashSet(checkedItems) val copyCheckedItems = HashSet(checkedItems)
Application.getTunnelManager().tunnels.thenAccept { tunnels: ObservableSortedKeyedList<String, ObservableTunnel> -> Application.getTunnelManager().tunnels.thenAccept { tunnels ->
val tunnelsToDelete = ArrayList<ObservableTunnel>() val tunnelsToDelete = ArrayList<ObservableTunnel>()
for (position in copyCheckedItems) tunnelsToDelete.add(tunnels[position]) for (position in copyCheckedItems) tunnelsToDelete.add(tunnels[position])
val futures = tunnelsToDelete val futures = tunnelsToDelete
@ -332,18 +324,16 @@ class TunnelListFragment : BaseFragment() {
.toTypedArray() .toTypedArray()
CompletableFuture.allOf(*futures as Array<out CompletableFuture<*>>) CompletableFuture.allOf(*futures as Array<out CompletableFuture<*>>)
.thenApply { futures.size } .thenApply { futures.size }
.whenComplete { count: Int, throwable: Throwable? -> onTunnelDeletionFinished(count, throwable) } .whenComplete(this@TunnelListFragment::onTunnelDeletionFinished)
} }
checkedItems.clear() checkedItems.clear()
mode.finish() mode.finish()
true true
} }
R.id.menu_action_select_all -> { R.id.menu_action_select_all -> {
Application.getTunnelManager().tunnels.thenAccept { tunnels: ObservableSortedKeyedList<String?, ObservableTunnel?> -> Application.getTunnelManager().tunnels.thenAccept { tunnels ->
var i = 0 for (i in 0 until tunnels.size) {
while (i < tunnels.size) {
setItemChecked(i, true) setItemChecked(i, true)
++i
} }
} }
true true
@ -410,6 +400,7 @@ class TunnelListFragment : BaseFragment() {
companion object { companion object {
const val REQUEST_IMPORT = 1 const val REQUEST_IMPORT = 1
private const val REQUEST_TARGET_FRAGMENT = 2 private const val REQUEST_TARGET_FRAGMENT = 2
private const val CHECKED_ITEMS = "CHECKED_ITEMS"
private val TAG = "WireGuard/" + TunnelListFragment::class.java.simpleName private val TAG = "WireGuard/" + TunnelListFragment::class.java.simpleName
} }
} }