Skip to content

Commit 4817df9

Browse files
authored
Merge pull request #6 from corhere/rwmutex
Add RWMutexMap for read-write locks
2 parents dc2460c + 60c4cf9 commit 4817df9

File tree

2 files changed

+325
-0
lines changed

2 files changed

+325
-0
lines changed

rwmutexmap.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
package locker
2+
3+
import (
4+
"sync"
5+
"sync/atomic"
6+
)
7+
8+
// RWMutexMap is a more convenient map[T]sync.RWMutex. It automatically makes
9+
// and deletes mutexes as needed. Unlocked mutexes consume no memory.
10+
//
11+
// The zero value is a valid MutexMap.
12+
type RWMutexMap[T comparable] struct {
13+
mu sync.Mutex
14+
locks map[T]*rwlockCtr
15+
}
16+
17+
// rwlockCtr is used by RWLocker to represent a lock with a given key.
18+
type rwlockCtr struct {
19+
sync.RWMutex
20+
waiters atomic.Int32 // Number of callers waiting to acquire the lock
21+
readers atomic.Int32 // Number of readers currently holding the lock
22+
}
23+
24+
var rwlockCtrPool = sync.Pool{New: func() any { return new(rwlockCtr) }}
25+
26+
func (l *RWMutexMap[T]) get(key T) *rwlockCtr {
27+
if l.locks == nil {
28+
l.locks = make(map[T]*rwlockCtr)
29+
}
30+
31+
nameLock, exists := l.locks[key]
32+
if !exists {
33+
nameLock = rwlockCtrPool.Get().(*rwlockCtr)
34+
l.locks[key] = nameLock
35+
}
36+
return nameLock
37+
}
38+
39+
// Lock locks the RWMutex identified by key for writing.
40+
func (l *RWMutexMap[T]) Lock(key T) {
41+
l.mu.Lock()
42+
nameLock := l.get(key)
43+
44+
// Increment the nameLock waiters while inside the main mutex.
45+
// This makes sure that the lock isn't deleted if `Lock` and `Unlock` are called concurrently.
46+
nameLock.waiters.Add(1)
47+
l.mu.Unlock()
48+
49+
// Lock the nameLock outside the main mutex so we don't block other operations.
50+
// Once locked then we can decrement the number of waiters for this lock.
51+
nameLock.Lock()
52+
nameLock.waiters.Add(-1)
53+
}
54+
55+
// RLock locks the RWMutex identified by key for reading.
56+
func (l *RWMutexMap[T]) RLock(key T) {
57+
l.mu.Lock()
58+
nameLock := l.get(key)
59+
60+
nameLock.waiters.Add(1)
61+
l.mu.Unlock()
62+
63+
nameLock.RLock()
64+
// Increment the number of readers before decrementing the waiters
65+
// so concurrent calls to RUnlock will not see a glitch where both
66+
// waiters and readers are 0.
67+
nameLock.readers.Add(1)
68+
nameLock.waiters.Add(-1)
69+
}
70+
71+
// Unlock unlocks the RWMutex identified by key.
72+
//
73+
// It is a run-time error if the lock is not locked for writing on entry to Unlock.
74+
func (l *RWMutexMap[T]) Unlock(key T) {
75+
l.mu.Lock()
76+
defer l.mu.Unlock()
77+
nameLock := l.get(key)
78+
// We don't have to do anything special to handle the error case:
79+
// l.get(key) will return an unlocked mutex.
80+
81+
if nameLock.waiters.Load() <= 0 && nameLock.readers.Load() <= 0 {
82+
delete(l.locks, key)
83+
defer rwlockCtrPool.Put(nameLock)
84+
}
85+
nameLock.Unlock()
86+
}
87+
88+
// RUnlock unlocks the RWMutex identified by key for reading.
89+
//
90+
// It is a run-time error if the lock is not locked for reading on entry to RUnlock.
91+
func (l *RWMutexMap[T]) RUnlock(key T) {
92+
l.mu.Lock()
93+
defer l.mu.Unlock()
94+
nameLock := l.get(key)
95+
nameLock.readers.Add(-1)
96+
97+
if nameLock.waiters.Load() <= 0 && nameLock.readers.Load() <= 0 {
98+
delete(l.locks, key)
99+
defer rwlockCtrPool.Put(nameLock)
100+
}
101+
nameLock.RUnlock()
102+
}
103+
104+
// Locker returns a [sync.Locker] interface that implements
105+
// the [sync.Locker.Lock] and [sync.Locker.Unlock] methods
106+
// by calling l.Lock(name) and l.Unlock(name).
107+
func (l *RWMutexMap[T]) Locker(key T) sync.Locker {
108+
return nameRWLocker[T]{l: l, key: key}
109+
}
110+
111+
// RLocker returns a [sync.Locker] interface that implements
112+
// the [sync.Locker.Lock] and [sync.Locker.Unlock] methods
113+
// by calling l.RLock(name) and l.RUnlock(name).
114+
func (l *RWMutexMap[T]) RLocker(key T) sync.Locker {
115+
return nameRLocker[T]{l: l, key: key}
116+
}
117+
118+
type nameRWLocker[T comparable] struct {
119+
l *RWMutexMap[T]
120+
key T
121+
}
122+
type nameRLocker[T comparable] nameRWLocker[T]
123+
124+
func (n nameRWLocker[T]) Lock() {
125+
n.l.Lock(n.key)
126+
}
127+
func (n nameRWLocker[T]) Unlock() {
128+
n.l.Unlock(n.key)
129+
}
130+
131+
func (n nameRLocker[T]) Lock() {
132+
n.l.RLock(n.key)
133+
}
134+
func (n nameRLocker[T]) Unlock() {
135+
n.l.RUnlock(n.key)
136+
}

