shards: use x/sync/semaphore for throttling

A buffered channel is a simple and wrong solution for throttling: if
the lock() method

        n := cap(s.throttle)
        for n > 0 {
                s.throttle <- struct{}{}
                n--
        }

is called from multiple goroutines concurrently, none of them will
succeed inserting N items into the queue, deadlocking all of them.

Fixes #23.

Change-Id: I265655ba192755b096a062f9c164a8d872911964
diff --git a/shards/shards.go b/shards/shards.go
index 7de7618..289fe31 100644
--- a/shards/shards.go
+++ b/shards/shards.go
@@ -22,17 +22,36 @@
 	"sort"
 	"time"
 
+	"golang.org/x/sync/semaphore"
+
 	"github.com/google/zoekt"
 	"github.com/google/zoekt/query"
 )
 
+type shardedSearcher struct {
+	// Limit the number of parallel queries. Since searching is
+	// CPU bound, we can't do better than #CPU queries in
+	// parallel.  If we do so, we just create more memory
+	// pressure.
+	throttle *semaphore.Weighted
+	capacity int64
+
+	shards map[string]zoekt.Searcher
+}
+
+func newShardedSearcher(n int64) *shardedSearcher {
+	ss := &shardedSearcher{
+		shards:   make(map[string]zoekt.Searcher),
+		throttle: semaphore.NewWeighted(n),
+		capacity: n,
+	}
+	return ss
+}
+
 // NewDirectorySearcher returns a searcher instance that loads all
 // shards corresponding to a glob into memory.
 func NewDirectorySearcher(dir string) (zoekt.Searcher, error) {
-	ss := &shardedSearcher{
-		shards:   make(map[string]zoekt.Searcher),
-		throttle: make(chan struct{}, runtime.NumCPU()),
-	}
+	ss := newShardedSearcher(int64(runtime.NumCPU()))
 	_, err := NewDirectoryWatcher(dir, ss)
 	if err != nil {
 		return nil, err
@@ -47,23 +66,13 @@
 
 // Close closes references to open files. It may be called only once.
 func (ss *shardedSearcher) Close() {
-	ss.lock()
+	ss.lock(context.Background())
 	defer ss.unlock()
 	for _, s := range ss.shards {
 		s.Close()
 	}
 }
 
-type shardedSearcher struct {
-	// Limit the number of parallel queries. Since searching is
-	// CPU bound, we can't do better than #CPU queries in
-	// parallel.  If we do so, we just create more memory
-	// pressure.
-	throttle chan struct{}
-
-	shards map[string]zoekt.Searcher
-}
-
 func (ss *shardedSearcher) Search(ctx context.Context, pat query.Q, opts *zoekt.SearchOptions) (*zoekt.SearchResult, error) {
 	start := time.Now()
 	type res struct {
@@ -71,14 +80,16 @@
 		err error
 	}
 
-	aggregate := zoekt.SearchResult{
+	aggregate := &zoekt.SearchResult{
 		RepoURLs:      map[string]string{},
 		LineFragments: map[string]string{},
 	}
 
 	// This critical section is large, but we don't want to deal with
 	// searches on shards that have just been closed.
-	ss.rlock()
+	if err := ss.rlock(ctx); err != nil {
+		return aggregate, err
+	}
 	defer ss.runlock()
 	aggregate.Wait = time.Now().Sub(start)
 	start = time.Now()
@@ -147,7 +158,7 @@
 
 	zoekt.SortFilesByScore(aggregate.Files)
 	aggregate.Duration = time.Now().Sub(start)
-	return &aggregate, nil
+	return aggregate, nil
 }
 
 func (ss *shardedSearcher) List(ctx context.Context, r query.Q) (*zoekt.RepoList, error) {
@@ -156,7 +167,9 @@
 		err error
 	}
 
-	ss.rlock()
+	if err := ss.rlock(ctx); err != nil {
+		return nil, err
+	}
 	defer ss.runlock()
 
 	shards := ss.getShards()
@@ -210,8 +223,8 @@
 	}, nil
 }
 
-func (s *shardedSearcher) rlock() {
-	s.throttle <- struct{}{}
+func (s *shardedSearcher) rlock(ctx context.Context) error {
+	return s.throttle.Acquire(ctx, 1)
 }
 
 // getShards returns the currently loaded shards. The shards must be
@@ -225,23 +238,15 @@
 }
 
 func (s *shardedSearcher) runlock() {
-	<-s.throttle
+	s.throttle.Release(1)
 }
 
-func (s *shardedSearcher) lock() {
-	n := cap(s.throttle)
-	for n > 0 {
-		s.throttle <- struct{}{}
-		n--
-	}
+func (s *shardedSearcher) lock(ctx context.Context) error {
+	return s.throttle.Acquire(ctx, s.capacity)
 }
 
 func (s *shardedSearcher) unlock() {
-	n := cap(s.throttle)
-	for n > 0 {
-		<-s.throttle
-		n--
-	}
+	s.throttle.Release(s.capacity)
 }
 
 func (s *shardedSearcher) load(key string) {
@@ -259,7 +264,7 @@
 }
 
 func (s *shardedSearcher) replace(key string, shard zoekt.Searcher) {
-	s.lock()
+	s.lock(context.Background())
 	defer s.unlock()
 	old := s.shards[key]
 	if old != nil {
diff --git a/shards/shards_test.go b/shards/shards_test.go
index 23a3652..81b9522 100644
--- a/shards/shards_test.go
+++ b/shards/shards_test.go
@@ -60,11 +60,9 @@
 	out := &bytes.Buffer{}
 	log.SetOutput(out)
 	defer log.SetOutput(os.Stderr)
-	ss := &shardedSearcher{
-		shards: map[string]zoekt.Searcher{
-			"x": &crashSearcher{},
-		},
-		throttle: make(chan struct{}, 2),
+	ss := newShardedSearcher(2)
+	ss.shards = map[string]zoekt.Searcher{
+		"x": &crashSearcher{},
 	}
 
 	q := &query.Substring{Pattern: "hoi"}