Added analyzer for detecting recursive/nested mutex read locks. (#10066)

* Added analyzer for detecting recursive/nested mutex read locks.

* Added type assertion checks and removed unused 'iTypes' directory

* Clean up

* Bazel file fixes

* Cleaned up code and added comments.

Co-authored-by: Raul Jordan <raul@prysmaticlabs.com>
This commit is contained in:
Chase Jeter
2022-01-14 01:12:21 -08:00
committed by GitHub
parent d73794ce72
commit 0ea4b02b8b
11 changed files with 758 additions and 0 deletions

View File

@@ -118,6 +118,13 @@
"tools/analyzers/shadowpredecl/testdata/shadow.go": "Analyzer testdata has to break rules"
}
},
"recursivelock": {
"exclude_files": {
"external/.*": "Third party code",
"rules_go_work-.*": "Third party code",
"tools/analyzers/recursivelock/testdata/.*": "Analyzer testdata has to break rules"
}
},
"nop": {
"exclude_files": {
"external/.*": "Third party code",

View File

@@ -0,0 +1,32 @@
load("@io_bazel_rules_go//go:def.bzl", "go_tool_library")
load("@prysm//tools/go:def.bzl", "go_library")
go_library(
name = "go_default_library",
srcs = ["analyzer.go"],
importpath = "github.com/prysmaticlabs/prysm/tools/analyzers/recursivelock",
visibility = ["//visibility:public"],
deps = [
"@org_golang_x_tools//go/analysis:go_default_library",
"@org_golang_x_tools//go/analysis/passes/inspect:go_default_library",
"@org_golang_x_tools//go/ast/astutil:go_default_library",
"@org_golang_x_tools//go/ast/inspector:go_default_library",
"@org_golang_x_tools//go/types/typeutil:go_default_library",
],
)
go_tool_library(
name = "go_tool_library",
srcs = ["analyzer.go"],
importpath = "github.com/prysmaticlabs/prysm/tools/analyzers/recursivelock",
visibility = ["//visibility:public"],
deps = [
"@org_golang_x_tools//go/analysis:go_tool_library",
"@org_golang_x_tools//go/analysis/passes/inspect:go_tool_library",
"@org_golang_x_tools//go/ast/astutil:go_tool_library",
"@org_golang_x_tools//go/ast/inspector:go_tool_library",
"@org_golang_x_tools//go/types/typeutil:go_tool_library",
],
)
# gazelle:exclude analyzer_test.go

View File

@@ -0,0 +1,475 @@
// Analyzer tool for detecting nested or recursive mutex read lock statements
package recursivelock
import (
"errors"
"fmt"
"go/ast"
"go/token"
"go/types"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/ast/inspector"
"golang.org/x/tools/go/types/typeutil"
)
// Analyzer runs static analysis.
var Analyzer = &analysis.Analyzer{
Name: "recursivelock",
Doc: "Checks for recursive or nested RLock calls",
Requires: []*analysis.Analyzer{inspect.Analyzer},
Run: run,
}
var errNestedRLock = errors.New("found recursive read lock call")
func run(pass *analysis.Pass) (interface{}, error) {
inspect, ok := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
if !ok {
return nil, errors.New("analyzer is not type *inspector.Inspector")
}
nodeFilter := []ast.Node{
(*ast.CallExpr)(nil),
(*ast.DeferStmt)(nil),
(*ast.FuncDecl)(nil),
(*ast.FuncLit)(nil),
(*ast.File)(nil),
(*ast.ReturnStmt)(nil),
}
var keepTrackOf tracker
inspect.Preorder(nodeFilter, func(node ast.Node) {
if keepTrackOf.funcLitEnd.IsValid() && node.Pos() <= keepTrackOf.funcLitEnd {
return
}
keepTrackOf.funcLitEnd = token.NoPos
if keepTrackOf.deferEnd.IsValid() && node.Pos() > keepTrackOf.deferEnd {
keepTrackOf.deferEnd = token.NoPos
} else if keepTrackOf.deferEnd.IsValid() {
return
}
if keepTrackOf.retEnd.IsValid() && node.Pos() > keepTrackOf.retEnd {
keepTrackOf.retEnd = token.NoPos
keepTrackOf.incFRU()
}
switch stmt := node.(type) {
case *ast.CallExpr:
call := getCallInfo(pass.TypesInfo, stmt)
if call == nil {
break
}
name := call.name
selMap := mapSelTypes(stmt, pass)
if selMap == nil {
break
}
if keepTrackOf.rLockSelector != nil {
if keepTrackOf.foundRLock > 0 {
if keepTrackOf.rLockSelector.isEqual(selMap, 0) {
pass.Reportf(
node.Pos(),
fmt.Sprintf(
"%v",
errNestedRLock,
),
)
} else {
if stack := hasNestedRLock(keepTrackOf.rLockSelector, selMap, call, inspect, pass, make(map[string]bool)); stack != "" {
pass.Reportf(
node.Pos(),
fmt.Sprintf(
"%v\n%v",
errNestedRLock,
stack,
),
)
}
}
}
if name == "RUnlock" && keepTrackOf.rLockSelector.isEqual(selMap, 1) {
keepTrackOf.deincFRU()
}
} else if name == "RLock" && keepTrackOf.foundRLock == 0 {
keepTrackOf.rLockSelector = selMap
keepTrackOf.incFRU()
}
case *ast.File:
keepTrackOf = tracker{}
case *ast.FuncDecl:
keepTrackOf = tracker{}
keepTrackOf.funcEnd = stmt.End()
case *ast.FuncLit:
if keepTrackOf.funcLitEnd == token.NoPos {
keepTrackOf.funcLitEnd = stmt.End()
}
case *ast.DeferStmt:
call := getCallInfo(pass.TypesInfo, stmt.Call)
if keepTrackOf.deferEnd == token.NoPos {
keepTrackOf.deferEnd = stmt.End()
}
if call != nil && call.name == "RUnlock" {
keepTrackOf.deferredRUnlock = true
}
case *ast.ReturnStmt:
if keepTrackOf.deferredRUnlock && keepTrackOf.retEnd == token.NoPos {
keepTrackOf.deincFRU()
keepTrackOf.retEnd = stmt.End()
}
}
})
return nil, nil
}
type tracker struct {
funcEnd token.Pos
retEnd token.Pos
deferEnd token.Pos
funcLitEnd token.Pos
deferredRUnlock bool
foundRLock int
rLockSelector *selIdentList
}
func (t tracker) String() string {
return fmt.Sprintf("funcEnd:%v\nretEnd:%v\ndeferEnd:%v\ndeferredRU:%v\nfoundRLock:%v\n", t.funcEnd, t.retEnd, t.deferEnd, t.deferredRUnlock, t.foundRLock)
}
func (t *tracker) deincFRU() {
if t.foundRLock > 0 {
t.foundRLock -= 1
}
}
func (t *tracker) incFRU() {
t.foundRLock += 1
}
// Stores the AST and type information of a single item in a selector expression
// For example, "a.b.c()", a selIdentNode might store the information for "a"
type selIdentNode struct {
next *selIdentNode
this *ast.Ident
typObj types.Object
}
// a list of selIdentNodes. Stores the information of an entire selector expression
// For example, each item in "a.b.c()" is stored as a node in this list, with the start node being "a"
type selIdentList struct {
start *selIdentNode
length int
current *selIdentNode // used for internal functions
currentIndex int // used for internal functions
}
// returns the next item in the list, and increments the counter keeping track of where we are in the list
func (s *selIdentList) next() (n *selIdentNode) {
n = s.current.next
if n != nil {
s.current = n
s.currentIndex++
}
return n
}
// reset resets the current node to the start node in the list
func (s *selIdentList) reset() {
s.current = s.start
s.currentIndex = 0
}
// isEqual returns true if two selIdentLists are equal to each other.
// The offset parameter tells how far in the list to check for equality.
// For example, a.b.c() and a.b.d() are equal with an offset of 1.
func (s *selIdentList) isEqual(s2 *selIdentList, offset int) bool {
if s2 == nil || (s.length != s2.length) {
return false
}
s.reset()
s2.reset()
for i := true; i; {
if !s.current.isEqual(s2.current) {
return false
}
if s.currentIndex < s.length-offset-1 && s.next() != nil {
s2.next()
} else {
i = false
}
}
return true
}
// getSub returns the shared beginning selIdentList of s and s2,
// if s contains all elements (except the last) of s2,
// and returns nil otherwise.
// For example, if s represents "a.b.c.d()" and s2 represents
// "a.b.e()", getSub will return a selIdentList representing "a.b".
// getSub returns nil if s2's length is greater than that of s
func (s *selIdentList) getSub(s2 *selIdentList) *selIdentList {
if s2 == nil || s2.length > s.length {
return nil
}
s.reset()
s2.reset()
for i := true; i; {
if !s.current.isEqual(s2.current) {
return nil
}
if s2.currentIndex != s2.length-2 { // might want to add a selNode.prev() func
s.next()
s2.next()
} else {
i = false
}
}
return &selIdentList{
start: s.current,
length: s.length - s.currentIndex,
current: s.current,
currentIndex: 0,
}
}
// changeRoot changes the first selIdentNode of a selIdentList
// to one with given *ast.Ident and types.Object
func (s *selIdentList) changeRoot(r *ast.Ident, t types.Object) {
selNode := &selIdentNode{
this: r,
next: s.start.next,
typObj: t,
}
if s.start == s.current {
s.start = selNode
s.current = selNode
} else {
s.start = selNode
}
}
func (s selIdentList) String() (str string) {
var temp *selIdentNode = s.start
str = fmt.Sprintf("length: %v\n[\n", s.length)
for i := 0; temp != nil; i++ {
if i == s.currentIndex {
str += "*"
}
str += fmt.Sprintf("%v: %v\n", i, temp)
temp = temp.next
}
str += "]"
return str
}
func (s *selIdentNode) isEqual(s2 *selIdentNode) bool {
return (s.this.Name == s2.this.Name) && (s.typObj == s2.typObj)
}
func (s selIdentNode) String() string {
return fmt.Sprintf("{ ident: '%v', type: '%v' }", s.this, s.typObj)
}
// mapSelTypes returns a selIdentList representation of the given call expression
func mapSelTypes(c *ast.CallExpr, pass *analysis.Pass) *selIdentList {
list := &selIdentList{}
valid := list.recurMapSelTypes(c.Fun, nil, pass.TypesInfo)
if !valid {
return nil
}
return list
}
// recursively identifies the type of each identity node in a selector expression
func (l *selIdentList) recurMapSelTypes(e ast.Expr, next *selIdentNode, t *types.Info) bool {
expr := astutil.Unparen(e)
l.length++
s := &selIdentNode{next: next}
switch stmt := expr.(type) {
case *ast.Ident:
s.this = stmt
s.typObj = t.ObjectOf(stmt)
case *ast.SelectorExpr:
s.this = stmt.Sel
if sel, ok := t.Selections[stmt]; ok {
s.typObj = sel.Obj() // method or field
} else {
s.typObj = t.Uses[stmt.Sel] // qualified identifier?
}
return l.recurMapSelTypes(stmt.X, s, t)
default:
return false
}
l.current = s
l.start = s
return true
}
type callInfo struct {
call *ast.CallExpr
id string // String representation of the type object
name string // type ID [either the name (if the function is exported) or the package/name if otherwise] of the function/method
}
// getCallInfo returns a *callInfo struct with call info
func getCallInfo(tInfo *types.Info, call *ast.CallExpr) (c *callInfo) {
c = &callInfo{}
c.call = call
f := typeutil.Callee(tInfo, call)
if f == nil {
return nil
}
if _, isBuiltin := f.(*types.Builtin); isBuiltin {
return nil
}
s, ok := f.Type().(*types.Signature)
if ok && interfaceMethod(s) {
return nil
}
c.id = f.String()
c.name = f.Id()
return c
}
func interfaceMethod(s *types.Signature) bool {
recv := s.Recv()
return recv != nil && types.IsInterface(recv.Type())
}
// hasNestedRLock returns a stack trace of the nested or recursive RLock within the declaration of a function/method call (given by call).
// If the call expression does not contain a nested or recursive RLock, hasNestedRLock returns an empty string.
// hasNestedRLock finds a nested or recursive RLock by recursively calling itself on any functions called by the function/method represented
// by callInfo.
func hasNestedRLock(fullRLockSelector *selIdentList, compareMap *selIdentList, call *callInfo, inspect *inspector.Inspector, pass *analysis.Pass, hist map[string]bool) (retStack string) {
var rLockSelector *selIdentList
f := pass.Fset
tInfo := pass.TypesInfo
cH := callHelper{
call: call.call,
fset: pass.Fset,
}
var node ast.Node = cH.identifyFuncLitBlock(cH.call.Fun) // this seems a bit redundant
var recv *ast.Ident
if node == (*ast.BlockStmt)(nil) {
subMap := fullRLockSelector.getSub(compareMap)
if subMap != nil {
rLockSelector = subMap
} else {
return "" // if this is not a local function literal call, and the selectors don't match up, then we can just return
}
node = findCallDeclarationNode(call, inspect, pass.TypesInfo)
if node == (*ast.FuncDecl)(nil) {
return ""
} else if castedNode, ok := node.(*ast.FuncDecl); ok && castedNode.Recv != nil {
recv = castedNode.Recv.List[0].Names[0]
rLockSelector.changeRoot(recv, pass.TypesInfo.ObjectOf(recv))
}
} else {
rLockSelector = fullRLockSelector // no need to find a submap, since this is a local function call
}
addition := fmt.Sprintf("\t%q at %v\n", call.name, f.Position(call.call.Pos()))
ast.Inspect(node, func(iNode ast.Node) bool {
switch stmt := iNode.(type) {
case *ast.CallExpr:
c := getCallInfo(tInfo, stmt)
if c == nil {
return false
}
name := c.name
selMap := mapSelTypes(stmt, pass)
if rLockSelector.isEqual(selMap, 0) { // if the method found is an RLock method
retStack += addition + fmt.Sprintf("\t%q at %v\n", name, f.Position(iNode.Pos()))
} else if name != "RUnlock" { // name should not equal the previousName to prevent infinite recursive loop
nt := c.id
if !hist[nt] { // make sure we are not in an infinite recursive loop
hist[nt] = true
stack := hasNestedRLock(rLockSelector, selMap, c, inspect, pass, hist)
delete(hist, nt)
if stack != "" {
retStack += addition + stack
}
}
}
}
return true
})
return retStack
}
// findCallDeclarationNode takes a callInfo struct and inspects the AST of the package
// to find a matching method or function declaration. It returns this declaration of type *ast.FuncDecl
func findCallDeclarationNode(c *callInfo, inspect *inspector.Inspector, tInfo *types.Info) *ast.FuncDecl {
var retNode *ast.FuncDecl = nil
nodeFilter := []ast.Node{
(*ast.FuncDecl)(nil),
}
inspect.Preorder(nodeFilter, func(node ast.Node) {
funcDec, ok := node.(*ast.FuncDecl)
if !ok {
return
}
compareId := tInfo.ObjectOf(funcDec.Name).String()
if c.id == compareId {
retNode = funcDec
}
})
return retNode
}
type callHelper struct {
call *ast.CallExpr
fset *token.FileSet
}
// identifyFuncLitBlock returns the AST block statement of the function literal called by the given expression,
// or nil if no function literal block statement could be identified.
func (c callHelper) identifyFuncLitBlock(expr ast.Expr) *ast.BlockStmt {
switch stmt := expr.(type) {
case *ast.FuncLit:
return stmt.Body
case *ast.Ident:
if stmt.Obj != nil {
switch objDecl := stmt.Obj.Decl.(type) {
case *ast.ValueSpec:
identIndex := findIdentIndex(stmt, objDecl.Names)
if identIndex != -1 && len(objDecl.Names) == len(objDecl.Values) {
value := objDecl.Values[identIndex]
return c.identifyFuncLitBlock(value)
}
case *ast.AssignStmt:
exprIndex := findIdentIndexFromExpr(stmt, objDecl.Lhs)
if exprIndex != -1 && len(objDecl.Lhs) == len(objDecl.Rhs) { // only deals with simple func lit assignments
value := objDecl.Rhs[exprIndex]
return c.identifyFuncLitBlock(value)
}
}
}
}
return nil
}
func findIdentIndex(id *ast.Ident, exprs []*ast.Ident) int {
for i, v := range exprs {
if v.Name == id.Name {
return i
}
}
return -1
}
func findIdentIndexFromExpr(id *ast.Ident, exprs []ast.Expr) int {
for i, v := range exprs {
if val, ok := v.(*ast.Ident); ok && val.Name == id.Name {
return i
}
}
return -1
}

View File

@@ -0,0 +1,11 @@
package recursivelock
import (
"testing"
"golang.org/x/tools/go/analysis/analysistest"
)
func TestAnalyzer(t *testing.T) {
analysistest.Run(t, analysistest.TestData(), Analyzer)
}

View File

@@ -0,0 +1,15 @@
load("@prysm//tools/go:def.bzl", "go_library")
go_library(
name = "go_default_library",
srcs = [
"badlockswithmethods.go",
"badlockswithstructs.go",
"complexlocks.go",
"globallocks.go",
"nonrlocks.go",
"types.go",
],
importpath = "github.com/prysmaticlabs/prysm/tools/analyzers/recursivelock/testdata",
visibility = ["//visibility:public"],
)

View File

@@ -0,0 +1,26 @@
// recursive read lock calls with methods
package testdata
func (p *ProtectResource) NestedMethod() {
p.RLock()
p.GetResource() // want `found recursive read lock call`
p.RUnlock()
}
func (p *ProtectResource) NestedMethod2() {
p.RLock()
p.GetResourceNested() // want `found recursive read lock call`
p.RUnlock()
}
func (p *NestedProtectResource) MultiLevelStruct() {
p.nestedPR.RLock()
p.nestedPR.GetResource() // want `found recursive read lock call`
p.nestedPR.RUnlock()
}
func (p *NestedProtectResource) MultiLevelStruct2() {
p.nestedPR.RLock()
p.GetNestedPResource() // want `found recursive read lock call`
p.nestedPR.RUnlock()
}

View File

@@ -0,0 +1,51 @@
package testdata
var resource *ProtectResource = &ProtectResource{resource: "protected"}
var nested *NestedResource = &NestedResource{p: ProtectResource{resource: "hello"}}
func NestedRLockWithStructs() {
resource.RLock()
resource.GetResource() // want `found recursive read lock call`
resource.RUnlock()
}
func NestedRLockWithParam(r *ProtectResource) {
r.RLock()
r.GetResource() // want `found recursive read lock call`
r.RUnlock()
}
func NestedRLockWithMoreStructs() {
nested.p.RLock()
nested.p.GetResource() // want `found recursive read lock call`
nested.p.RUnlock()
}
func NestedRLockWithFuncLit() {
resource.RLock()
var varFuncLit func() = func() {
resource.RLock()
}
varFuncLit() // want `found recursive read lock call`
assignFuncLit := func() {
resource.RLock()
}
assignFuncLit() // want `found recursive read lock call`
obfuscateFuncLit := varFuncLit
obfuscateFuncLit() // want `found recursive read lock call`
var multiVarFuncLit1, multiVarFuncLit2 func() = func() {
resource.RLock()
}, func() {
resource.RLock()
}
multiVarFuncLit1() // want `found recursive read lock call`
multiVarFuncLit2() // want `found recursive read lock call`
multiAssignFuncLit1, multiAssignFuncLit2 := func() {
resource.RLock()
}, func() {
resource.RLock()
}
multiAssignFuncLit1() // want `found recursive read lock call`
multiAssignFuncLit2() // want `found recursive read lock call`
resource.RUnlock()
}

View File

@@ -0,0 +1,23 @@
// These nested rlock patterns are too complex for the analyzer to catch right now
package testdata
func (p *ProtectResource) FuncLitInStructLit() {
p.RLock()
type funcLitContainer struct {
funcLit func()
}
var fl *funcLitContainer = &funcLitContainer{
funcLit: func() {
p.RLock()
},
}
fl.funcLit() // this is a nested RLock but won't be caught
p.RUnlock()
}
func (e *ExposedMutex) FuncReturnsMutex() {
e.GetLock().RLock()
e.lock.RLock() // this is an obvious nested lock, but won't be caught since the first RLock was called through a getter function
e.lock.RUnlock()
e.GetLock().RUnlock()
}

View File

@@ -0,0 +1,52 @@
// THE ANALYZER DOES NOT CATCH RLOCKS WHEN THE MUTEX VARIABLE IS OUTSIDE OF SCOPE LIKE THIS
package testdata
import (
"sync"
)
var mutex *sync.RWMutex
func RLockFuncs() {
regularRLock()
nestedRLock1Level()
nestedRLock2Levels()
deferredRLock()
recursiveRLock(true)
}
func regularRLock() {
mutex.RLock()
mutex.RUnlock()
}
func nestedRLock1Level() {
mutex.RLock()
regularRLock() // this is a nested RLock, but the analyzer will not pick it up
mutex.RUnlock()
}
func nestedRLock2Levels() {
mutex.RLock()
callRegularRLock()
mutex.RUnlock()
}
func callRegularRLock() {
regularRLock()
}
func deferredRLock() {
mutex.RLock()
defer mutex.RUnlock()
regularRLock()
}
func recursiveRLock(run bool) {
mutex.RLock()
defer mutex.RUnlock()
if run {
recursiveRLock(!run)
}
}

View File

@@ -0,0 +1,14 @@
// These are all non recursive rlocks. Testing to make sure there are no false positives
package testdata
func (resource *ProtectResource) NonNestedRLockWithDefer() string {
resource.RLock()
defer resource.RUnlock()
return resource.GetResource() // this is not a nested rlock because runlock is deferred
}
func (resource *NestedProtectResource) NonNestedRLockDifferentRLocks() {
resource.RLock()
resource.GetNestedPResource() // get nested resource uses RLock, but at a deeper level in the struct
resource.RUnlock()
}

View File

@@ -0,0 +1,52 @@
package testdata
import (
"sync"
)
type ProtectResource struct {
*sync.RWMutex
resource string
}
func (r *ProtectResource) GetResource() string {
defer r.RUnlock()
r.RLock()
return r.resource
}
func (r *ProtectResource) GetResourceNested() string {
return r.GetResource()
}
type NestedProtectResource struct {
*sync.RWMutex
nestedPR ProtectResource
}
func (r *NestedProtectResource) GetNestedPResource() string {
defer r.nestedPR.RUnlock()
r.nestedPR.RLock()
return r.nestedPR.resource
}
type NotProtected struct {
resource string
}
func (r *NotProtected) GetResource() string {
return r.resource
}
type NestedResource struct {
*NotProtected
p ProtectResource
}
type ExposedMutex struct {
lock *sync.RWMutex
}
func (e *ExposedMutex) GetLock() *sync.RWMutex {
return e.lock
}