crypto: Slightly Java-ify the Curve25519 implementation

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Samuel Holland 2018-06-17 14:50:09 -05:00
parent bcae77b989
commit d3a8291a7a

View File

@ -10,23 +10,23 @@ import java.util.Arrays;
/** /**
* Implementation of the Curve25519 elliptic curve algorithm. * Implementation of the Curve25519 elliptic curve algorithm.
* <p> * <p>
* This implementation was imported to WireGuard from noise-java:
* https://github.com/rweather/noise-java
* <p>
* This implementation is based on that from arduinolibs: * This implementation is based on that from arduinolibs:
* https://github.com/rweather/arduinolibs * https://github.com/rweather/arduinolibs
* <p> * <p>
* This implementation is copied verbatim from noise-java:
* https://github.com/rweather/noise-java
* <p>
* Differences in this version are due to using 26-bit limbs for the * Differences in this version are due to using 26-bit limbs for the
* representation instead of the 8/16/32-bit limbs in the original. * representation instead of the 8/16/32-bit limbs in the original.
* <p> * <p>
* References: http://cr.yp.to/ecdh.html, RFC 7748 * References: http://cr.yp.to/ecdh.html, RFC 7748
*/ */
@SuppressWarnings("MagicNumber") @SuppressWarnings({"MagicNumber", "NonConstantFieldWithUpperCaseName", "SuspiciousNameCombination"})
public final class Curve25519 { public final class Curve25519 {
// Numbers modulo 2^255 - 19 are broken up into ten 26-bit words. // Numbers modulo 2^255 - 19 are broken up into ten 26-bit words.
private static final int NUM_LIMBS_255BIT = 10; private static final int NUM_LIMBS_255BIT = 10;
private static final int NUM_LIMBS_510BIT = 20; private static final int NUM_LIMBS_510BIT = 20;
private final int[] A; private final int[] A;
private final int[] AA; private final int[] AA;
private final int[] B; private final int[] B;
@ -152,6 +152,38 @@ public final class Curve25519 {
} }
} }
/**
* Subtracts two numbers modulo 2^255 - 19.
*
* @param result The result.
* @param x The first number to subtract.
* @param y The second number to subtract.
*/
private static void sub(final int[] result, final int[] x, final int[] y) {
int index;
int borrow;
// Subtract y from x to generate the intermediate result.
borrow = 0;
for (index = 0; index < NUM_LIMBS_255BIT; ++index) {
borrow = x[index] - y[index] - ((borrow >> 26) & 0x01);
result[index] = borrow & 0x03FFFFFF;
}
// If we had a borrow, then the result has gone negative and we
// have to add 2^255 - 19 to the result to make it positive again.
// The top bits of "borrow" will be all 1's if there is a borrow
// or it will be all 0's if there was no borrow. Easiest is to
// conditionally subtract 19 and then mask off the high bits.
borrow = result[0] - ((-((borrow >> 26) & 0x01)) & 19);
result[0] = borrow & 0x03FFFFFF;
for (index = 1; index < NUM_LIMBS_255BIT; ++index) {
borrow = result[index] - ((borrow >> 26) & 0x01);
result[index] = borrow & 0x03FFFFFF;
}
result[NUM_LIMBS_255BIT - 1] &= 0x001FFFFF;
}
/** /**
* Adds two numbers modulo 2^255 - 19. * Adds two numbers modulo 2^255 - 19.
* *
@ -160,8 +192,7 @@ public final class Curve25519 {
* @param y The second number to add. * @param y The second number to add.
*/ */
private void add(final int[] result, final int[] x, final int[] y) { private void add(final int[] result, final int[] x, final int[] y) {
int carry; int carry = x[0] + y[0];
carry = x[0] + y[0];
result[0] = carry & 0x03FFFFFF; result[0] = carry & 0x03FFFFFF;
for (int index = 1; index < NUM_LIMBS_255BIT; ++index) { for (int index = 1; index < NUM_LIMBS_255BIT; ++index) {
carry = (carry >> 26) + x[index] + y[index]; carry = (carry >> 26) + x[index] + y[index];
@ -200,12 +231,13 @@ public final class Curve25519 {
*/ */
private void evalCurve(final byte[] s) { private void evalCurve(final byte[] s) {
int sposn = 31; int sposn = 31;
int sbit = 6;
int svalue = s[sposn] | 0x40; int svalue = s[sposn] | 0x40;
int swap = 0; int swap = 0;
// Iterate over all 255 bits of "s" from the highest to the lowest. // Iterate over all 255 bits of "s" from the highest to the lowest.
// We ignore the high bit of the 256-bit representation of "s". // We ignore the high bit of the 256-bit representation of "s".
for (int sbit = 6; ; ) { while (true) {
// Conditional swaps on entry to this bit but only if we // Conditional swaps on entry to this bit but only if we
// didn't swap on the previous bit. // didn't swap on the previous bit.
final int select = (svalue >> sbit) & 0x01; final int select = (svalue >> sbit) & 0x01;
@ -263,14 +295,12 @@ public final class Curve25519 {
* @param y The second number to multiply. * @param y The second number to multiply.
*/ */
private void mul(final int[] result, final int[] x, final int[] y) { private void mul(final int[] result, final int[] x, final int[] y) {
int i;
// Multiply the two numbers to create the intermediate result. // Multiply the two numbers to create the intermediate result.
long v = x[0]; long v = x[0];
for (i = 0; i < NUM_LIMBS_255BIT; ++i) { for (int i = 0; i < NUM_LIMBS_255BIT; ++i) {
t1[i] = v * y[i]; t1[i] = v * y[i];
} }
for (i = 1; i < NUM_LIMBS_255BIT; ++i) { for (int i = 1; i < NUM_LIMBS_255BIT; ++i) {
v = x[i]; v = x[i];
for (int j = 0; j < (NUM_LIMBS_255BIT - 1); ++j) { for (int j = 0; j < (NUM_LIMBS_255BIT - 1); ++j) {
t1[i + j] += v * y[j]; t1[i + j] += v * y[j];
@ -281,7 +311,7 @@ public final class Curve25519 {
// Propagate carries and convert back into 26-bit words. // Propagate carries and convert back into 26-bit words.
v = t1[0]; v = t1[0];
t2[0] = ((int) v) & 0x03FFFFFF; t2[0] = ((int) v) & 0x03FFFFFF;
for (i = 1; i < NUM_LIMBS_510BIT; ++i) { for (int i = 1; i < NUM_LIMBS_510BIT; ++i) {
v = (v >> 26) + t1[i]; v = (v >> 26) + t1[i];
t2[i] = ((int) v) & 0x03FFFFFF; t2[i] = ((int) v) & 0x03FFFFFF;
} }
@ -315,8 +345,6 @@ public final class Curve25519 {
* @param x The argument. * @param x The argument.
*/ */
private void pow250(final int[] result, final int[] x) { private void pow250(final int[] result, final int[] x) {
int j;
// The big-endian hexadecimal expansion of (2^250 - 1) is: // The big-endian hexadecimal expansion of (2^250 - 1) is:
// 03FFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF // 03FFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF
// //
@ -329,11 +357,11 @@ public final class Curve25519 {
// Build a pattern of 250 bits in length of repeated copies of 0000000001. // Build a pattern of 250 bits in length of repeated copies of 0000000001.
square(A, x); square(A, x);
for (j = 0; j < 9; ++j) for (int j = 0; j < 9; ++j)
square(A, A); square(A, A);
mul(result, A, x); mul(result, A, x);
for (int i = 0; i < 23; ++i) { for (int i = 0; i < 23; ++i) {
for (j = 0; j < 10; ++j) for (int j = 0; j < 10; ++j)
square(A, A); square(A, A);
mul(result, result, A); mul(result, result, A);
} }
@ -342,7 +370,7 @@ public final class Curve25519 {
// the result to "fill in" the gaps in the pattern. // the result to "fill in" the gaps in the pattern.
square(A, result); square(A, result);
mul(result, result, A); mul(result, result, A);
for (j = 0; j < 8; ++j) { for (int j = 0; j < 8; ++j) {
square(A, A); square(A, A);
mul(result, result, A); mul(result, result, A);
} }
@ -381,18 +409,14 @@ public final class Curve25519 {
* @param size The number of limbs in the high order half of x. * @param size The number of limbs in the high order half of x.
*/ */
private void reduce(final int[] result, final int[] x, final int size) { private void reduce(final int[] result, final int[] x, final int size) {
int index;
int limb;
int carry;
// Calculate (x mod 2^255) + ((x / 2^255) * 19) which will // Calculate (x mod 2^255) + ((x / 2^255) * 19) which will
// either produce the answer we want or it will produce a // either produce the answer we want or it will produce a
// value of the form "answer + j * (2^255 - 19)". There are // value of the form "answer + j * (2^255 - 19)". There are
// 5 left-over bits in the top-most limb of the bottom half. // 5 left-over bits in the top-most limb of the bottom half.
carry = 0; int carry = 0;
limb = x[NUM_LIMBS_255BIT - 1] >> 21; int limb = x[NUM_LIMBS_255BIT - 1] >> 21;
x[NUM_LIMBS_255BIT - 1] &= 0x001FFFFF; x[NUM_LIMBS_255BIT - 1] &= 0x001FFFFF;
for (index = 0; index < size; ++index) { for (int index = 0; index < size; ++index) {
limb += x[NUM_LIMBS_255BIT + index] << 5; limb += x[NUM_LIMBS_255BIT + index] << 5;
carry += (limb & 0x03FFFFFF) * 19 + x[index]; carry += (limb & 0x03FFFFFF) * 19 + x[index];
x[index] = carry & 0x03FFFFFF; x[index] = carry & 0x03FFFFFF;
@ -402,7 +426,7 @@ public final class Curve25519 {
if (size < NUM_LIMBS_255BIT) { if (size < NUM_LIMBS_255BIT) {
// The high order half of the number is short; e.g. for mulA24(). // The high order half of the number is short; e.g. for mulA24().
// Propagate the carry through the rest of the low order part. // Propagate the carry through the rest of the low order part.
for (index = size; index < NUM_LIMBS_255BIT; ++index) { for (int index = size; index < NUM_LIMBS_255BIT; ++index) {
carry += x[index]; carry += x[index];
x[index] = carry & 0x03FFFFFF; x[index] = carry & 0x03FFFFFF;
carry >>= 26; carry >>= 26;
@ -417,7 +441,7 @@ public final class Curve25519 {
// top 5 bits of the highest limb of the bottom half. // top 5 bits of the highest limb of the bottom half.
carry = (x[NUM_LIMBS_255BIT - 1] >> 21) * 19; carry = (x[NUM_LIMBS_255BIT - 1] >> 21) * 19;
x[NUM_LIMBS_255BIT - 1] &= 0x001FFFFF; x[NUM_LIMBS_255BIT - 1] &= 0x001FFFFF;
for (index = 0; index < NUM_LIMBS_255BIT; ++index) { for (int index = 0; index < NUM_LIMBS_255BIT; ++index) {
carry += x[index]; carry += x[index];
result[index] = carry & 0x03FFFFFF; result[index] = carry & 0x03FFFFFF;
carry >>= 26; carry >>= 26;
@ -436,14 +460,11 @@ public final class Curve25519 {
* @param x The number to reduce, and the result. * @param x The number to reduce, and the result.
*/ */
private void reduceQuick(final int[] x) { private void reduceQuick(final int[] x) {
int index;
int carry;
// Perform a trial subtraction of (2^255 - 19) from "x" which is // Perform a trial subtraction of (2^255 - 19) from "x" which is
// equivalent to adding 19 and subtracting 2^255. We add 19 here; // equivalent to adding 19 and subtracting 2^255. We add 19 here;
// the subtraction of 2^255 occurs in the next step. // the subtraction of 2^255 occurs in the next step.
carry = 19; int carry = 19;
for (index = 0; index < NUM_LIMBS_255BIT; ++index) { for (int index = 0; index < NUM_LIMBS_255BIT; ++index) {
carry += x[index]; carry += x[index];
t2[index] = carry & 0x03FFFFFF; t2[index] = carry & 0x03FFFFFF;
carry >>= 26; carry >>= 26;
@ -457,7 +478,7 @@ public final class Curve25519 {
final int mask = -((t2[NUM_LIMBS_255BIT - 1] >> 21) & 0x01); final int mask = -((t2[NUM_LIMBS_255BIT - 1] >> 21) & 0x01);
final int nmask = ~mask; final int nmask = ~mask;
t2[NUM_LIMBS_255BIT - 1] &= 0x001FFFFF; t2[NUM_LIMBS_255BIT - 1] &= 0x001FFFFF;
for (index = 0; index < NUM_LIMBS_255BIT; ++index) for (int index = 0; index < NUM_LIMBS_255BIT; ++index)
x[index] = (x[index] & nmask) | (t2[index] & mask); x[index] = (x[index] & nmask) | (t2[index] & mask);
} }
@ -470,36 +491,4 @@ public final class Curve25519 {
private void square(final int[] result, final int[] x) { private void square(final int[] result, final int[] x) {
mul(result, x, x); mul(result, x, x);
} }
/**
* Subtracts two numbers modulo 2^255 - 19.
*
* @param result The result.
* @param x The first number to subtract.
* @param y The second number to subtract.
*/
private static void sub(final int[] result, final int[] x, final int[] y) {
int index;
int borrow;
// Subtract y from x to generate the intermediate result.
borrow = 0;
for (index = 0; index < NUM_LIMBS_255BIT; ++index) {
borrow = x[index] - y[index] - ((borrow >> 26) & 0x01);
result[index] = borrow & 0x03FFFFFF;
}
// If we had a borrow, then the result has gone negative and we
// have to add 2^255 - 19 to the result to make it positive again.
// The top bits of "borrow" will be all 1's if there is a borrow
// or it will be all 0's if there was no borrow. Easiest is to
// conditionally subtract 19 and then mask off the high bits.
borrow = result[0] - ((-((borrow >> 26) & 0x01)) & 19);
result[0] = borrow & 0x03FFFFFF;
for (index = 1; index < NUM_LIMBS_255BIT; ++index) {
borrow = result[index] - ((borrow >> 26) & 0x01);
result[index] = borrow & 0x03FFFFFF;
}
result[NUM_LIMBS_255BIT - 1] &= 0x001FFFFF;
}
} }