diff --git a/app/src/main/java/com/wireguard/android/util/RootShell.java b/app/src/main/java/com/wireguard/android/util/RootShell.java index 7fbafa33..9e487a31 100644 --- a/app/src/main/java/com/wireguard/android/util/RootShell.java +++ b/app/src/main/java/com/wireguard/android/util/RootShell.java @@ -31,6 +31,7 @@ public class RootShell { private final String deviceNotRootedMessage; private final File localBinaryDir; private final File localTemporaryDir; + private final Object lock = new Object(); private final String preamble; private Process process; private BufferedReader stderr; @@ -57,16 +58,18 @@ public class RootShell { return false; } - public synchronized boolean isRunning() { - try { - // Throws an exception if the process hasn't finished yet. - if (process != null) - process.exitValue(); - } catch (final IllegalThreadStateException ignored) { - // The existing process is still running. - return true; + private boolean isRunning() { + synchronized (lock) { + try { + // Throws an exception if the process hasn't finished yet. + if (process != null) + process.exitValue(); + return false; + } catch (final IllegalThreadStateException ignored) { + // The existing process is still running. + return true; + } } - return false; } /** @@ -77,101 +80,108 @@ public class RootShell { * @param command Command to run as root. * @return The exit value of the command. */ - public synchronized int run(final Collection output, final String command) + public int run(final Collection output, final String command) throws IOException, NoRootException { - start(); - final String marker = UUID.randomUUID().toString(); - final String script = "echo " + marker + "; echo " + marker + " >&2; (" + command + - "); ret=$?; echo " + marker + " $ret; echo " + marker + " $ret >&2\n"; - Log.v(TAG, "executing: " + command); - stdin.write(script); - stdin.flush(); - String line; - int errnoStdout = Integer.MIN_VALUE; - int errnoStderr = Integer.MAX_VALUE; - int markersSeen = 0; - while ((line = stdout.readLine()) != null) { - if (line.startsWith(marker)) { - ++markersSeen; - if (line.length() > marker.length() + 1) { - errnoStdout = Integer.valueOf(line.substring(marker.length() + 1)); - break; + synchronized (lock) { + /* Start inside synchronized block to prevent a concurrent call to stop(). */ + start(); + final String marker = UUID.randomUUID().toString(); + final String script = "echo " + marker + "; echo " + marker + " >&2; (" + command + + "); ret=$?; echo " + marker + " $ret; echo " + marker + " $ret >&2\n"; + Log.v(TAG, "executing: " + command); + stdin.write(script); + stdin.flush(); + String line; + int errnoStdout = Integer.MIN_VALUE; + int errnoStderr = Integer.MAX_VALUE; + int markersSeen = 0; + while ((line = stdout.readLine()) != null) { + if (line.startsWith(marker)) { + ++markersSeen; + if (line.length() > marker.length() + 1) { + errnoStdout = Integer.valueOf(line.substring(marker.length() + 1)); + break; + } + } else if (markersSeen > 0) { + if (output != null) + output.add(line); + Log.v(TAG, "stdout: " + line); } - } else if (markersSeen > 0) { - if (output != null) - output.add(line); - Log.v(TAG, "stdout: " + line); } - } - while ((line = stderr.readLine()) != null) { - if (line.startsWith(marker)) { - ++markersSeen; - if (line.length() > marker.length() + 1) { - errnoStderr = Integer.valueOf(line.substring(marker.length() + 1)); - break; + while ((line = stderr.readLine()) != null) { + if (line.startsWith(marker)) { + ++markersSeen; + if (line.length() > marker.length() + 1) { + errnoStderr = Integer.valueOf(line.substring(marker.length() + 1)); + break; + } + } else if (markersSeen > 2) { + Log.v(TAG, "stderr: " + line); } - } else if (markersSeen > 2) { - Log.v(TAG, "stderr: " + line); } + if (markersSeen != 4) + throw new IOException("Expected 4 markers, received " + markersSeen); + if (errnoStdout != errnoStderr) + throw new IOException("Unable to read exit status"); + Log.v(TAG, "exit: " + errnoStdout); + return errnoStdout; } - if (markersSeen != 4) - throw new IOException("Expected 4 markers, received " + markersSeen); - if (errnoStdout != errnoStderr) - throw new IOException("Unable to read exit status"); - Log.v(TAG, "exit: " + errnoStdout); - return errnoStdout; } - public synchronized void start() throws IOException, NoRootException { - if (isRunning()) - return; - if (!localBinaryDir.isDirectory() && !localBinaryDir.mkdirs()) - throw new FileNotFoundException("Could not create local binary directory"); - if (!localTemporaryDir.isDirectory() && !localTemporaryDir.mkdirs()) - throw new FileNotFoundException("Could not create local temporary directory"); + public void start() throws IOException, NoRootException { if (!isExecutableInPath(SU)) throw new NoRootException(deviceNotRootedMessage); - try { - final ProcessBuilder builder = new ProcessBuilder().command(SU); - builder.environment().put("LC_ALL", "C"); + synchronized (lock) { + if (isRunning()) + return; + if (!localBinaryDir.isDirectory() && !localBinaryDir.mkdirs()) + throw new FileNotFoundException("Could not create local binary directory"); + if (!localTemporaryDir.isDirectory() && !localTemporaryDir.mkdirs()) + throw new FileNotFoundException("Could not create local temporary directory"); try { - process = builder.start(); - } catch (final IOException e) { - // A failure at this stage means the device isn't rooted. - throw new NoRootException(deviceNotRootedMessage, e); - } - stdin = new OutputStreamWriter(process.getOutputStream(), StandardCharsets.UTF_8); - stdout = new BufferedReader(new InputStreamReader(process.getInputStream(), - StandardCharsets.UTF_8)); - stderr = new BufferedReader(new InputStreamReader(process.getErrorStream(), - StandardCharsets.UTF_8)); - stdin.write(preamble); - stdin.flush(); - // Check that the shell started successfully. - final String uid = stdout.readLine(); - if (!"0".equals(uid)) { - Log.w(TAG, "Root check did not return correct UID: " + uid); - throw new NoRootException(deviceNotRootedMessage); - } - if (!isRunning()) { - String line; - while ((line = stderr.readLine()) != null) { - Log.w(TAG, "Root check returned an error: " + line); - if (line.contains("Permission denied")) - throw new NoRootException(deviceNotRootedMessage); + final ProcessBuilder builder = new ProcessBuilder().command(SU); + builder.environment().put("LC_ALL", "C"); + try { + process = builder.start(); + } catch (final IOException e) { + // A failure at this stage means the device isn't rooted. + throw new NoRootException(deviceNotRootedMessage, e); } - throw new IOException("Shell failed to start: " + process.exitValue()); + stdin = new OutputStreamWriter(process.getOutputStream(), StandardCharsets.UTF_8); + stdout = new BufferedReader(new InputStreamReader(process.getInputStream(), + StandardCharsets.UTF_8)); + stderr = new BufferedReader(new InputStreamReader(process.getErrorStream(), + StandardCharsets.UTF_8)); + stdin.write(preamble); + stdin.flush(); + // Check that the shell started successfully. + final String uid = stdout.readLine(); + if (!"0".equals(uid)) { + Log.w(TAG, "Root check did not return correct UID: " + uid); + throw new NoRootException(deviceNotRootedMessage); + } + if (!isRunning()) { + String line; + while ((line = stderr.readLine()) != null) { + Log.w(TAG, "Root check returned an error: " + line); + if (line.contains("Permission denied")) + throw new NoRootException(deviceNotRootedMessage); + } + throw new IOException("Shell failed to start: " + process.exitValue()); + } + } catch (final IOException | NoRootException e) { + stop(); + throw e; } - } catch (final IOException | NoRootException e) { - stop(); - throw e; } } - public synchronized void stop() throws IOException { - if (process != null) { - process.destroy(); - process = null; + public void stop() { + synchronized (lock) { + if (process != null) { + process.destroy(); + process = null; + } } }