shards: use x/sync/semaphore for throttling
A buffered channel is a simple and wrong solution for throttling: if
the lock() method
n := cap(s.throttle)
for n > 0 {
s.throttle <- struct{}{}
n--
}
is called from multiple goroutines concurrently, none of them will
succeed inserting N items into the queue, deadlocking all of them.
Fixes #23.
Change-Id: I265655ba192755b096a062f9c164a8d872911964
diff --git a/shards/shards.go b/shards/shards.go
index 7de7618..289fe31 100644
--- a/shards/shards.go
+++ b/shards/shards.go
@@ -22,17 +22,36 @@
"sort"
"time"
+ "golang.org/x/sync/semaphore"
+
"github.com/google/zoekt"
"github.com/google/zoekt/query"
)
+type shardedSearcher struct {
+ // Limit the number of parallel queries. Since searching is
+ // CPU bound, we can't do better than #CPU queries in
+ // parallel. If we do so, we just create more memory
+ // pressure.
+ throttle *semaphore.Weighted
+ capacity int64
+
+ shards map[string]zoekt.Searcher
+}
+
+func newShardedSearcher(n int64) *shardedSearcher {
+ ss := &shardedSearcher{
+ shards: make(map[string]zoekt.Searcher),
+ throttle: semaphore.NewWeighted(n),
+ capacity: n,
+ }
+ return ss
+}
+
// NewDirectorySearcher returns a searcher instance that loads all
// shards corresponding to a glob into memory.
func NewDirectorySearcher(dir string) (zoekt.Searcher, error) {
- ss := &shardedSearcher{
- shards: make(map[string]zoekt.Searcher),
- throttle: make(chan struct{}, runtime.NumCPU()),
- }
+ ss := newShardedSearcher(int64(runtime.NumCPU()))
_, err := NewDirectoryWatcher(dir, ss)
if err != nil {
return nil, err
@@ -47,23 +66,13 @@
// Close closes references to open files. It may be called only once.
func (ss *shardedSearcher) Close() {
- ss.lock()
+ ss.lock(context.Background())
defer ss.unlock()
for _, s := range ss.shards {
s.Close()
}
}
-type shardedSearcher struct {
- // Limit the number of parallel queries. Since searching is
- // CPU bound, we can't do better than #CPU queries in
- // parallel. If we do so, we just create more memory
- // pressure.
- throttle chan struct{}
-
- shards map[string]zoekt.Searcher
-}
-
func (ss *shardedSearcher) Search(ctx context.Context, pat query.Q, opts *zoekt.SearchOptions) (*zoekt.SearchResult, error) {
start := time.Now()
type res struct {
@@ -71,14 +80,16 @@
err error
}
- aggregate := zoekt.SearchResult{
+ aggregate := &zoekt.SearchResult{
RepoURLs: map[string]string{},
LineFragments: map[string]string{},
}
// This critical section is large, but we don't want to deal with
// searches on shards that have just been closed.
- ss.rlock()
+ if err := ss.rlock(ctx); err != nil {
+ return aggregate, err
+ }
defer ss.runlock()
aggregate.Wait = time.Now().Sub(start)
start = time.Now()
@@ -147,7 +158,7 @@
zoekt.SortFilesByScore(aggregate.Files)
aggregate.Duration = time.Now().Sub(start)
- return &aggregate, nil
+ return aggregate, nil
}
func (ss *shardedSearcher) List(ctx context.Context, r query.Q) (*zoekt.RepoList, error) {
@@ -156,7 +167,9 @@
err error
}
- ss.rlock()
+ if err := ss.rlock(ctx); err != nil {
+ return nil, err
+ }
defer ss.runlock()
shards := ss.getShards()
@@ -210,8 +223,8 @@
}, nil
}
-func (s *shardedSearcher) rlock() {
- s.throttle <- struct{}{}
+func (s *shardedSearcher) rlock(ctx context.Context) error {
+ return s.throttle.Acquire(ctx, 1)
}
// getShards returns the currently loaded shards. The shards must be
@@ -225,23 +238,15 @@
}
func (s *shardedSearcher) runlock() {
- <-s.throttle
+ s.throttle.Release(1)
}
-func (s *shardedSearcher) lock() {
- n := cap(s.throttle)
- for n > 0 {
- s.throttle <- struct{}{}
- n--
- }
+func (s *shardedSearcher) lock(ctx context.Context) error {
+ return s.throttle.Acquire(ctx, s.capacity)
}
func (s *shardedSearcher) unlock() {
- n := cap(s.throttle)
- for n > 0 {
- <-s.throttle
- n--
- }
+ s.throttle.Release(s.capacity)
}
func (s *shardedSearcher) load(key string) {
@@ -259,7 +264,7 @@
}
func (s *shardedSearcher) replace(key string, shard zoekt.Searcher) {
- s.lock()
+ s.lock(context.Background())
defer s.unlock()
old := s.shards[key]
if old != nil {
diff --git a/shards/shards_test.go b/shards/shards_test.go
index 23a3652..81b9522 100644
--- a/shards/shards_test.go
+++ b/shards/shards_test.go
@@ -60,11 +60,9 @@
out := &bytes.Buffer{}
log.SetOutput(out)
defer log.SetOutput(os.Stderr)
- ss := &shardedSearcher{
- shards: map[string]zoekt.Searcher{
- "x": &crashSearcher{},
- },
- throttle: make(chan struct{}, 2),
+ ss := newShardedSearcher(2)
+ ss.shards = map[string]zoekt.Searcher{
+ "x": &crashSearcher{},
}
q := &query.Substring{Pattern: "hoi"}