深入分析ThreadLocal

深入分析ThreadLocal

前言

在我们平常的编程中,ThreadLocal用的场景可能并不多,但如果要承载一些线程相关的数据,且不想在方法中来回传递参数,则可以用ThreadLocal来保存,比如登录用户信息,你可以保存在ThreadLocal中,后面被调用的方法如果需要用户用户信息,则可以直接从ThreadLocal中取,不需要方法传递这个信息。ThreadLocal就是提供线程内的局部变量,说白了,就是ThreadLocal会在各个线程内部创建一个变量的副本,各个线程访问与修改这个变量不会干扰其他线程。

示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
public class ThreadLocalTest {

private static final ThreadLocal<Integer> threadLocal = new ThreadLocal<Integer>(){
@Override
protected Integer initialValue() {
return Integer.valueOf(0);
}
};

public static void main(String[] args) {
ExecutorService service = Executors.newFixedThreadPool(5);
for (int j = 0; j < 5; j++) {
service.submit(new Runnable() {
@Override
public void run() {
System.out.println(Thread.currentThread() + "'s first value is " + threadLocal.get());
for (int i = 0; i < 10; i++) {
threadLocal.set(threadLocal.get() + i);
}
System.out.println(Thread.currentThread() + "'s last value is " + threadLocal.get());
}
});
}
service.shutdown();
}
}

结果
Thread[pool-1-thread-1,5,main]'s first value is 0
Thread[pool-1-thread-1,5,main]'s last value is 45
Thread[pool-1-thread-2,5,main]'s first value is 0
Thread[pool-1-thread-2,5,main]'s last value is 45
Thread[pool-1-thread-3,5,main]'s first value is 0
Thread[pool-1-thread-3,5,main]'s last value is 45
Thread[pool-1-thread-4,5,main]'s first value is 0
Thread[pool-1-thread-4,5,main]'s last value is 45
Thread[pool-1-thread-5,5,main]'s first value is 0
Thread[pool-1-thread-5,5,main]'s last value is 45

从结果可以看出各线程之间的值不会相互影响

源码解析

首先,我们来看一下结构图,我们可以看到几点:

ThreadLocal_1

  • 每个Thread持有一个ThreadLocalMap。
  • ThreadLocalMap里存储的key是ThreadLocal对象,是一个弱引用;value是线程变量的副本。
  • ThreadLoca的set get方法本质上都是操作Thread里持有的ThradLocalMap。

Thread

1
2
3
4
5
6
public class Thread implements Runnable {

/* ThreadLocal values pertaining to this thread. This map is maintained
* by the ThreadLocal class. */
ThreadLocal.ThreadLocalMap threadLocals = null;
}

ThreadLocalMap结构

ThreadLocalMap内部实现就是一个Entry的数组,Entry包含key和value,根据key的hash值与Entry数组的长度算出该放在数组的哪个索引。如果hash冲突,则采用开放定址法线性探测来存放。实现和HashMap比较类似,只不过处理hash冲突的方式不一样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
static class ThreadLocalMap {

//Entry的Key是弱引用,所以Key的引用对象是会在垃圾回收的时候被回收的
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;

Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
//初始容量16
private static final int INITIAL_CAPACITY = 16;
//Entry数组
private Entry[] table;
//Entry数组中Entry的数量
private int size = 0;
//扩容阈值
private int threshold; // Default to 0
//返回索引i的下一个索引,如果大于len,则返回0
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}
//返回索引i的上一个索引,如果小于1,则返回len-1
private static int prevIndex(int i, int len) {
return ((i - 1 >= 0) ? i - 1 : len - 1);
}
//ThreadLocalMap初始化,初始化Entry数组,并存放Entry元素
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
table = new Entry[INITIAL_CAPACITY];
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new Entry(firstKey, firstValue);
size = 1;
setThreshold(INITIAL_CAPACITY);
}
}

ThreadLocal实现

ThreadLocal提供了几个核心方法set(), get(), remove(), initialValue()方法,我们首先看set()方法

