pwnhub-crypto-theme

好久没看到Pwnhub出题,于是趁着自己也研究了一下密码学相关的东西,看看能不能做出来一题。然而出题人实在是强,让人再一次明白了思路开阔的重要性

Pwnhub Crypto 解题记录

比赛中我总共看了两个题目,这里先记录以下我在解babyOT的时候的思路

题目描述

题目给出了一个python文件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#!/usr/bin/env -S python3 -u
import os
import string
from Crypto.PublicKey import RSA
from Crypto.Util.number import bytes_to_long
from random import SystemRandom
def getkey():
if os.path.isfile("key.pem"):
with open("key.pem", "rb") as f:
key = RSA.importKey(f.read())
else:
key = RSA.generate(2048)
with open("key.pem", "wb") as f:
f.write(key.exportKey("PEM"))
return key
def random_str(n):
return "".join([random.choice(string.ascii_letters) for _ in range(n)])
if __name__ == "__main__":
random = SystemRandom()
key = getkey()
print(key.n)
print(key.e)
while True:
msg0 = bytes_to_long(random_str(2048 // 8 - 1).encode())
msg1 = bytes_to_long(random_str(2048 // 8 - 1).encode())
x0 = random.randrange(key.n)
x1 = random.randrange(key.n)
print(x0)
print(x1)
v = int(input())
print((msg0 + pow(v - x0, key.d, key.n)) % key.n)
print((msg1 + pow(v - x1, key.d, key.n)) % key.n)
guess0 = int(input())
guess1 = int(input())
if guess0 == msg0:
print("You are on the half way of success, work harder!")
if guess1 == msg1:
print(open('flag').read())
exit()

题目本身跑在server上,我们需要直接与其进行通信。不过乍一看,好像就是猜随机数msg0/msg1,还得一次性猜对。其次注意到,我们可以交互的地方是这个v值,以及最后的猜测msg0/msg1的值。题目中还使用了RSA算法,不过这里的RSA是用PyCrypto这个库来生成的,所以目测应该没有什么太大的漏洞。一眼看过去能够注意到的点就只有这些了。

从题目上看,可能和这个叫做OT的东西相关,队友了解到是一个叫做Oblivious Transfer(不经意传输)的协议。于是这里首先要了解一下这个协议本身:

Oblivious Transfer 不经意传输

OT是多方安全计算下使用的算法之一。

协议解决的场景

要介绍一个协议,首先要介绍一下这个算法使用的场景:
假设有两个人Alice/Bob,Alice手上有很多的数据,Bob想要知道Alice的数据。但是两个人都非常小心,不想让对方知道自己的信息,具体来说就是:

  • Alice不想让Bob知道他不该知道的信息
  • Bob不想让Alice知道他选择了哪个信息

也就是说Bob想知道一个Alice信息,但是Alice不知道Bob选择了哪个信息,Bob也不能知道Alice的其他信息。(有一点零知识证明的意思)

算法细节

这边用最简单的【1-2 不经意传输】做例子。1-2的意思是【从两个消息中,选取一个信息】
这里摘录一个wiki的表格:

  1. Alice 有两个秘密消息m0, m1
  2. Alice 使用RSA算法,生成公钥(e, N)对公开,私钥d自己留着。公钥(e, N)告知 Bob。这里需要注意的是,每次通信的时候RSA都要重新生成一对公钥私钥
  3. Alice 产生两个随机数x0, x1,并且将这两个随机数传输给 Bob。
  4. Bob 【决定】要获取的数字编号b(0或者1),以及产生一个随机数
  5. Bob 计算一个数字v=(xb+ke)modNv=(x_b+k^e) mod N 这里的e即为前面给出来的RSA的公钥。并且将这个v发送给 Alice
  6. Alice 计算多个kik_i,其中一个kik_i将会等于k

k0=(vx0)dmodNk1=(vx1)dmodNk_0=(v-x_0)^d mod N \\ k_1=(v-x_1)^d mod N

注意由于此时v并不是 Alice 产生的,所以此时的 Alice 并不知道哪一个k是 Bob 需要的
7. Alice 将生成的值与自己手上的信息进行相加,得到全新的信息

m0=m0+k0m1=m1+k1m_0'=m_0+k_0\\ m_1'=m_1+k_1

并且将信息发送给Bob。因为此时每一个信息都增加了kik_i,所以 Bob 无法直接还原信息m
8. Bob 此时知道自己选择的信息编号b,于是选出mbm_b,计算出kbk_b,并且用mb=mbkbm_b=m_b'-k_b得到此时的解密信息。

如何解决场景

对于Alice而言:

  • Alice 能够知道的只是 Bob给出来的一个随机数xbx_b算出来的值v,这个值还被一个用RSA公钥加密的k相加过,所以Alice无法知道这个xbx_b具体是哪个值

对于Bob而言:

  • Alice 交出了所有的信息,但是Alice交出的信息中,除了自己指定的信息,另一个信息(假设是xbx_{b'})被加密成了mb=mb+(xb+kexb)dm_{b'}'=m_{b'}+(x_b+k^e-x_{b'})^d。这个信息 Bob 已经无法还原了,所以Bob只能获得自己想要的信息,其他信息只能抛弃

不过换个角度来说,Bob交给了Alice一个有可能判别身份的数字,Alice交给了Bob自己所有的信息。从结果上来看,两个地方都存在攻击面。实际上OT在实现上可以使用不同的公钥加密方式,不一定非要使用RSA。

题目分析

把协议过了一遍,发现这个题目本质上就是Server就是Alice, 我们来模拟Bob,解开被Alice加密的那些信息。不过乍一看,生成RSA密钥对用的是PyCrypto,那这个生成算法估计是没啥问题的。而且题目代码非常简洁,感觉不到什么可以被利用的地方。于是只好和队友重新过了一遍题目,发现几个问题点:

  • 我们给出来的v可以和协议不一样,直接给出x0x_0,这样我们可以直接得到msg0,不过这个时候msg1的为msg1+(x0x1)dmodnmsg1 + (x_0- x_1)^d mod n
  • RSA的密钥对仅仅在代码最开头生成了一次。
  • msg0/msg1是ascii码,意味着都是可见的字符。进一步来说,如果能够进行一定的限制的话,msg0/msg1存在爆破的可能。

故事一开始,我猜测是不是有x0x1x_0-x_1这个值能够在n上构成循环群,然后这个x0x1x_0-x_1能够形成一个循环群之类的。虽然我们知道这个(x0x1)e(x_0-x_1)^e的阶为d,不过我当时猜测是不是有比较小的阶也能满足这个条件,最后显然是失败了,爆破了很久都爆不出来。虽然后来出题放放出了hint提到我们都忽略了一个点

  • RSA算法的密钥对没有重新生成

但是我们不知道怎么来考虑这点,毕竟v的值除了受到密钥对的影响,还有随机数x。

寻找不变量

最后官方WP放出来之后,我们才理解这个题目怎么做。记得很久以前看过解决小学奥数题有一个根本思路:要在变量里面寻找不变量。在这个题目里面的不变量其实就是RSA的密钥对(n, e, d), 其中n,e我们又是已知的,d我们无法得知。不过认真看代码的话会发现有一个地方用到了这个值:

1
2
print((msg0 + pow(v - x0, key.d, key.n)) % key.n)
print((msg1 + pow(v - x1, key.d, key.n)) % key.n)

这里的F(v,xi)=(vxi)dmodnF(v,x_i)=(v-x_i)^d modn这个加密函数FF其实算是不变量,因为这里的d/n都是一个固定值。不过当时比赛的时候没想到这点怎么利用,后来看了答案得到了提示,那就是说

如果输入的值不变,那么就能获得一样的输出值

回想hintRSA算法的密钥对没有重新生成,这里其实暗示了一个点,那就是如果发起多次连接,RSA还是不变的。这里一个非常巧妙的地方就在于多次连接。我们知道,每次重新连接的时候,所有的信息都会被重置,不过在这里面,蕴藏着一个不变量,也就是我们之前提到的F加密函数,如果我们能够控制F函数的输入不变,那么我们就能够获得同一个输出!具体要怎么做呢?我们假设整个题目在第一次产生的变量叫做(x0,x1,msg0,msg1)(x_0,x_1,msg_0,msg_1),第二次生成的变量叫做(x0,x1,msg0,msg1)(x_0',x_1',msg_0',msg_1'),我们第一次连接的时候,能够知道(x0,x1)(x_0,x_1),输入的v为x0x_0,然后就能得到msg0msg_0msg1+(x0x1)dmodnmsg_1+(x_0-x_1)^dmodn
这里可以看到,我们的v取值为x0x_0,于是F(x0,x1)=(x0x1)dmodnF(x_0,x_1)=(x_0-x_1)^dmodn
这之后,我们不关闭这个连接,重新建立新的连接,此时能够得到(x0x1)(x_0'-x_1'),与此同时,我们知道在这一次,加密函数的写法变成了F(v,xi)=(vxi)dmodnF(v, x_i')=(v-x_i')^dmodn。第一次连接中我们需要推断的是(x0x1)dmodn(x_0-x_1)^dmodn,于是这里的xi>x1x_i->x_1。于是可以有如下推断

F(v,x1)=(vx1)dmodn(x0x1)dmodn=(x0x1+x1x1)dmodn=F(x0x1+x1,x1)\because F(v,x_1')=(v-x_1')^dmodn \\ \therefore (x_0-x_1)^dmodn = (x_0-x_1+x_1'-x_1')^dmodn = F(x_0-x_1+x_1', x_1')

因此,当我们的v=x0x1+x1v=x_0-x_1+x_1'时,被F(v,x1)F(v, x_1')加密过的值将会与F(x0,x1)F(x_0, x_1)相同,从而保证能够在多次连接中获得相同的加密值。至此,我们就在多个连接中找到了不变量。

题解推导

控制F获得同一个输出的意义在哪儿呢?首先,每一次都需要取猜测msg,乍一看每一次的数据都是独立的,不变量对于我们需要获取msg1有什么帮助呢?首先第一步,我们先确定我们要得到什么

需要知道msg的值

然后我们能够做什么

控制输入v

一个直观想法是,让v等于x0x_0,但是此时官方会打印两个值:

msg0+0msg1+(x0x1)dmodNmsg0+0 \\ msg1+(x_0-x_1)^d mod N

由于我们不知道d,所以无法算出(x0x1)dmodN(x_0-x_1)^d modN。到这里,有几种不同的思路:

  • 尝试获取d
  • 直接爆破msg1

这里考虑到我们已知的条件:

  • 多次连接下,可以获得相同的F(v,xi)F(v, x_i)
  • msg的每一个字节取值均为ascii

显然这个条件对于获取d没有什么帮助,乍一看好像也和直接爆破msg1没关系。我们这里设y=F(xi,v)y = F(x_i,v),并设官方会打印数字P = (msg1 + pow(v - x1, key.d, key.n)) % key.n。如果我们进行了很多次的连接,同时将y控制不变,那么此时就有:

P=msg1+yP=msg1+yP=msg1+y...其中,msg1的每一个字节均为asciiP=msg1+y \\ P'=msg1'+y \\ P''=msg1''+y \\ ... \\ \text{其中,msg1的每一个字节均为ascii}

也就是说,每一次获得的msg都是一个区间值,因此就形成了一个不等式方程组。通过多次计算,y将会被逐渐限制在一个区间内,最后得到一个具体的值。这样我们就能够爆破得到y,从而猜测到msg的基本信息。带入不等式可以有

P=msg1+ymsg1=Py0x414141+384msg10x7a7a7a+3840x414141+384Py0x7a7a7a+384P0x7a7a7a+384yP0x414141+384P=msg1+y \Rightarrow msg1=P-y \\ \because \begin{matrix} \underbrace{ 0x414141+\cdots } \\ 384 \end{matrix} \leq msg1 \leq \begin{matrix} \underbrace{ 0x7a7a7a+\cdots } \\ 384 \end{matrix} \\ \therefore \begin{matrix} \underbrace{ 0x414141+\cdots } \\ 384 \end{matrix} \leq P-y \leq \begin{matrix} \underbrace{ 0x7a7a7a+\cdots } \\ 384 \end{matrix} \\ \therefore P - \begin{matrix} \underbrace{ 0x7a7a7a+\cdots } \\ 384 \end{matrix} \leq y \leq P - \begin{matrix} \underbrace{ 0x414141+\cdots } \\ 384 \end{matrix}

不过说起来,仔细考虑的话会发现,由于这里考虑的粒度比较粗,所以有一些情况会无法涵盖(例如msg的取值不可能为0x412041.....)所以此时得到的范围会相对来说比较宽泛,我们需要缩小取值范围

缩小取值范围

我们首先要注意msg1并不是真正意义上的从[0x41414141....~0x7a7a7a7a7a...],之前的例子理也出现过了,例如0x41204141.....这种数字是不会出现的,所以我们与其整体考虑,不如拆分成每一个数字来考虑。也就是我们将考虑每一个数字可能的取值范围

这里假设msg1=m0m1m2...m384,y=y0y1...y384P=P0P1...P384则此时有mi{ord(A)...ord(Z),ord(a)...ord(z)}那么此时y=Pmsg1yi{Piord(A)...Piord(Z),Piord(a)Piord(z)}\text{这里假设} msg1=m_0m_1m_2...m_{384}, y=y_0y_1...y_{384} P=P_0P_1...P_{384} \\ \text{则此时有} m_i \in \{ord('A')...ord('Z'), ord('a')...ord('z')\} \\ \text{那么此时}y = P-msg1 \Rightarrow y_i \in \{P_i-ord('A')...P_i-ord('Z'), P_i-ord('a')-P_i-ord('z')\}

那么我们取多次P,就能够将y缩小到一个比较小的范围里面。不过这里取值的时候,需要考虑到借位的问题:

假设第i-1个值:Pi1ord(A)则此时第i个数字必然发生借位,于是有Piord(z)1yiPiord(A)1假设第i-1个值:Pi1ord(z)此时下下标为i-1则此时第i个数字的取值中的左边可能发生借位也就是Piord(z)1yiPiord(A)\text{假设第i-1个值}: P_{i-1} \leq ord('A') \\ \text{则此时第i个数字必然发生借位,于是有} \\ P_i-ord('z')-1 \leq y_i \leq P_i-ord('A')-1 \\ \text{假设第i-1个值}: P_{i-1} \leq ord('z'),\text{此时下下标为i-1}\\ \text{则此时第i个数字的取值中的左边可能发生借位},\text{也就是}\\ P_i-ord('z')-1 \leq y_i \leq P_i-ord('A')

为了增大检测范围,此时可以用如下两条规则来确定:

  • Pi1ord(A)P_{i-1} \leq ord('A'),则必定左,右侧同时借位
  • Pi1ord(z)P_{i-1} \leq ord('z'),则必定左侧借位

其余情况一律当作不接位处理(毕竟是不停产生的随机数,所以可以稍微放松一点约束条件)

细节处理

除此之外,有一个小细节,是写poc才发现的。。。看代码这段:

1
2
msg0 = bytes_to_long(random_str(2048 // 8 - 1).encode())
msg1 = bytes_to_long(random_str(2048 // 8 - 1).encode())

实际上msg的长度没有真实随机数的长度长,所以实际上我们需要爆破的只有255个字节,第256个字节只需要考虑进位问题即可

最终爆破

当我们把范围限制在一个可以承受的范围的时候(可能的取值控制在500000左右之后),我们就可以尝试去爆破y的具体取值(毕竟原题目是一个需要交互的题目,减少交互可以减小网络等问题的影响)。具体怎么做呢?由于此时我们已知y=msg1+(x0x1)dmodNy=msg1+(x_0-x_1)^dmodN ,此时我们可以确认的已知量为 x0x1x_0-x_1 ,所以我们的算式需要围绕这个值来进行爆破。由于我们使用了RSA对这个数字进行加密,所以我们可以遍历区间中的所有取值,检查 (ymsg1)emodN==x0x1(y-msg1)^emodN == x_0-x_1。如果成立的话,则代表我们爆破的y是合理的。

这里贴上为了解题写的poc,并且改写成了本地的版本:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import os
import string
from Crypto.PublicKey import RSA
from Crypto.Util.number import bytes_to_long, long_to_bytes
from random import SystemRandom
from itertools import product
import gmpy2
import codecs
class NumSet(object):
def __init__(self, init_set, index):
self.final_set = init_set
self.possible = set()
self.index = index
def add_possible(self, num):
# if num < 0:
# num += 256
self.possible.add(num % 256)
def check_possible(self):
"""
becuase not all number is possible, here we
should and the final_set with possible set
And finally, we will clear the possible set
"""
self.final_set &= self.possible
self.possible = set()
condition = len(self.final_set) > 0
if not condition:
print("index is " + str(self.index))
# assert(condition)
return len(self.final_set) > 0
def and_possible(self, s):
if type(s) == type(set()):
self.possible &= s
else:
print("Please check input type")
def get_length(self):
return len(self.final_set)
def in_set(self, n):
return n in self.final_set
def __str__(self):
return "the set is {}".format(self.final_set)
class Server(object):
def __init__(self):
self.key = None
self.x0 = self.x1 = 0
self.msg0 = self.msg1 = 0
def init_server(self):
self.random = SystemRandom()
if os.path.isfile("key.pem"):
with open("key.pem", "rb") as f:
key = RSA.importKey(f.read())
else:
key = RSA.generate(2048)
with open("key.pem", "wb") as f:
f.write(key.exportKey("PEM"))
self.key = key
def get_key(self):
return self.key
def generate_x0_x1(self):
x0 = self.random.randrange(self.key.n)
x1 = self.random.randrange(self.key.n)
self.x0, self.x1 = x0, x1
return x0, x1
def random_str(self, n):
return "".join([self.random.choice(string.ascii_letters) for _ in range(n)])
def generate_value(self, v):
msg0 = bytes_to_long(self.random_str(2048 // 8 - 1).encode())
msg1 = bytes_to_long(self.random_str(2048 // 8 - 1).encode())
value0 = (msg0 + pow(v - self.x0, self.key.d, self.key.n)) % self.key.n
value1 = (msg1 + pow(v - self.x1, self.key.d, self.key.n)) % self.key.n
self.msg0 = msg0
self.msg1 = msg1
return value0, value1
def check_answer(self, msg0, msg1):
if msg0 == self.msg0 and msg1 == self.msg1:
print("You find the answer!")
else:
print("Sorry the answer is wrong~")
possible_num = []
for i in range(256):
possible_num.append(NumSet({t for t in range(256)}, i))
# possible_num = [NumSet({i for i in range(256)}, i) for i in range(256)]
def check_scale(num):
global possible_num
# @param num:bytes number that used to limit value
left_carry = 0
right_carry = 0
for i in range(255, -1, -1):
each_num = num[i]
if i == 0:
possible_num[i].add_possible(each_num - left_carry)
possible_num[i].add_possible(each_num)
possible_num[i].check_possible()
else:
for c in range(ord('A'), ord('z')+1):
tmp = c
result = each_num - tmp
if left_carry == 1:
result -= 1
possible_num[i].add_possible(result)
if right_carry == 0:
possible_num[i].add_possible(each_num - ord('A'))
if not possible_num[i].check_possible() :
print(each_num)
print(possible_num[i].final_set)
assert(False)
if each_num < ord('A'):
left_carry = right_carry = 1
elif each_num < ord('z'):
left_carry = 1
right_carry = 0
else:
left_carry = right_carry = 0
def generate_number():
tmp_ans = []
tmp_num = []
for each in possible_num:
tmp_num.append(each.final_set)
for each in product(*tmp_num):
tmp_number = bytes_to_long(bytes(each))
tmp_ans.append(tmp_number)
return tmp_ans
def guess_num(value_1, key, we_need_know, x0x1):
global test_left_scale
bytes_value = long_to_bytes(value_1, 256)
check_scale(bytes_value)
all_possibles_number = 1
all_possibles_array = []
for i, each in enumerate(possible_num):
all_possibles_array.append(each.get_length())
all_possibles_number *= each.get_length()
print("{}:{}".format(all_possibles_number, all_possibles_array))
if all_possibles_number < 500000:
print("Ok, we can get length")
index = 0
ans_array = generate_number()
for m in ans_array:
m %= key.n
# if m == we_need_know:
if pow(m, key.e, key.n) == x0x1:
return m
print("Failed")
assert(False)
return -1
if __name__ == "__main__":
# global test_scale
# random = SystemRandom()
first_server = Server()
first_server.init_server()
x0, x1 = first_server.generate_x0_x1()
v = x0
msg0, value1 = first_server.generate_value(v)
key = first_server.get_key()
we_need_know = pow(v - x1, key.d, key.n)
guess0 = v
msg1 = 0
while True:
# here work as server
new_server = Server()
new_server.init_server()
x_0, x_1 = new_server.generate_x0_x1()
v = x0-x1+x_0
value_0, value_1 = new_server.generate_value(v)
num = guess_num(value_0, key, we_need_know, (x0-x1)%key.n)
if num == we_need_know:
msg1 = (value1 - num) % key.n
print("successful!")
break
num = -1
if num != -1:
break
first_server.check_answer(msg0, msg1)

后记

整个题目前前后后花了大概有半个月的时间来解,除了平时上班之外,其实剩余的时间也是不少,不过最近因为各种原因,心思总是不能集中在一件事情上,导致实际上公式老早就推导完成,但是实际上poc却写了有两周这样无厘头的事情。

关于密码学

  • 边界条件真的很重要。尤其这种带有猜测性质的问题,poc很多时候要么会猜测错误答案导致最后某个字节可能的取值为空,要么会少猜测一些值导致迟迟不能进行爆破,其实很大原因是因为边界值搞得不对。
  • 跳出纯粹的数学思维。这可能更加是针对协议来说,协议里面有一条RSA需要在每次握手的时候进行变化,为什么要这么做呢?当初我以为是有什么没见过的数学原理在里面,现在想想协议规定这一条实际上应该就是为了预防这种单次链接未结束的时候,进行多次链接这种特殊情况。在考虑协议的时候,除了正向考虑问题(直接从头到尾寻找整个流程中可以被攻击的算法),有时候也可以逆转思路来考虑(不是顺序触发流程中的条件,而是检查协议本身是否存在可以泄露信息的地方)