From b8bbeae7404d85fdd719f4830d0e66ff862e1d97 Mon Sep 17 00:00:00 2001 From: Raul Jordan Date: Wed, 1 Feb 2023 09:32:01 -0500 Subject: [PATCH] Introduce Thread Safe Map Data Structure (#11940) * fix thread safety issue * gazelle * push up benchmarks * rev * rem keys method * shallow copy * fxi --- container/thread-safe/BUILD.bazel | 15 +++++ container/thread-safe/map.go | 59 +++++++++++++++++++ container/thread-safe/map_test.go | 97 +++++++++++++++++++++++++++++++ 3 files changed, 171 insertions(+) create mode 100644 container/thread-safe/BUILD.bazel create mode 100644 container/thread-safe/map.go create mode 100644 container/thread-safe/map_test.go diff --git a/container/thread-safe/BUILD.bazel b/container/thread-safe/BUILD.bazel new file mode 100644 index 0000000000..7b2d756aa0 --- /dev/null +++ b/container/thread-safe/BUILD.bazel @@ -0,0 +1,15 @@ +load("@prysm//tools/go:def.bzl", "go_library", "go_test") + +go_library( + name = "go_default_library", + srcs = ["map.go"], + importpath = "github.com/prysmaticlabs/prysm/v3/container/thread-safe", + visibility = ["//visibility:public"], +) + +go_test( + name = "go_default_test", + srcs = ["map_test.go"], + embed = [":go_default_library"], + deps = ["//testing/require:go_default_library"], +) diff --git a/container/thread-safe/map.go b/container/thread-safe/map.go new file mode 100644 index 0000000000..96932f0757 --- /dev/null +++ b/container/thread-safe/map.go @@ -0,0 +1,59 @@ +// Package threadsafe contains generic containers that are +// protected either by Mutexes or atomics underneath the hood. +package threadsafe + +import "sync" + +// Map implements a simple thread-safe map protected by a mutex. +type Map[K comparable, V any] struct { + items map[K]V + lock sync.RWMutex +} + +// NewThreadSafeMap returns a thread-safe map instance from a normal map. +func NewThreadSafeMap[K comparable, V any](m map[K]V) *Map[K, V] { + return &Map[K, V]{ + items: m, + } +} + +// Keys returns the keys of a thread-safe map. +func (m *Map[K, V]) Keys() []K { + m.lock.RLock() + defer m.lock.RUnlock() + r := make([]K, 0, len(m.items)) + for k := range m.items { + key := k + r = append(r, key) + } + return r +} + +// Len of the thread-safe map. +func (m *Map[K, V]) Len() int { + m.lock.RLock() + defer m.lock.RUnlock() + return len(m.items) +} + +// Get an item from a thread-safe map. +func (m *Map[K, V]) Get(k K) (V, bool) { + m.lock.RLock() + defer m.lock.RUnlock() + v, ok := m.items[k] + return v, ok +} + +// Put an item into a thread-safe map. +func (m *Map[K, V]) Put(k K, v V) { + m.lock.Lock() + defer m.lock.Unlock() + m.items[k] = v +} + +// Delete an item from a thread-safe map. +func (m *Map[K, V]) Delete(k K) { + m.lock.Lock() + defer m.lock.Unlock() + delete(m.items, k) +} diff --git a/container/thread-safe/map_test.go b/container/thread-safe/map_test.go new file mode 100644 index 0000000000..8bbac1148a --- /dev/null +++ b/container/thread-safe/map_test.go @@ -0,0 +1,97 @@ +package threadsafe + +import ( + "sort" + "sync" + "testing" + + "github.com/prysmaticlabs/prysm/v3/testing/require" +) + +type safeMap struct { + items map[int]string + lock sync.RWMutex +} + +func (s *safeMap) Get(k int) (string, bool) { + s.lock.RLock() + defer s.lock.RUnlock() + v, ok := s.items[k] + return v, ok +} + +func (s *safeMap) Put(i int, str string) { + s.lock.Lock() + defer s.lock.Unlock() + s.items[i] = str +} + +func (s *safeMap) Delete(i int) { + s.lock.Lock() + defer s.lock.Unlock() + delete(s.items, i) +} + +func BenchmarkMap_Concrete(b *testing.B) { + mm := &safeMap{ + items: make(map[int]string), + } + for i := 0; i < b.N; i++ { + for j := 0; j < 1000; j++ { + mm.Put(j, "foo") + mm.Get(j) + mm.Delete(j) + } + } +} + +func BenchmarkMap_Generic(b *testing.B) { + items := make(map[int]string) + mm := NewThreadSafeMap(items) + for i := 0; i < b.N; i++ { + for j := 0; j < 1000; j++ { + mm.Put(j, "foo") + mm.Get(j) + mm.Delete(j) + } + } +} + +func TestMap(t *testing.T) { + m := map[int]string{ + 1: "foo", + 200: "bar", + 10000: "baz", + } + + tMap := NewThreadSafeMap(m) + keys := tMap.Keys() + sort.IntSlice(keys).Sort() + + require.DeepEqual(t, []int{1, 200, 10000}, keys) + require.Equal(t, 3, tMap.Len()) + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(w *sync.WaitGroup, scopedMap *Map[int, string]) { + defer w.Done() + v, ok := scopedMap.Get(1) + require.Equal(t, true, ok) + require.Equal(t, "foo", v) + + scopedMap.Put(3, "nyan") + + v, ok = scopedMap.Get(3) + require.Equal(t, true, ok) + require.Equal(t, "nyan", v) + + }(&wg, tMap) + } + wg.Wait() + + tMap.Delete(3) + + _, ok := tMap.Get(3) + require.Equal(t, false, ok) +}