set
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
public class ThreadLocal<T> {

public void set(T value) {
//获取当前线程
Thread t = Thread.currentThread();
//根据当前线程获取ThreadLocalMap,前面我们已经说过了每个Thread都持有一个ThreadLocalMap
ThreadLocalMap map = getMap(t);
//如果ThreadLocalMap不为空,则直接调用ThreadLocalMap的set方法;如果ThreadLocalMap为空,则创建ThreadLocalMap,并绑定给当前的线程
if (map != null)
map.set(this, value);
else
createMap(t, value);
}

ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}

void createMap(Thread t, T firstValue) {
//参看ThreadLocalMap的构造方法
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
}

接下来我们看ThreadLocalMap的set方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
static class ThreadLocalMap {

private void set(ThreadLocal<?> key, Object value) {
//Entry数组
Entry[] tab = table;
//数组长度
int len = tab.length;
//计算key(也即ThreadLocal)存放的索引i
int i = key.threadLocalHashCode & (len-1);
//从i循环Entry数组,直到Entry数组元素为空;这是采用线性探测法来解决冲突,也即如果Hash冲突了,当前索引i的地方已经存放了Entry,则找到下一个为空的地方来进行存放
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
//获取Entry的key
ThreadLocal<?> k = e.get();
//如果key和传进来的key相同,则说明该ThreadLocal已经存在于Map中了,则将值进行替换,返回即可
if (k == key) {
e.value = value;
return;
}
//如果key为null,说明Key已经被回收了,则需要替换这个过期的Entry,后面会介绍
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
//创建一个新的Entry放在索引为i的空数组元素中,Entry数组的size进行自增
tab[i] = new Entry(key, value);
int sz = ++size;
//清除一些Entry的key为空的Entry元素,并进行扩容处理
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}

private void replaceStaleEntry(ThreadLocal<?> key, Object value,
int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;
//从staleSlot往前循环找,直到Entry数组元素为空。如果找到的Entry的key为null,设置slotToExpunge为找到的索引@1
int slotToExpunge = staleSlot;
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
if (e.get() == null)
slotToExpunge = i;

//从staleSlot往后循环找,直到Entry数组元素为空。如果找到的Entry的key和传进来的key一样,则将找到的Entry的value赋值为新值,并将找到的Entry与索引位置为staleSlot的Entry进行交换。也即原来hash冲突的Entry被放到了别的位置,现在放回正确的位置。@2
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == key) {
e.value = value;
tab[i] = tab[staleSlot];
tab[staleSlot] = e;

//如果@1步中往前没有找到key为空的Entry,则将slotToExpunge赋值为i。
if (slotToExpunge == staleSlot)
slotToExpunge = i;
//清除slotToExpunge位置的Entry,并继续进行扫描清除后续Key为Null的Entry
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}
//如果找到的Entry的key为null,且@1步中往前没有找到key为空的Entry,则将slotToExpunge赋值为i
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}

//如果key没有找到,则将旧的value置为null,让value可以被GC,在staleSlot索引的地方创建一个新的Entry
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);

//清除slotToExpunge位置的Entry,并继续进行扫描清除后续Key为Null的Entry
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

//清除指定位置的Entry,并进行重Hash
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;
//将指定位置staleSlot的Entry的value置为null,以便进行垃圾回收,且将指定位置staleSlot数组元素置为null,Entry数组的size自减
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;

Entry e;
int i;
//从staleSlot往后循环找,直到Entry数组元素为空
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
//取staleSlot + 1位置的Entry,且Entry不为null
ThreadLocal<?> k = e.get();
//如果Entry的key为null, 则将value置为null,且将位置staleSlot + 1的数组元素置为null
if (k == null) {
e.value = null;
tab[i] = null;
size--;
//Entry的key不为null,获取key的hash值
//如果根据hash值算出来的索引与staleSlot + i不相等,则将staleSlot + i位置的数组元素清空,将Entry重哈希到Entry数组中
} else {
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}

