Add Lock Analyzer (#10430)

* add lock analyzer

* fix locks

* progress

* fix failures

* fix error log

Co-authored-by: Raul Jordan <raul@prysmaticlabs.com>
This commit is contained in:
Nishant Das
2022-04-06 00:39:48 +08:00
committed by GitHub
parent 984575ed57
commit 2a7a09b112
6 changed files with 297 additions and 62 deletions

View File

@@ -25,6 +25,45 @@ var Analyzer = &analysis.Analyzer{
}
var errNestedRLock = errors.New("found recursive read lock call")
var errNestedLock = errors.New("found recursive lock call")
var errNestedMixedLock = errors.New("found recursive mixed lock call")
type mode int
const (
LockMode = mode(iota)
RLockMode
)
func (m mode) LockName() string {
switch m {
case LockMode:
return "Lock"
case RLockMode:
return "RLock"
}
return ""
}
func (m mode) UnLockName() string {
switch m {
case LockMode:
return "Unlock"
case RLockMode:
return "RUnlock"
}
return ""
}
func (m mode) ErrorFound() error {
switch m {
case LockMode:
return errNestedLock
case RLockMode:
return errNestedRLock
}
return nil
}
func run(pass *analysis.Pass) (interface{}, error) {
inspect, ok := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
@@ -33,6 +72,7 @@ func run(pass *analysis.Pass) (interface{}, error) {
}
nodeFilter := []ast.Node{
(*ast.GoStmt)(nil),
(*ast.CallExpr)(nil),
(*ast.DeferStmt)(nil),
(*ast.FuncDecl)(nil),
@@ -42,20 +82,36 @@ func run(pass *analysis.Pass) (interface{}, error) {
(*ast.ReturnStmt)(nil),
}
keepTrackOf := &tracker{}
keepTrackOf := &tracker{
rLockTrack: &lockTracker{},
lockTrack: &lockTracker{},
}
inspect.Preorder(nodeFilter, func(node ast.Node) {
if keepTrackOf.funcLitEnd.IsValid() && node.Pos() <= keepTrackOf.funcLitEnd {
if keepTrackOf.rLockTrack.funcLitEnd.IsValid() && node.Pos() <= keepTrackOf.rLockTrack.funcLitEnd &&
keepTrackOf.lockTrack.funcLitEnd.IsValid() && node.Pos() <= keepTrackOf.lockTrack.funcLitEnd {
return
}
keepTrackOf.funcLitEnd = token.NoPos
if keepTrackOf.deferEnd.IsValid() && node.Pos() > keepTrackOf.deferEnd {
keepTrackOf.deferEnd = token.NoPos
} else if keepTrackOf.deferEnd.IsValid() {
keepTrackOf.rLockTrack.funcLitEnd = token.NoPos
keepTrackOf.lockTrack.funcLitEnd = token.NoPos
if keepTrackOf.rLockTrack.deferEnd.IsValid() && node.Pos() > keepTrackOf.rLockTrack.deferEnd {
keepTrackOf.rLockTrack.deferEnd = token.NoPos
} else if keepTrackOf.rLockTrack.deferEnd.IsValid() {
return
}
if keepTrackOf.retEnd.IsValid() && node.Pos() > keepTrackOf.retEnd {
keepTrackOf.retEnd = token.NoPos
keepTrackOf.incFRU()
if keepTrackOf.lockTrack.deferEnd.IsValid() && node.Pos() > keepTrackOf.lockTrack.deferEnd {
keepTrackOf.lockTrack.deferEnd = token.NoPos
} else if keepTrackOf.lockTrack.deferEnd.IsValid() {
return
}
if keepTrackOf.rLockTrack.retEnd.IsValid() && node.Pos() > keepTrackOf.rLockTrack.retEnd {
keepTrackOf.rLockTrack.retEnd = token.NoPos
keepTrackOf.rLockTrack.incFRU()
}
if keepTrackOf.lockTrack.retEnd.IsValid() && node.Pos() > keepTrackOf.lockTrack.retEnd {
keepTrackOf.lockTrack.retEnd = token.NoPos
keepTrackOf.lockTrack.incFRU()
}
keepTrackOf = stmtSelector(node, pass, keepTrackOf, inspect)
})
@@ -64,57 +120,46 @@ func run(pass *analysis.Pass) (interface{}, error) {
func stmtSelector(node ast.Node, pass *analysis.Pass, keepTrackOf *tracker, inspect *inspector.Inspector) *tracker {
switch stmt := node.(type) {
case *ast.GoStmt:
keepTrackOf.rLockTrack.goroutinePos = stmt.Call.End()
keepTrackOf.lockTrack.goroutinePos = stmt.Call.End()
case *ast.CallExpr:
if stmt.End() == keepTrackOf.rLockTrack.goroutinePos ||
stmt.End() == keepTrackOf.lockTrack.goroutinePos {
keepTrackOf.rLockTrack.goroutinePos = 0
keepTrackOf.lockTrack.goroutinePos = 0
break
}
call := getCallInfo(pass.TypesInfo, stmt)
if call == nil {
break
}
name := call.name
selMap := mapSelTypes(stmt, pass)
if selMap == nil {
break
}
if keepTrackOf.rLockSelector != nil {
if keepTrackOf.foundRLock > 0 {
if keepTrackOf.rLockSelector.isEqual(selMap, 0) {
pass.Reportf(
node.Pos(),
fmt.Sprintf(
"%v",
errNestedRLock,
),
)
} else {
if stack := hasNestedRLock(keepTrackOf.rLockSelector, selMap, call, inspect, pass, make(map[string]bool)); stack != "" {
pass.Reportf(
node.Pos(),
fmt.Sprintf(
"%v\n%v",
errNestedRLock,
stack,
),
)
}
}
}
if name == "RUnlock" && keepTrackOf.rLockSelector.isEqual(selMap, 1) {
keepTrackOf.deincFRU()
}
} else if name == "RLock" && keepTrackOf.foundRLock == 0 {
keepTrackOf.rLockSelector = selMap
keepTrackOf.incFRU()
}
checkForRecLocks(node, pass, inspect, RLockMode, call, keepTrackOf.rLockTrack, selMap)
checkForRecLocks(node, pass, inspect, LockMode, call, keepTrackOf.lockTrack, selMap)
case *ast.File:
keepTrackOf = &tracker{}
keepTrackOf = &tracker{
rLockTrack: &lockTracker{},
lockTrack: &lockTracker{},
}
case *ast.FuncDecl:
keepTrackOf = &tracker{}
keepTrackOf.funcEnd = stmt.End()
keepTrackOf = &tracker{
rLockTrack: &lockTracker{},
lockTrack: &lockTracker{},
}
keepTrackOf.rLockTrack.funcEnd = stmt.End()
case *ast.FuncLit:
if keepTrackOf.funcLitEnd == token.NoPos {
keepTrackOf.funcLitEnd = stmt.End()
if keepTrackOf.rLockTrack.funcLitEnd == token.NoPos {
keepTrackOf.rLockTrack.funcLitEnd = stmt.End()
}
if keepTrackOf.lockTrack.funcLitEnd == token.NoPos {
keepTrackOf.lockTrack.funcLitEnd = stmt.End()
}
case *ast.IfStmt:
stmts := stmt.Body.List
@@ -124,48 +169,113 @@ func stmtSelector(node ast.Node, pass *analysis.Pass, keepTrackOf *tracker, insp
keepTrackOf = stmtSelector(stmt.Else, pass, keepTrackOf, inspect)
case *ast.DeferStmt:
call := getCallInfo(pass.TypesInfo, stmt.Call)
if keepTrackOf.deferEnd == token.NoPos {
keepTrackOf.deferEnd = stmt.End()
if keepTrackOf.rLockTrack.deferEnd == token.NoPos {
keepTrackOf.rLockTrack.deferEnd = stmt.End()
}
if call != nil && call.name == "RUnlock" {
keepTrackOf.deferredRUnlock = true
if keepTrackOf.lockTrack.deferEnd == token.NoPos {
keepTrackOf.lockTrack.deferEnd = stmt.End()
}
if call != nil && call.name == RLockMode.UnLockName() {
keepTrackOf.rLockTrack.deferredRUnlock = true
}
if call != nil && call.name == LockMode.UnLockName() {
keepTrackOf.lockTrack.deferredRUnlock = true
}
case *ast.ReturnStmt:
for i := 0; i < len(stmt.Results); i++ {
keepTrackOf = stmtSelector(stmt.Results[i], pass, keepTrackOf, inspect)
}
if keepTrackOf.deferredRUnlock && keepTrackOf.retEnd == token.NoPos {
keepTrackOf.deincFRU()
keepTrackOf.retEnd = stmt.End()
if keepTrackOf.rLockTrack.deferredRUnlock && keepTrackOf.rLockTrack.retEnd == token.NoPos {
keepTrackOf.rLockTrack.deincFRU()
keepTrackOf.rLockTrack.retEnd = stmt.End()
}
if keepTrackOf.lockTrack.deferredRUnlock && keepTrackOf.lockTrack.retEnd == token.NoPos {
keepTrackOf.lockTrack.deincFRU()
keepTrackOf.lockTrack.retEnd = stmt.End()
}
}
return keepTrackOf
}
type tracker struct {
rLockTrack *lockTracker
lockTrack *lockTracker
}
type lockTracker struct {
funcEnd token.Pos
retEnd token.Pos
deferEnd token.Pos
funcLitEnd token.Pos
goroutinePos token.Pos
deferredRUnlock bool
foundRLock int
rLockSelector *selIdentList
}
func (t tracker) String() string {
func (t lockTracker) String() string {
return fmt.Sprintf("funcEnd:%v\nretEnd:%v\ndeferEnd:%v\ndeferredRU:%v\nfoundRLock:%v\n", t.funcEnd, t.retEnd, t.deferEnd, t.deferredRUnlock, t.foundRLock)
}
func (t *tracker) deincFRU() {
func (t *lockTracker) deincFRU() {
if t.foundRLock > 0 {
t.foundRLock -= 1
}
}
func (t *tracker) incFRU() {
func (t *lockTracker) incFRU() {
t.foundRLock += 1
}
func checkForRecLocks(node ast.Node, pass *analysis.Pass, inspect *inspector.Inspector, lockmode mode, call *callInfo,
lockTracker *lockTracker, selMap *selIdentList) {
name := call.name
if lockTracker.rLockSelector != nil {
if lockTracker.foundRLock > 0 {
if lockTracker.rLockSelector.isRelated(selMap, 0) {
pass.Reportf(
node.Pos(),
fmt.Sprintf(
"%v",
errNestedMixedLock,
),
)
}
if lockTracker.rLockSelector.isEqual(selMap, 0) {
pass.Reportf(
node.Pos(),
fmt.Sprintf(
"%v",
lockmode.ErrorFound(),
),
)
} else {
if stack := hasNestedlock(lockTracker.rLockSelector, lockTracker.goroutinePos, selMap, call, inspect, pass, make(map[string]bool),
lockmode.UnLockName()); stack != "" {
pass.Reportf(
node.Pos(),
fmt.Sprintf(
"%v\n%v",
lockmode.ErrorFound(),
stack,
),
)
}
}
}
if name == lockmode.UnLockName() && lockTracker.rLockSelector.isEqual(selMap, 1) {
lockTracker.deincFRU()
}
if name == lockmode.LockName() && lockTracker.foundRLock == 0 && lockTracker.rLockSelector.isEqual(selMap, 0) {
lockTracker.incFRU()
}
} else if name == lockmode.LockName() && lockTracker.foundRLock == 0 {
lockTracker.rLockSelector = selMap
lockTracker.incFRU()
}
}
// Stores the AST and type information of a single item in a selector expression
// For example, "a.b.c()", a selIdentNode might store the information for "a"
type selIdentNode struct {
@@ -221,6 +331,45 @@ func (s *selIdentList) isEqual(s2 *selIdentList, offset int) bool {
return true
}
// isRelated checks if our selectors are of the same type and
// reference the same underlying object. If they do we check
// if the provided list is referencing a non-equal but related
// lock. Ex: Lock - RLock, RLock - Lock
// TODO: Use a generalizable method here instead of hardcoding
// the lock definitions here.
func (s *selIdentList) isRelated(s2 *selIdentList, offset int) bool {
if s2 == nil || (s.length != s2.length) {
return false
}
s.reset()
s2.reset()
for i := true; i; {
if !s.current.isEqual(s2.current) {
return false
}
if s.currentIndex < s.length-offset-1 && s.next() != nil {
s2.next()
} else {
i = false
}
// Only check if we are at the last index for
// related method calls.
if s.currentIndex == s.length-1 {
switch s.current.this.String() {
case LockMode.LockName():
if s2.current.this.String() == RLockMode.LockName() {
return true
}
case RLockMode.LockName():
if s2.current.this.String() == LockMode.LockName() {
return true
}
}
}
}
return false
}
// getSub returns the shared beginning selIdentList of s and s2,
// if s contains all elements (except the last) of s2,
// and returns nil otherwise.
@@ -356,11 +505,12 @@ func interfaceMethod(s *types.Signature) bool {
return recv != nil && types.IsInterface(recv.Type())
}
// hasNestedRLock returns a stack trace of the nested or recursive RLock within the declaration of a function/method call (given by call).
// If the call expression does not contain a nested or recursive RLock, hasNestedRLock returns an empty string.
// hasNestedRLock finds a nested or recursive RLock by recursively calling itself on any functions called by the function/method represented
// hasNestedlock returns a stack trace of the nested or recursive lock within the declaration of a function/method call (given by call).
// If the call expression does not contain a nested or recursive lock, hasNestedlock returns an empty string.
// hasNestedlock finds a nested or recursive lock by recursively calling itself on any functions called by the function/method represented
// by callInfo.
func hasNestedRLock(fullRLockSelector *selIdentList, compareMap *selIdentList, call *callInfo, inspect *inspector.Inspector, pass *analysis.Pass, hist map[string]bool) (retStack string) {
func hasNestedlock(fullRLockSelector *selIdentList, goPos token.Pos, compareMap *selIdentList, call *callInfo, inspect *inspector.Inspector,
pass *analysis.Pass, hist map[string]bool, lockName string) (retStack string) {
var rLockSelector *selIdentList
f := pass.Fset
tInfo := pass.TypesInfo
@@ -390,20 +540,26 @@ func hasNestedRLock(fullRLockSelector *selIdentList, compareMap *selIdentList, c
addition := fmt.Sprintf("\t%q at %v\n", call.name, f.Position(call.call.Pos()))
ast.Inspect(node, func(iNode ast.Node) bool {
switch stmt := iNode.(type) {
case *ast.GoStmt:
goPos = stmt.End()
case *ast.CallExpr:
if stmt.End() == goPos {
goPos = 0
return false
}
c := getCallInfo(tInfo, stmt)
if c == nil {
return false
}
name := c.name
selMap := mapSelTypes(stmt, pass)
if rLockSelector.isEqual(selMap, 0) { // if the method found is an RLock method
if rLockSelector.isEqual(selMap, 0) || rLockSelector.isRelated(selMap, 0) { // if the method found is an RLock method
retStack += addition + fmt.Sprintf("\t%q at %v\n", name, f.Position(iNode.Pos()))
} else if name != "RUnlock" { // name should not equal the previousName to prevent infinite recursive loop
} else if name != lockName { // name should not equal the previousName to prevent infinite recursive loop
nt := c.id
if !hist[nt] { // make sure we are not in an infinite recursive loop
hist[nt] = true
stack := hasNestedRLock(rLockSelector, selMap, c, inspect, pass, hist)
stack := hasNestedlock(rLockSelector, goPos, selMap, c, inspect, pass, hist, lockName)
delete(hist, nt)
if stack != "" {
retStack += addition + stack

View File

@@ -13,6 +13,31 @@ func (p *ProtectResource) NestedMethod2() {
p.RUnlock()
}
func (p *ProtectResource) NestedMethodMixedLock() {
p.Lock()
p.GetResource() // want `found recursive lock call`
p.Unlock()
}
func (p *ProtectResource) MixedLock() {
p.RLock()
p.Lock() // want `found recursive mixed lock call`
p.Unlock()
p.RUnlock()
}
func (p *ProtectResource) NestedMethodGoroutine() {
p.RLock()
defer p.RUnlock()
go p.GetResource()
}
func (p *ProtectResource) NestedResourceGoroutine() {
p.RLock()
defer p.RUnlock()
p.GetResourceNestedGoroutine()
}
func (p *NestedProtectResource) MultiLevelStruct() {
p.nestedPR.RLock()
p.nestedPR.GetResource() // want `found recursive read lock call`

View File

@@ -15,9 +15,30 @@ func (p *ProtectResource) FuncLitInStructLit() {
p.RUnlock()
}
func (p *ProtectResource) FuncLitInStructLitLocked() {
p.Lock()
type funcLitContainer struct {
funcLit func()
}
var fl *funcLitContainer = &funcLitContainer{
funcLit: func() {
p.Lock()
},
}
fl.funcLit() // this is a nested Lock but won't be caught
p.Unlock()
}
func (e *ExposedMutex) FuncReturnsMutex() {
e.GetLock().RLock()
e.lock.RLock() // this is an obvious nested lock, but won't be caught since the first RLock was called through a getter function
e.lock.RUnlock()
e.GetLock().RUnlock()
}
func (e *ExposedMutex) FuncReturnsMutexLocked() {
e.GetLock().Lock()
e.lock.Lock() // this is an obvious nested lock, but won't be caught since the first RLock was called through a getter function
e.lock.Unlock()
e.GetLock().Unlock()
}

View File

@@ -12,3 +12,15 @@ func (resource *NestedProtectResource) NonNestedRLockDifferentRLocks() {
resource.GetNestedPResource() // get nested resource uses RLock, but at a deeper level in the struct
resource.RUnlock()
}
func (resource *ProtectResource) NestedLockWithDefer() string {
resource.Lock()
defer resource.Unlock()
return resource.GetResourceLocked() // want `found recursive lock call`
}
func (resource *NestedProtectResource) NonNestedLockDifferentLocks() {
resource.Lock()
resource.GetNestedPResourceLocked() // get nested resource uses RLock, but at a deeper level in the struct
resource.Unlock()
}

View File

@@ -15,10 +15,24 @@ func (r *ProtectResource) GetResource() string {
return r.resource
}
func (r *ProtectResource) GetResourceLocked() string {
defer r.Unlock()
r.Lock()
return r.resource
}
func (r *ProtectResource) GetResourceNested() string {
return r.GetResource()
}
func (r *ProtectResource) GetResourceNestedGoroutine() {
go r.GetResource()
}
func (r *ProtectResource) GetResourceNestedLock() string {
return r.GetResourceLocked()
}
type NestedProtectResource struct {
*sync.RWMutex
nestedPR ProtectResource
@@ -30,6 +44,12 @@ func (r *NestedProtectResource) GetNestedPResource() string {
return r.nestedPR.resource
}
func (r *NestedProtectResource) GetNestedPResourceLocked() string {
defer r.nestedPR.Unlock()
r.nestedPR.Lock()
return r.nestedPR.resource
}
type NotProtected struct {
resource string
}