rsa: black formatting, ignore idea files

Signed-off-by: HeshamTB <hishaminv@gmail.com>
This commit is contained in:
HeshamTB 2024-01-19 15:58:15 +03:00
parent 8c7bc52d9a
commit a58f9b9337
Signed by: Hesham
GPG Key ID: 74876157D199B09E
2 changed files with 175 additions and 121 deletions

2
.gitignore vendored
View File

@ -9,3 +9,5 @@ build/
#venv dir #venv dir
venv/ venv/
.idea/

270
rsa.py
View File

@ -2,7 +2,7 @@
# Copyright (C) 2019-2021 Hesham T. Banafa # Copyright (C) 2019-2021 Hesham T. Banafa
#program to generate rsa key pair using methods in EE-305 # program to generate rsa key pair using methods in EE-305
# Hesham Banafa # Hesham Banafa
""" """
@ -17,16 +17,17 @@ import os
import sys import sys
import MillerRabin as mr import MillerRabin as mr
VERSION="1.2.2" VERSION = "1.2.2"
keysFolder = "keys/" keysFolder = "keys/"
byteOrder = "little" byteOrder = "little"
N=0 N = 0
E=1 E = 1
D=2 D = 2
P=3 P = 3
Q=4 Q = 4
PHI=5 PHI = 5
ID=6 ID = 6
def main(): def main():
print("hesham-rsa version ", VERSION) print("hesham-rsa version ", VERSION)
@ -56,7 +57,9 @@ def main():
print(ex) print(ex)
sys.exit(1) sys.exit(1)
sys.exit(0) sys.exit(0)
if sys.argv[1] == "encrypt" and len(sys.argv) == 5: ##rsa encrypt <message> <key> <signer> if (
sys.argv[1] == "encrypt" and len(sys.argv) == 5
): ##rsa encrypt <message> <key> <signer>
msg = sys.argv[2] msg = sys.argv[2]
msg_list = msg.split() msg_list = msg.split()
keyName = sys.argv[3] keyName = sys.argv[3]
@ -67,9 +70,11 @@ def main():
msg_encrypted = "" msg_encrypted = ""
for word in msg_list: for word in msg_list:
msg_encrypted = msg_encrypted + " " + hex(encrypt(word, key_public)) msg_encrypted = msg_encrypted + " " + hex(encrypt(word, key_public))
#msg_encrypted = encrypt(msg, key_public) # msg_encrypted = encrypt(msg, key_public)
print("Encrypted msg: \n", msg_encrypted) print("Encrypted msg: \n", msg_encrypted)
print("Signed: \n", sign(msg_encrypted, signing_key)) ## Adds an encrypted sig at the end of message. print(
"Signed: \n", sign(msg_encrypted, signing_key)
) ## Adds an encrypted sig at the end of message.
sys.exit(0) sys.exit(0)
elif sys.argv[1] == "encrypt": elif sys.argv[1] == "encrypt":
print("Not enough arguments") print("Not enough arguments")
@ -83,7 +88,9 @@ def main():
msg_decrypted = "" msg_decrypted = ""
key = readKeyFile(sys.argv[3]) key = readKeyFile(sys.argv[3])
for cipher_word in cipher_list: for cipher_word in cipher_list:
msg_decrypted = msg_decrypted + " " + str(decrypt(int(cipher_word, 16),key[D],key[N])) msg_decrypted = (
msg_decrypted + " " + str(decrypt(int(cipher_word, 16), key[D], key[N]))
)
if sig == None: if sig == None:
print("\033[91mUnknown signature! \u2717" + "\033[0m") print("\033[91mUnknown signature! \u2717" + "\033[0m")
else: else:
@ -92,19 +99,19 @@ def main():
sys.exit(0) sys.exit(0)
elif sys.argv[1] == "decrypt": elif sys.argv[1] == "decrypt":
print("Not enough arguments") print("Not enough arguments")
print("rsa decrypt \"<cipher>\" <keyid>") print('rsa decrypt "<cipher>" <keyid>')
sys.exit(1) sys.exit(1)
if sys.argv[1] == "list": if sys.argv[1] == "list":
listKeys() listKeys()
sys.exit(0) sys.exit(0)
if sys.argv[1] == "export" and len(sys.argv) == 3: #rsa export <key> if sys.argv[1] == "export" and len(sys.argv) == 3: # rsa export <key>
key_file_name = sys.argv[2] key_file_name = sys.argv[2]
exportKey(key_file_name) exportKey(key_file_name)
sys.exit(0) sys.exit(0)
elif sys.argv[1] == "export": elif sys.argv[1] == "export":
printHelp() printHelp()
sys.exit(1) sys.exit(1)
if sys.argv[1] == "crack" and len(sys.argv) == 3: #rsa crack <key> if sys.argv[1] == "crack" and len(sys.argv) == 3: # rsa crack <key>
keyName = sys.argv[2] keyName = sys.argv[2]
cracked_key = crackKey2(keyName) cracked_key = crackKey2(keyName)
printKey(cracked_key) printKey(cracked_key)
@ -112,22 +119,26 @@ def main():
elif sys.argv[1] == "crack": elif sys.argv[1] == "crack":
printHelp() printHelp()
sys.exit(1) sys.exit(1)
if sys.argv[1] == "is_prime" and len(sys.argv) == 4: #rsa is_prime <base> <N> if sys.argv[1] == "is_prime" and len(sys.argv) == 4: # rsa is_prime <base> <N>
isPrime_cmd(0) isPrime_cmd(0)
sys.exit(0) sys.exit(0)
if sys.argv[1] == "is_prime_mr" and len(sys.argv) == 4: #rsa is_prime_mr <base> <N> if (
sys.argv[1] == "is_prime_mr" and len(sys.argv) == 4
): # rsa is_prime_mr <base> <N>
isPrime_cmd(1) isPrime_cmd(1)
sys.exit(0) sys.exit(0)
if sys.argv[1] == "genrand" and len(sys.argv) == 3: #rsa genrand <bits> if sys.argv[1] == "genrand" and len(sys.argv) == 3: # rsa genrand <bits>
print(gen_random(int(sys.argv[2]))) print(gen_random(int(sys.argv[2])))
sys.exit(0) sys.exit(0)
if sys.argv[1] == "genprime" and len(sys.argv) == 3: #rsa genprime <bits> if sys.argv[1] == "genprime" and len(sys.argv) == 3: # rsa genprime <bits>
print(getPrime(int(sys.argv[2]))) print(getPrime(int(sys.argv[2])))
sys.exit(0) sys.exit(0)
if sys.argv[1] == "prime_factors" and len(sys.argv) == 4: #rsa primefactors <base> <N> if (
sys.argv[1] == "prime_factors" and len(sys.argv) == 4
): # rsa primefactors <base> <N>
prime_factors(sys.argv[3], sys.argv[2]) prime_factors(sys.argv[3], sys.argv[2])
sys.exit(0) sys.exit(0)
if sys.argv[1] == "print" and len(sys.argv) == 3: #rsa print <key> if sys.argv[1] == "print" and len(sys.argv) == 3: # rsa print <key>
printKey(readKeyFile(sys.argv[2])) printKey(readKeyFile(sys.argv[2]))
sys.exit(0) sys.exit(0)
elif sys.argv[1] == "print": elif sys.argv[1] == "print":
@ -137,87 +148,92 @@ def main():
printHelp() printHelp()
sys.exit(0) sys.exit(0)
#No command exit code # No command exit code
printHelp() printHelp()
sys.exit(127) sys.exit(127)
def generateKeys(id, bits=64): def generateKeys(id, bits=64):
from multiprocessing.pool import Pool from multiprocessing.pool import Pool
#Primes of size 32 bit random
#resulting in a 64-bit key mod # Primes of size 32 bit random
# resulting in a 64-bit key mod
pool = Pool() pool = Pool()
result1 = pool.apply_async(getPrime, [int(bits/2)]) result1 = pool.apply_async(getPrime, [int(bits / 2)])
result2 = pool.apply_async(getPrime, [int(bits/2)]) result2 = pool.apply_async(getPrime, [int(bits / 2)])
p = result1.get() p = result1.get()
q = result2.get() q = result2.get()
n = p*q n = p * q
#print("n: ", n) # print("n: ", n)
#lamda(n) = LCM(p-1, q-1) # lamda(n) = LCM(p-1, q-1)
#Since LCM(a,b) = ab/GCD(a,b) # Since LCM(a,b) = ab/GCD(a,b)
#gcd = math.gcd(p-1, q-1) # gcd = math.gcd(p-1, q-1)
#print("GCD: ", gcd) # print("GCD: ", gcd)
#lcm = abs((p-1) * (q-1)) / gcd # lcm = abs((p-1) * (q-1)) / gcd
#print("LCM: ", lcm) # print("LCM: ", lcm)
phi = (p-1)*(q-1) phi = (p - 1) * (q - 1)
#print("phi: ", phi) # print("phi: ", phi)
#e exponant should be 1 < e < lamda(n) and GCD(e, lamda(n)) = 1 (coprime) # e exponant should be 1 < e < lamda(n) and GCD(e, lamda(n)) = 1 (coprime)
# recommended value is 65,537 # recommended value is 65,537
e = 65537 e = 65537
d = pow(e,-1,phi) # d = e^-1 mod phi d = pow(e, -1, phi) # d = e^-1 mod phi
return (n, e, d, p, q, phi, id) return (n, e, d, p, q, phi, id)
def encrypt(message, publicKey): def encrypt(message, publicKey):
msg_text = message msg_text = message
n = publicKey[N] n = publicKey[N]
e = publicKey[E] e = publicKey[E]
#print("using n: {0}, e: {1}".format(n, e)) # print("using n: {0}, e: {1}".format(n, e))
msg_number_form = int.from_bytes(msg_text.encode(), byteOrder) msg_number_form = int.from_bytes(msg_text.encode(), byteOrder)
#print("Word: %s or %d" % (msg_text, msg_number_form)) # print("Word: %s or %d" % (msg_text, msg_number_form))
msg_encrypted_number_form = pow(msg_number_form, e, n) # c = msg^e mod n msg_encrypted_number_form = pow(msg_number_form, e, n) # c = msg^e mod n
return msg_encrypted_number_form return msg_encrypted_number_form
def decrypt(cipher, privateKey, n): def decrypt(cipher, privateKey, n):
msg_encrypted_number_form = cipher msg_encrypted_number_form = cipher
d = privateKey d = privateKey
msg_decrypted_number_form = pow(msg_encrypted_number_form, d, n) # msg = c^d mod n msg_decrypted_number_form = pow(msg_encrypted_number_form, d, n) # msg = c^d mod n
msg_decrypted = int(msg_decrypted_number_form) msg_decrypted = int(msg_decrypted_number_form)
try: try:
msg_decrypted = str(msg_decrypted.to_bytes(msg_decrypted.bit_length(), byteOrder).decode()).strip() msg_decrypted = str(
msg_decrypted.to_bytes(msg_decrypted.bit_length(), byteOrder).decode()
).strip()
except UnicodeDecodeError: except UnicodeDecodeError:
#print("decrypt: Cant decrypt properly") # print("decrypt: Cant decrypt properly")
return "" return ""
return msg_decrypted return msg_decrypted
def getPrime(bits): def getPrime(bits):
while True: while True:
#Byte order "little" or "big" does not matter here since we want a random number from os.urandom() # Byte order "little" or "big" does not matter here since we want a random number from os.urandom()
x = int.from_bytes(os.urandom(int(bits/8)), byteOrder) x = int.from_bytes(os.urandom(int(bits / 8)), byteOrder)
print('Trying: ', x, end="\n") print("Trying: ", x, end="\n")
if mr.is_prime(x): if mr.is_prime(x):
print("\nprime: ", x, '\n') print("\nprime: ", x, "\n")
return x return x
#backTrack(x) # backTrack(x)
def isPrime(number): def isPrime(number):
if number == 2: if number == 2:
return True return True
#if 2 devides number then num is not prime. pg.21 # if 2 devides number then num is not prime. pg.21
if number % 2 == 0 or number == 1: if number % 2 == 0 or number == 1:
return False return False
#largest integer less than or equal square root of number (K) # largest integer less than or equal square root of number (K)
rootOfNum = math.sqrt(number) rootOfNum = math.sqrt(number)
K = math.floor(rootOfNum) K = math.floor(rootOfNum)
#Take odd D such that 1 < D <= K # Take odd D such that 1 < D <= K
#If D devides number then number is not prime. otherwise prime. # If D devides number then number is not prime. otherwise prime.
for D in range(1, K, 2): for D in range(1, K, 2):
if D % 2 == 0 or D == 1: if D % 2 == 0 or D == 1:
pass pass
@ -226,23 +242,26 @@ def isPrime(number):
return False return False
return True return True
def gen_random(bits: int): def gen_random(bits: int):
x = int.from_bytes(os.urandom(int(bits/8)), byteOrder) x = int.from_bytes(os.urandom(int(bits / 8)), byteOrder)
return x return x
def sign(encrypted_msg, key): def sign(encrypted_msg, key):
enc_msg = str(encrypted_msg) enc_msg = str(encrypted_msg)
encrypted_msg_list = enc_msg.split() encrypted_msg_list = enc_msg.split()
enc_sig = encrypt("sig:"+key[ID], (key[N], key[D])) enc_sig = encrypt("sig:" + key[ID], (key[N], key[D]))
encrypted_msg_list.append(hex(enc_sig)) encrypted_msg_list.append(hex(enc_sig))
signed_msg = "" signed_msg = ""
for word in encrypted_msg_list: for word in encrypted_msg_list:
signed_msg = str(signed_msg) + " " + str(word) signed_msg = str(signed_msg) + " " + str(word)
return signed_msg.strip() return signed_msg.strip()
def verify(cipher_list): def verify(cipher_list):
local_keys = os.listdir(keysFolder) local_keys = os.listdir(keysFolder)
cipher_list.reverse() #To get last word using index 0 cipher_list.reverse() # To get last word using index 0
encrypted_sig = cipher_list[0] encrypted_sig = cipher_list[0]
cipher_list.reverse() cipher_list.reverse()
sig = None sig = None
@ -251,12 +270,14 @@ def verify(cipher_list):
print("Found key: ", key_name) print("Found key: ", key_name)
sig = str(decrypt(int(encrypted_sig, 16), key[E], key[N])) sig = str(decrypt(int(encrypted_sig, 16), key[E], key[N]))
if "sig:" in sig: if "sig:" in sig:
return sig.replace("sig:","") return sig.replace("sig:", "")
else: continue else:
else: return None continue
else:
return None
def isPrime_cmd(func): def isPrime_cmd(func):
number = int_base_n_from_str(sys.argv[3], sys.argv[2]) number = int_base_n_from_str(sys.argv[3], sys.argv[2])
if func == 0: if func == 0:
@ -265,13 +286,13 @@ def isPrime_cmd(func):
prime = mr.is_prime(number) prime = mr.is_prime(number)
if prime: if prime:
print('Prime') print("Prime")
#print(number) # print(number)
else: else:
print('Not prime') print("Not prime")
def prime_factors(number, base): def prime_factors(number, base):
num = int_base_n_from_str(number, base) num = int_base_n_from_str(number, base)
factors = {1: 1} factors = {1: 1}
k = 0 k = 0
@ -281,9 +302,9 @@ def prime_factors(number, base):
if k != 0: if k != 0:
factors.update({2: k}) factors.update({2: k})
for i in range(3, int(math.sqrt(num))+1, 2): for i in range(3, int(math.sqrt(num)) + 1, 2):
j = 0 j = 0
while (num % i == 0): while num % i == 0:
j += 1 j += 1
num = num / i num = num / i
if j != 0: if j != 0:
@ -292,49 +313,64 @@ def prime_factors(number, base):
factors.update({int(num): 1}) factors.update({int(num): 1})
print(factors) print(factors)
def readKeyFile(keyName): def readKeyFile(keyName):
key = tuple() key = tuple()
with open(keysFolder+keyName, "r") as keyFile: with open(keysFolder + keyName, "r") as keyFile:
tempkey = keyFile.readlines() tempkey = keyFile.readlines()
if len(tempkey) == 3: #means it only public part (n, e, id) if len(tempkey) == 3: # means it only public part (n, e, id)
key = (int(tempkey[N].strip(), 16), int(tempkey[E].strip(), 16), 0, 0, 0, 0, tempkey[2]) key = (
else: #Make this a loop from 0 to 5 int(tempkey[N].strip(), 16),
key = (int(tempkey[N].strip(), 16), int(tempkey[E].strip(), 16),
0,
0,
0,
0,
tempkey[2],
)
else: # Make this a loop from 0 to 5
key = (
int(tempkey[N].strip(), 16),
int(tempkey[E].strip(), 16), int(tempkey[E].strip(), 16),
int(tempkey[D].strip(), 16), int(tempkey[D].strip(), 16),
int(tempkey[P].strip(), 16), int(tempkey[P].strip(), 16),
int(tempkey[Q].strip(), 16), int(tempkey[Q].strip(), 16),
int(tempkey[PHI].strip(), 16), int(tempkey[PHI].strip(), 16),
str(tempkey[ID].strip())) str(tempkey[ID].strip()),
)
return key return key
def saveKeyFile(key, fileName): def saveKeyFile(key, fileName):
if not os.path.isdir(keysFolder): if not os.path.isdir(keysFolder):
os.makedirs(keysFolder) os.makedirs(keysFolder)
with open(keysFolder+fileName, "w") as keyFile: with open(keysFolder + fileName, "w") as keyFile:
for entry in range(0, 6): for entry in range(0, 6):
if key[entry] != 0: if key[entry] != 0:
keyFile.write(hex(key[entry])+"\n") keyFile.write(hex(key[entry]) + "\n")
else: else:
pass pass
keyFile.write(key[ID]+"\n") keyFile.write(key[ID] + "\n")
def printKey(key): def printKey(key):
n = key[N] n = key[N]
e = key[E] e = key[E]
d = key[D] d = key[D]
id = key[ID] id = key[ID]
print("----------------------------------------------"+ print(
"\nID: {}".format(id) + "----------------------------------------------"
"\n{}-BIT KEY".format(n.bit_length())+ + "\nID: {}".format(id)
"\nPUBLIC PART:"+ + "\n{}-BIT KEY".format(n.bit_length())
"\n{0}/{1}".format(hex(n), hex(e))+ + "\nPUBLIC PART:"
"\nPTIVATE PART:"+ + "\n{0}/{1}".format(hex(n), hex(e))
"\n{0}".format(hex(d))+ + "\nPTIVATE PART:"
"\n----------------------------------------------", + "\n{0}".format(hex(d))
+ "\n----------------------------------------------",
) )
def listKeys(): def listKeys():
if not os.path.isdir(keysFolder): if not os.path.isdir(keysFolder):
os.makedirs(keysFolder) os.makedirs(keysFolder)
@ -348,15 +384,18 @@ def listKeys():
key = readKeyFile(keyName) key = readKeyFile(keyName)
if key[D] == 0: if key[D] == 0:
check = "".strip() check = "".strip()
else: check = '\u2713' else:
check = "\u2713"
print("%10s%7s%7s-bit" % (key[ID].strip(), check, key[N].bit_length())) print("%10s%7s%7s-bit" % (key[ID].strip(), check, key[N].bit_length()))
def exportKey(keyFileName): def exportKey(keyFileName):
key = readKeyFile(keyFileName) key = readKeyFile(keyFileName)
public_key = (key[N], key[E], 0, 0, 0, 0, key[ID]) public_key = (key[N], key[E], 0, 0, 0, 0, key[ID])
saveKeyFile(public_key, key[ID]+"-public") saveKeyFile(public_key, key[ID] + "-public")
print("Saved public form of key {} in keys folder".format(key[ID])) print("Saved public form of key {} in keys folder".format(key[ID]))
def crackKey(keyName): def crackKey(keyName):
print("in crack") print("in crack")
key = readKeyFile(keyName) key = readKeyFile(keyName)
@ -367,77 +406,90 @@ def crackKey(keyName):
# if number devides n then it p or q # if number devides n then it p or q
if n % number == 0: if n % number == 0:
p = number p = number
q = int(n/p) q = int(n / p)
phi = (p-1)*(q-1) phi = (p - 1) * (q - 1)
e = 65537 e = 65537
d = pow(e,-1,phi) d = pow(e, -1, phi)
key_cracked = (n, e, d, p, q, phi, str(keyName+"-cracked")) key_cracked = (n, e, d, p, q, phi, str(keyName + "-cracked"))
return key_cracked return key_cracked
else: pass else:
else: pass pass
else:
pass
def crackKey2(keyName): def crackKey2(keyName):
print("in crack") print("in crack")
key = readKeyFile(keyName) key = readKeyFile(keyName)
n = key[N] n = key[N]
print("n: ", n) print("n: ", n)
bits = int(n.bit_length()/2) bits = int(n.bit_length() / 2)
print("bits: ", bits) print("bits: ", bits)
while True: while True:
number = int.from_bytes(os.urandom(int(bits/8)), byteOrder) number = int.from_bytes(os.urandom(int(bits / 8)), byteOrder)
if number == 0 or number == 1: continue if number == 0 or number == 1:
continue
print("Trying prime: ", number, end="\r") print("Trying prime: ", number, end="\r")
# if number devides n then it p or q # if number devides n then it p or q
if n % number == 0: if n % number == 0:
print("\nFound a factor") print("\nFound a factor")
p = number p = number
print("p: ", p) print("p: ", p)
q = int(n/p) q = int(n / p)
phi = (p-1)*(q-1) phi = (p - 1) * (q - 1)
if phi == 0: continue if phi == 0:
continue
e = 65537 e = 65537
d = pow(e,-1,phi) d = pow(e, -1, phi)
key_cracked = (n, e, d, p, q, phi, str(keyName+"-cracked")) key_cracked = (n, e, d, p, q, phi, str(keyName + "-cracked"))
print(key_cracked) print(key_cracked)
return key_cracked return key_cracked
else: continue else:
continue
def printHelp(): def printHelp():
print("commands:") print("commands:")
print("rsa gen <keysize> <keyname>") print("rsa gen <keysize> <keyname>")
print("rsa encrypt <message> <key> <signer>") print("rsa encrypt <message> <key> <signer>")
print("rsa decrypt \"<cipher>\" <key>") print('rsa decrypt "<cipher>" <key>')
print("rsa export <key>") print("rsa export <key>")
print("rsa crack <key>") print("rsa crack <key>")
print("rsa print <key>") print("rsa print <key>")
print("rsa list") print("rsa list")
def backTrack(x): def backTrack(x):
#Back track and clear terminal with length of x # Back track and clear terminal with length of x
length = len(str(x)) length = len(str(x))
while length > 0: while length > 0:
print("\b",end="") print("\b", end="")
length -= 1 length -= 1
def keyExist(keyName): def keyExist(keyName):
exist = os.path.exists(keysFolder+keyName) exist = os.path.exists(keysFolder + keyName)
return exist return exist
def int_base_n_from_str(st: str, base):
def int_base_n_from_str(st: str, base):
try: try:
base = int(base) base = int(base)
except ValueError as e: except ValueError as e:
print(f'Value {sys.argv[2]} is not a valid base (2, 8, 10, 16)', print(
file=sys.stderr) f"Value {sys.argv[2]} is not a valid base (2, 8, 10, 16)", file=sys.stderr
)
exit(-1) exit(-1)
try: try:
number = int(st, base) number = int(st, base)
except ValueError as e: except ValueError as e:
print(f'Value {sys.argv[3]} is not valid for as a base {base} number', print(
file=sys.stderr) f"Value {sys.argv[3]} is not valid for as a base {base} number",
file=sys.stderr,
)
exit(-1) exit(-1)
return number return number
if __name__ == "__main__": if __name__ == "__main__":
main() main()