//从i + 1开始扫描,扫描log2(n)次,n一般为table的长度,如果Entry的Key为null,则进行清除并返回true
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
i = nextIndex(i, len);
Entry e = tab[i];
//如果Entry不null且Entry的key为null,则进行清除
if (e != null && e.get() == null) {
n = len;
removed = true;
//调用expungeStaleEntry方法清除位置为i的Entry
i = expungeStaleEntry(i);
}
} while ( (n >>>= 1) != 0);
return removed;
}

//Map扩容
private void rehash() {
//扫描Entry数组,清除Key为空的Entry
expungeStaleEntries();
//如果size大于threshold - threshold / 4, 则进行扩容操作
if (size >= threshold - threshold / 4)
resize();
}

//扫描Entry数组,清除Key为空的Entry
private void expungeStaleEntries() {
Entry[] tab = table;
int len = tab.length;
for (int j = 0; j < len; j++) {
Entry e = tab[j];
if (e != null && e.get() == null)
expungeStaleEntry(j);
}
}

private void resize() {
//保留老的Entry数组
Entry[] oldTab = table;
//老Entry数组的长度
int oldLen = oldTab.length;
//新Entry数组的长度,是老的长度的2倍
int newLen = oldLen * 2;
//创建Entry数组
Entry[] newTab = new Entry[newLen];
int count = 0;

//遍历每一个老Entry数组中的Entry元素,
//如果key为null的,则不映射到新Entry数组中
//如果key不为null,则重新进行Hash运算,映射到新的Entry数组中
for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null; // Help the GC
} else {
int h = k.threadLocalHashCode & (newLen - 1);
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}

setThreshold(newLen);
size = count;
table = newTab;
}
}

上面便是整个set的过程,还是比较复杂的,我觉得要关注2点,1、如果Entry的key在Hash冲突之后,在Entry数组中顺序往下找到一个为Null的地方存放,或者顺序找到key值一样的进行更新value的操作;2、在清除Key为Null的Entry后,要对后续的Entry进行重Hash,且要继续往后扫描log2(n)次清除Key为Null的Entry。接下来我们看get()方法

get
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
public class ThreadLocal<T> {

public T get() {
//获取当前线程
Thread t = Thread.currentThread();
//根据当前线程获取ThreadLocalMap
ThreadLocalMap map = getMap(t);
//如果map不为null
if (map != null) {
//调用ThreadLocalMap的getEntry方法获取Entry对象,获取到的Entry对象不为空,则直接返回Value
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
//如果map为null,或者Entry对象为null,则返回初始值
return setInitialValue();
}

private T setInitialValue() {
//调用initialValue方法获得初始值,就像我们demo中重写ThreadLocal的initialValue方法
T value = initialValue();
//下面的流程和上面分析的set流程是一致的
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}
}

ThreadLocal的get方法比较简单,接下来看ThreadLocalMap的getEntry方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
static class ThreadLocalMap {

private Entry getEntry(ThreadLocal<?> key) {
//计算key在Entry数组中的索引
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
//如果Entry不为null,且Entry的key与传进来的key一致,则返回该Entry对象;否则循环Entry数组,去寻找Entry对象
if (e != null && e.get() == key)
return e;
else
return getEntryAfterMiss(key, i, e);
}

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;

//循环Entry数组,一直循环到Entry为null,找到key相等的Entry进行返回,否则返回null
while (e != null) {
ThreadLocal<?> k = e.get();
if (k == key)
return e;
if (k == null)
expungeStaleEntry(i);
else
i = nextIndex(i, len);
e = tab[i];
}
return null;
}
}

以上就是get()方法,相比set()方法要简单许多。其实还有remove()方法,其实也比较简单,这里就不在做分析了。

结语

撸源码真的是比较耗脑力,耗精力,耗时间,但是真的看完以后,你会发现优秀的人写的代码确实很棒,有很多可学习的地方,而且对于之前死记硬背的东西又了更深的了解,比如Hash冲突后的解决方案,以及ThreadLocal使用不当会导致内存溢出等等。所以就是痛并快乐着吧!