rwmutexmap_test.go

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
package locker
2+
3+
import (
4+
"math/rand"
5+
"strconv"
6+
"sync"
7+
"testing"
8+
"time"
9+
)
10+
11+
func TestRWMutex_Lock(t *testing.T) {
12+
var l RWMutexMap[string]
13+
l.Lock("test")
14+
ctr := l.locks["test"]
15+
16+
if w := ctr.waiters.Load(); w != 0 {
17+
t.Fatalf("expected waiters to be 0, got %d", w)
18+
}
19+
20+
chDone := make(chan struct{})
21+
go func() {
22+
l.Lock("test")
23+
close(chDone)
24+
}()
25+
26+
chWaiting := make(chan struct{})
27+
go func() {
28+
for range time.Tick(1 * time.Millisecond) {
29+
if ctr.waiters.Load() == 1 {
30+
close(chWaiting)
31+
break
32+
}
33+
}
34+
}()
35+
36+
select {
37+
case <-chWaiting:
38+
case <-time.After(3 * time.Second):
39+
t.Fatal("timed out waiting for lock waiters to be incremented")
40+
}
41+
42+
select {
43+
case <-chDone:
44+
t.Fatal("lock should not have returned while it was still held")
45+
default:
46+
}
47+
48+
l.Unlock("test")
49+
50+
select {
51+
case <-chDone:
52+
case <-time.After(3 * time.Second):
53+
t.Fatalf("lock should have completed")
54+
}
55+
56+
if w := ctr.waiters.Load(); w != 0 {
57+
t.Fatalf("expected waiters to be 0, got %d", w)
58+
}
59+
}
60+
61+
func TestRWMutex_Unlock(t *testing.T) {
62+
var l RWMutexMap[string]
63+
64+
l.Lock("test")
65+
l.Unlock("test")
66+
67+
chDone := make(chan struct{})
68+
go func() {
69+
l.Lock("test")
70+
close(chDone)
71+
}()
72+
73+
select {
74+
case <-chDone:
75+
case <-time.After(3 * time.Second):
76+
t.Fatalf("lock should not be blocked")
77+
}
78+
}
79+
80+
func TestRWMutex_RLock(t *testing.T) {
81+
var l RWMutexMap[string]
82+
rlocked := make(chan bool, 1)
83+
wlocked := make(chan bool, 1)
84+
n := 10
85+
86+
go func() {
87+
for i := 0; i < n; i++ {
88+
l.RLock("test")
89+
l.RLock("test")
90+
rlocked <- true
91+
l.Lock("test")
92+
wlocked <- true
93+
}
94+
}()
95+
96+
for i := 0; i < n; i++ {
97+
<-rlocked
98+
l.RUnlock("test")
99+
select {
100+
case <-wlocked:
101+
t.Fatal("RLock() didn't block Lock()")
102+
default:
103+
}
104+
l.RUnlock("test")
105+
<-wlocked
106+
select {
107+
case <-rlocked:
108+
t.Fatal("Lock() didn't block RLock()")
109+
default:
110+
}
111+
l.Unlock("test")
112+
}
113+
114+
if len(l.locks) != 0 {
115+
t.Fatalf("expected no locks to be present in the map, got %d", len(l.locks))
116+
}
117+
}
118+
119+
func TestRWMutex_Concurrency(t *testing.T) {
120+
var l RWMutexMap[string]
121+
122+
var wg sync.WaitGroup
123+
for i := 0; i <= 10000; i++ {
124+
wg.Add(1)
125+
go func() {
126+
l.Lock("test")
127+
// if there is a concurrency issue, will very likely panic here
128+
l.Unlock("test")
129+
l.RLock("test")
130+
l.RUnlock("test")
131+
wg.Done()
132+
}()
133+
}
134+
135+
chDone := make(chan struct{})
136+
go func() {
137+
wg.Wait()
138+
close(chDone)
139+
}()
140+
141+
select {
142+
case <-chDone:
143+
case <-time.After(10 * time.Second):
144+
t.Fatal("timeout waiting for locks to complete")
145+
}
146+
147+
// Since everything has unlocked this should not exist anymore
148+
if ctr, exists := l.locks["test"]; exists {
149+
t.Fatalf("lock should not exist: %v", ctr)
150+
}
151+
}
152+
153+
func BenchmarkRWMutex(b *testing.B) {
154+
var l RWMutexMap[string]
155+
b.ReportAllocs()
156+
for i := 0; i < b.N; i++ {
157+
l.Lock("test")
158+
l.Lock(strconv.Itoa(i))
159+
l.Unlock(strconv.Itoa(i))
160+
l.Unlock("test")
161+
}
162+
}
163+
164+
func BenchmarkRWMutex_Parallel(b *testing.B) {
165+
var l RWMutexMap[string]
166+
b.SetParallelism(128)
167+
b.RunParallel(func(pb *testing.PB) {
168+
for pb.Next() {
169+
l.Lock("test")
170+
l.Unlock("test")
171+
}
172+
})
173+
}
174+
175+
func BenchmarkRWMutex_MoreKeys(b *testing.B) {
176+
var l RWMutexMap[string]
177+
var keys []string
178+
for i := 0; i < 64; i++ {
179+
keys = append(keys, strconv.Itoa(i))
180+
}
181+
b.SetParallelism(128)
182+
b.RunParallel(func(pb *testing.PB) {
183+
for pb.Next() {
184+
k := keys[rand.Intn(len(keys))]
185+
l.Lock(k)
186+
l.Unlock(k)
187+
}
188+
})
189+
}

0 commit comments

Comments
 (0)