mirror of
https://github.com/OffchainLabs/prysm.git
synced 2026-01-08 23:18:15 -05:00
Move Shared Packages into Math/ and IO/ (#9622)
* amend * building * build * userprompt * imports * build val * gaz * io file Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com>
This commit is contained in:
24
io/file/BUILD.bazel
Normal file
24
io/file/BUILD.bazel
Normal file
@@ -0,0 +1,24 @@
|
||||
load("@prysm//tools/go:def.bzl", "go_library", "go_test")
|
||||
|
||||
go_library(
|
||||
name = "go_default_library",
|
||||
srcs = ["fileutil.go"],
|
||||
importpath = "github.com/prysmaticlabs/prysm/io/file",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//shared/params:go_default_library",
|
||||
"@com_github_pkg_errors//:go_default_library",
|
||||
"@com_github_sirupsen_logrus//:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "go_default_test",
|
||||
srcs = ["fileutil_test.go"],
|
||||
deps = [
|
||||
":go_default_library",
|
||||
"//shared/params:go_default_library",
|
||||
"//shared/testutil/assert:go_default_library",
|
||||
"//shared/testutil/require:go_default_library",
|
||||
],
|
||||
)
|
||||
323
io/file/fileutil.go
Normal file
323
io/file/fileutil.go
Normal file
@@ -0,0 +1,323 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/user"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/prysmaticlabs/prysm/shared/params"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ExpandPath given a string which may be a relative path.
|
||||
// 1. replace tilde with users home dir
|
||||
// 2. expands embedded environment variables
|
||||
// 3. cleans the path, e.g. /a/b/../c -> /a/c
|
||||
// Note, it has limitations, e.g. ~someuser/tmp will not be expanded
|
||||
func ExpandPath(p string) (string, error) {
|
||||
if strings.HasPrefix(p, "~/") || strings.HasPrefix(p, "~\\") {
|
||||
if home := HomeDir(); home != "" {
|
||||
p = home + p[1:]
|
||||
}
|
||||
}
|
||||
return filepath.Abs(path.Clean(os.ExpandEnv(p)))
|
||||
}
|
||||
|
||||
// HandleBackupDir takes an input directory path and either alters its permissions to be usable if it already exists, creates it if not
|
||||
func HandleBackupDir(dirPath string, permissionOverride bool) error {
|
||||
expanded, err := ExpandPath(dirPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
exists, err := HasDir(expanded)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
info, err := os.Stat(expanded)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if info.Mode().Perm() != params.BeaconIoConfig().ReadWriteExecutePermissions {
|
||||
if permissionOverride {
|
||||
if err := os.Chmod(expanded, params.BeaconIoConfig().ReadWriteExecutePermissions); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return errors.New("dir already exists without proper 0700 permissions")
|
||||
}
|
||||
}
|
||||
}
|
||||
return os.MkdirAll(expanded, params.BeaconIoConfig().ReadWriteExecutePermissions)
|
||||
}
|
||||
|
||||
// MkdirAll takes in a path, expands it if necessary, and looks through the
|
||||
// permissions of every directory along the path, ensuring we are not attempting
|
||||
// to overwrite any existing permissions. Finally, creates the directory accordingly
|
||||
// with standardized, Prysm project permissions. This is the static-analysis enforced
|
||||
// method for creating a directory programmatically in Prysm.
|
||||
func MkdirAll(dirPath string) error {
|
||||
expanded, err := ExpandPath(dirPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
exists, err := HasDir(expanded)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
info, err := os.Stat(expanded)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if info.Mode().Perm() != params.BeaconIoConfig().ReadWriteExecutePermissions {
|
||||
return errors.New("dir already exists without proper 0700 permissions")
|
||||
}
|
||||
}
|
||||
return os.MkdirAll(expanded, params.BeaconIoConfig().ReadWriteExecutePermissions)
|
||||
}
|
||||
|
||||
// WriteFile is the static-analysis enforced method for writing binary data to a file
|
||||
// in Prysm, enforcing a single entrypoint with standardized permissions.
|
||||
func WriteFile(file string, data []byte) error {
|
||||
expanded, err := ExpandPath(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if FileExists(expanded) {
|
||||
info, err := os.Stat(expanded)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if info.Mode() != params.BeaconIoConfig().ReadWritePermissions {
|
||||
return errors.New("file already exists without proper 0600 permissions")
|
||||
}
|
||||
}
|
||||
return ioutil.WriteFile(expanded, data, params.BeaconIoConfig().ReadWritePermissions)
|
||||
}
|
||||
|
||||
// HomeDir for a user.
|
||||
func HomeDir() string {
|
||||
if home := os.Getenv("HOME"); home != "" {
|
||||
return home
|
||||
}
|
||||
if usr, err := user.Current(); err == nil {
|
||||
return usr.HomeDir
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// HasDir checks if a directory indeed exists at the specified path.
|
||||
func HasDir(dirPath string) (bool, error) {
|
||||
fullPath, err := ExpandPath(dirPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
info, err := os.Stat(fullPath)
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
if info == nil {
|
||||
return false, err
|
||||
}
|
||||
return info.IsDir(), err
|
||||
}
|
||||
|
||||
// HasReadWritePermissions checks if file at a path has proper
|
||||
// 0600 permissions set.
|
||||
func HasReadWritePermissions(itemPath string) (bool, error) {
|
||||
info, err := os.Stat(itemPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return info.Mode() == params.BeaconIoConfig().ReadWritePermissions, nil
|
||||
}
|
||||
|
||||
// FileExists returns true if a file is not a directory and exists
|
||||
// at the specified path.
|
||||
func FileExists(filename string) bool {
|
||||
filePath, err := ExpandPath(filename)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
info, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
log.WithError(err).Info("Checking for file existence returned an error")
|
||||
}
|
||||
return false
|
||||
}
|
||||
return info != nil && !info.IsDir()
|
||||
}
|
||||
|
||||
// RecursiveFileFind returns true, and the path, if a file is not a directory and exists
|
||||
// at dir or any of its subdirectories. Finds the first instant based on the Walk order and returns.
|
||||
// Define non-fatal error to stop the recursive directory walk
|
||||
var stopWalk = errors.New("stop walking")
|
||||
|
||||
// RecursiveFileFind searches for file in a directory and its subdirectories.
|
||||
func RecursiveFileFind(filename, dir string) (bool, string, error) {
|
||||
var found bool
|
||||
var fpath string
|
||||
dir = filepath.Clean(dir)
|
||||
found = false
|
||||
err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// checks if its a file and has the exact name as the filename
|
||||
// need to break the walk function by using a non-fatal error
|
||||
if !info.IsDir() && filename == info.Name() {
|
||||
found = true
|
||||
fpath = path
|
||||
return stopWalk
|
||||
}
|
||||
|
||||
// no errors or file found
|
||||
return nil
|
||||
})
|
||||
if err != nil && err != stopWalk {
|
||||
return false, "", err
|
||||
}
|
||||
return found, fpath, nil
|
||||
}
|
||||
|
||||
// ReadFileAsBytes expands a file name's absolute path and reads it as bytes from disk.
|
||||
func ReadFileAsBytes(filename string) ([]byte, error) {
|
||||
filePath, err := ExpandPath(filename)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not determine absolute path of password file")
|
||||
}
|
||||
return ioutil.ReadFile(filePath) // #nosec G304
|
||||
}
|
||||
|
||||
// CopyFile copy a file from source to destination path.
|
||||
func CopyFile(src, dst string) error {
|
||||
if !FileExists(src) {
|
||||
return errors.New("source file does not exist at provided path")
|
||||
}
|
||||
f, err := os.Open(src) // #nosec G304
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, params.BeaconIoConfig().ReadWritePermissions) // #nosec G304
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = io.Copy(dstFile, f)
|
||||
return err
|
||||
}
|
||||
|
||||
// CopyDir copies contents of one directory into another, recursively.
|
||||
func CopyDir(src, dst string) error {
|
||||
dstExists, err := HasDir(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if dstExists {
|
||||
return errors.New("destination directory already exists")
|
||||
}
|
||||
fds, err := ioutil.ReadDir(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := MkdirAll(dst); err != nil {
|
||||
return errors.Wrapf(err, "error creating directory: %s", dst)
|
||||
}
|
||||
for _, fd := range fds {
|
||||
srcPath := path.Join(src, fd.Name())
|
||||
dstPath := path.Join(dst, fd.Name())
|
||||
if fd.IsDir() {
|
||||
if err = CopyDir(srcPath, dstPath); err != nil {
|
||||
return errors.Wrapf(err, "error copying directory %s -> %s", srcPath, dstPath)
|
||||
}
|
||||
} else {
|
||||
if err = CopyFile(srcPath, dstPath); err != nil {
|
||||
return errors.Wrapf(err, "error copying file %s -> %s", srcPath, dstPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DirsEqual checks whether two directories have the same content.
|
||||
func DirsEqual(src, dst string) bool {
|
||||
hash1, err := HashDir(src)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
hash2, err := HashDir(dst)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return hash1 == hash2
|
||||
}
|
||||
|
||||
// HashDir calculates and returns hash of directory: each file's hash is calculated and saved along
|
||||
// with the file name into the list, after which list is hashed to produce the final signature.
|
||||
// Implementation is based on https://github.com/golang/mod/blob/release-branch.go1.15/sumdb/dirhash/hash.go
|
||||
func HashDir(dir string) (string, error) {
|
||||
files, err := DirFiles(dir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
h := sha256.New()
|
||||
files = append([]string(nil), files...)
|
||||
sort.Strings(files)
|
||||
for _, file := range files {
|
||||
fd, err := os.Open(filepath.Join(dir, file)) // #nosec G304
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
hf := sha256.New()
|
||||
_, err = io.Copy(hf, fd)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := fd.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := fmt.Fprintf(h, "%x %s\n", hf.Sum(nil), file); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return "hashdir:" + base64.StdEncoding.EncodeToString(h.Sum(nil)), nil
|
||||
}
|
||||
|
||||
// DirFiles returns list of files found within a given directory and its sub-directories.
|
||||
// Directory prefix will not be included as a part of returned file string i.e. for a file located
|
||||
// in "dir/foo/bar" only "foo/bar" part will be returned.
|
||||
func DirFiles(dir string) ([]string, error) {
|
||||
var files []string
|
||||
dir = filepath.Clean(dir)
|
||||
err := filepath.Walk(dir, func(file string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
relFile := file
|
||||
if dir != "." {
|
||||
relFile = file[len(dir)+1:]
|
||||
}
|
||||
files = append(files, filepath.ToSlash(relFile))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return files, nil
|
||||
}
|
||||
434
io/file/fileutil_test.go
Normal file
434
io/file/fileutil_test.go
Normal file
@@ -0,0 +1,434 @@
|
||||
// Copyright 2015 The go-ethereum Authors
|
||||
// This file is part of go-ethereum.
|
||||
//
|
||||
// go-ethereum is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// go-ethereum is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
|
||||
package file_test
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/io/file"
|
||||
"github.com/prysmaticlabs/prysm/shared/params"
|
||||
"github.com/prysmaticlabs/prysm/shared/testutil/assert"
|
||||
"github.com/prysmaticlabs/prysm/shared/testutil/require"
|
||||
)
|
||||
|
||||
func TestPathExpansion(t *testing.T) {
|
||||
user, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
tests := map[string]string{
|
||||
"/home/someuser/tmp": "/home/someuser/tmp",
|
||||
"~/tmp": user.HomeDir + "/tmp",
|
||||
"$DDDXXX/a/b": "/tmp/a/b",
|
||||
"/a/b/": "/a/b",
|
||||
}
|
||||
require.NoError(t, os.Setenv("DDDXXX", "/tmp"))
|
||||
for test, expected := range tests {
|
||||
expanded, err := file.ExpandPath(test)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, expanded)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMkdirAll_AlreadyExists_WrongPermissions(t *testing.T) {
|
||||
dirName := t.TempDir() + "somedir"
|
||||
err := os.MkdirAll(dirName, os.ModePerm)
|
||||
require.NoError(t, err)
|
||||
err = file.MkdirAll(dirName)
|
||||
assert.ErrorContains(t, "already exists without proper 0700 permissions", err)
|
||||
}
|
||||
|
||||
func TestMkdirAll_AlreadyExists_Override(t *testing.T) {
|
||||
dirName := t.TempDir() + "somedir"
|
||||
err := os.MkdirAll(dirName, params.BeaconIoConfig().ReadWriteExecutePermissions)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, file.MkdirAll(dirName))
|
||||
}
|
||||
|
||||
func TestHandleBackupDir_AlreadyExists_Override(t *testing.T) {
|
||||
dirName := t.TempDir() + "somedir"
|
||||
err := os.MkdirAll(dirName, os.ModePerm)
|
||||
require.NoError(t, err)
|
||||
info, err := os.Stat(dirName)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "drwxr-xr-x", info.Mode().String())
|
||||
assert.NoError(t, file.HandleBackupDir(dirName, true))
|
||||
info, err = os.Stat(dirName)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "drwx------", info.Mode().String())
|
||||
}
|
||||
|
||||
func TestHandleBackupDir_AlreadyExists_No_Override(t *testing.T) {
|
||||
dirName := t.TempDir() + "somedir"
|
||||
err := os.MkdirAll(dirName, os.ModePerm)
|
||||
require.NoError(t, err)
|
||||
info, err := os.Stat(dirName)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "drwxr-xr-x", info.Mode().String())
|
||||
err = file.HandleBackupDir(dirName, false)
|
||||
assert.ErrorContains(t, "dir already exists without proper 0700 permissions", err)
|
||||
info, err = os.Stat(dirName)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "drwxr-xr-x", info.Mode().String())
|
||||
}
|
||||
|
||||
func TestHandleBackupDir_NewDir(t *testing.T) {
|
||||
dirName := t.TempDir() + "somedir"
|
||||
require.NoError(t, file.HandleBackupDir(dirName, true))
|
||||
info, err := os.Stat(dirName)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "drwx------", info.Mode().String())
|
||||
}
|
||||
|
||||
func TestMkdirAll_OK(t *testing.T) {
|
||||
dirName := t.TempDir() + "somedir"
|
||||
err := file.MkdirAll(dirName)
|
||||
assert.NoError(t, err)
|
||||
exists, err := file.HasDir(dirName)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, true, exists)
|
||||
}
|
||||
|
||||
func TestWriteFile_AlreadyExists_WrongPermissions(t *testing.T) {
|
||||
dirName := t.TempDir() + "somedir"
|
||||
err := os.MkdirAll(dirName, os.ModePerm)
|
||||
require.NoError(t, err)
|
||||
someFileName := filepath.Join(dirName, "somefile.txt")
|
||||
require.NoError(t, ioutil.WriteFile(someFileName, []byte("hi"), os.ModePerm))
|
||||
err = file.WriteFile(someFileName, []byte("hi"))
|
||||
assert.ErrorContains(t, "already exists without proper 0600 permissions", err)
|
||||
}
|
||||
|
||||
func TestWriteFile_AlreadyExists_OK(t *testing.T) {
|
||||
dirName := t.TempDir() + "somedir"
|
||||
err := os.MkdirAll(dirName, os.ModePerm)
|
||||
require.NoError(t, err)
|
||||
someFileName := filepath.Join(dirName, "somefile.txt")
|
||||
require.NoError(t, ioutil.WriteFile(someFileName, []byte("hi"), params.BeaconIoConfig().ReadWritePermissions))
|
||||
assert.NoError(t, file.WriteFile(someFileName, []byte("hi")))
|
||||
}
|
||||
|
||||
func TestWriteFile_OK(t *testing.T) {
|
||||
dirName := t.TempDir() + "somedir"
|
||||
err := os.MkdirAll(dirName, os.ModePerm)
|
||||
require.NoError(t, err)
|
||||
someFileName := filepath.Join(dirName, "somefile.txt")
|
||||
require.NoError(t, file.WriteFile(someFileName, []byte("hi")))
|
||||
exists := file.FileExists(someFileName)
|
||||
assert.Equal(t, true, exists)
|
||||
}
|
||||
|
||||
func TestCopyFile(t *testing.T) {
|
||||
fName := t.TempDir() + "testfile"
|
||||
err := ioutil.WriteFile(fName, []byte{1, 2, 3}, params.BeaconIoConfig().ReadWritePermissions)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = file.CopyFile(fName, fName+"copy")
|
||||
assert.NoError(t, err)
|
||||
defer func() {
|
||||
assert.NoError(t, os.Remove(fName+"copy"))
|
||||
}()
|
||||
|
||||
assert.Equal(t, true, deepCompare(t, fName, fName+"copy"))
|
||||
}
|
||||
|
||||
func TestCopyDir(t *testing.T) {
|
||||
tmpDir1 := t.TempDir()
|
||||
tmpDir2 := filepath.Join(t.TempDir(), "copyfolder")
|
||||
type fileDesc struct {
|
||||
path string
|
||||
content []byte
|
||||
}
|
||||
fds := []fileDesc{
|
||||
{
|
||||
path: "testfile1",
|
||||
content: []byte{1, 2, 3},
|
||||
},
|
||||
{
|
||||
path: "subfolder1/testfile1",
|
||||
content: []byte{4, 5, 6},
|
||||
},
|
||||
{
|
||||
path: "subfolder1/testfile2",
|
||||
content: []byte{7, 8, 9},
|
||||
},
|
||||
{
|
||||
path: "subfolder2/testfile1",
|
||||
content: []byte{10, 11, 12},
|
||||
},
|
||||
{
|
||||
path: "testfile2",
|
||||
content: []byte{13, 14, 15},
|
||||
},
|
||||
}
|
||||
require.NoError(t, os.MkdirAll(filepath.Join(tmpDir1, "subfolder1"), 0777))
|
||||
require.NoError(t, os.MkdirAll(filepath.Join(tmpDir1, "subfolder2"), 0777))
|
||||
for _, fd := range fds {
|
||||
require.NoError(t, file.WriteFile(filepath.Join(tmpDir1, fd.path), fd.content))
|
||||
assert.Equal(t, true, file.FileExists(filepath.Join(tmpDir1, fd.path)))
|
||||
assert.Equal(t, false, file.FileExists(filepath.Join(tmpDir2, fd.path)))
|
||||
}
|
||||
|
||||
// Make sure that files are copied into non-existent directory only. If directory exists function exits.
|
||||
assert.ErrorContains(t, "destination directory already exists", file.CopyDir(tmpDir1, t.TempDir()))
|
||||
require.NoError(t, file.CopyDir(tmpDir1, tmpDir2))
|
||||
|
||||
// Now, all files should have been copied.
|
||||
for _, fd := range fds {
|
||||
assert.Equal(t, true, file.FileExists(filepath.Join(tmpDir2, fd.path)))
|
||||
assert.Equal(t, true, deepCompare(t, filepath.Join(tmpDir1, fd.path), filepath.Join(tmpDir2, fd.path)))
|
||||
}
|
||||
assert.Equal(t, true, file.DirsEqual(tmpDir1, tmpDir2))
|
||||
}
|
||||
|
||||
func TestDirsEqual(t *testing.T) {
|
||||
t.Run("non-existent source directory", func(t *testing.T) {
|
||||
assert.Equal(t, false, file.DirsEqual(filepath.Join(t.TempDir(), "nonexistent"), t.TempDir()))
|
||||
})
|
||||
|
||||
t.Run("non-existent dest directory", func(t *testing.T) {
|
||||
assert.Equal(t, false, file.DirsEqual(t.TempDir(), filepath.Join(t.TempDir(), "nonexistent")))
|
||||
})
|
||||
|
||||
t.Run("non-empty directory", func(t *testing.T) {
|
||||
// Start with directories that do not have the same contents.
|
||||
tmpDir1, tmpFileNames := tmpDirWithContents(t)
|
||||
tmpDir2 := filepath.Join(t.TempDir(), "newfolder")
|
||||
assert.Equal(t, false, file.DirsEqual(tmpDir1, tmpDir2))
|
||||
|
||||
// Copy dir, and retest (hashes should match now).
|
||||
require.NoError(t, file.CopyDir(tmpDir1, tmpDir2))
|
||||
assert.Equal(t, true, file.DirsEqual(tmpDir1, tmpDir2))
|
||||
|
||||
// Tamper the data, make sure that hashes do not match anymore.
|
||||
require.NoError(t, os.Remove(filepath.Join(tmpDir1, tmpFileNames[2])))
|
||||
assert.Equal(t, false, file.DirsEqual(tmpDir1, tmpDir2))
|
||||
})
|
||||
}
|
||||
|
||||
func TestHashDir(t *testing.T) {
|
||||
t.Run("non-existent directory", func(t *testing.T) {
|
||||
hash, err := file.HashDir(filepath.Join(t.TempDir(), "nonexistent"))
|
||||
assert.ErrorContains(t, "no such file or directory", err)
|
||||
assert.Equal(t, "", hash)
|
||||
})
|
||||
|
||||
t.Run("empty directory", func(t *testing.T) {
|
||||
hash, err := file.HashDir(t.TempDir())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "hashdir:47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU=", hash)
|
||||
})
|
||||
|
||||
t.Run("non-empty directory", func(t *testing.T) {
|
||||
tmpDir, _ := tmpDirWithContents(t)
|
||||
hash, err := file.HashDir(tmpDir)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "hashdir:oSp9wRacwTIrnbgJWcwTvihHfv4B2zRbLYa0GZ7DDk0=", hash)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDirFiles(t *testing.T) {
|
||||
tmpDir, tmpDirFnames := tmpDirWithContents(t)
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
outFiles []string
|
||||
}{
|
||||
{
|
||||
name: "dot path",
|
||||
path: filepath.Join(tmpDir, "/./"),
|
||||
outFiles: tmpDirFnames,
|
||||
},
|
||||
{
|
||||
name: "non-empty folder",
|
||||
path: tmpDir,
|
||||
outFiles: tmpDirFnames,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
outFiles, err := file.DirFiles(tt.path)
|
||||
require.NoError(t, err)
|
||||
|
||||
sort.Strings(outFiles)
|
||||
assert.DeepEqual(t, tt.outFiles, outFiles)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecursiveFileFind(t *testing.T) {
|
||||
tmpDir, _ := tmpDirWithContentsForRecursiveFind(t)
|
||||
tests := []struct {
|
||||
name string
|
||||
root string
|
||||
path string
|
||||
found bool
|
||||
}{
|
||||
{
|
||||
name: "file1",
|
||||
root: tmpDir,
|
||||
path: "subfolder1/subfolder11/file1",
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "file2",
|
||||
root: tmpDir,
|
||||
path: "subfolder2/file2",
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "file1",
|
||||
root: tmpDir + "/subfolder1",
|
||||
path: "subfolder11/file1",
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "file3",
|
||||
root: tmpDir,
|
||||
path: "file3",
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "file4",
|
||||
root: tmpDir,
|
||||
path: "",
|
||||
found: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
found, _, err := file.RecursiveFileFind(tt.name, tt.root)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.DeepEqual(t, tt.found, found)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func deepCompare(t *testing.T, file1, file2 string) bool {
|
||||
sf, err := os.Open(file1)
|
||||
assert.NoError(t, err)
|
||||
df, err := os.Open(file2)
|
||||
assert.NoError(t, err)
|
||||
sscan := bufio.NewScanner(sf)
|
||||
dscan := bufio.NewScanner(df)
|
||||
|
||||
for sscan.Scan() && dscan.Scan() {
|
||||
if !bytes.Equal(sscan.Bytes(), dscan.Bytes()) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// tmpDirWithContents returns path to temporary directory having some folders/files in it.
|
||||
// Directory is automatically removed by internal testing cleanup methods.
|
||||
func tmpDirWithContents(t *testing.T) (string, []string) {
|
||||
dir := t.TempDir()
|
||||
fnames := []string{
|
||||
"file1",
|
||||
"file2",
|
||||
"subfolder1/file1",
|
||||
"subfolder1/file2",
|
||||
"subfolder1/subfolder11/file1",
|
||||
"subfolder1/subfolder11/file2",
|
||||
"subfolder1/subfolder12/file1",
|
||||
"subfolder1/subfolder12/file2",
|
||||
"subfolder2/file1",
|
||||
}
|
||||
require.NoError(t, os.MkdirAll(filepath.Join(dir, "subfolder1", "subfolder11"), 0777))
|
||||
require.NoError(t, os.MkdirAll(filepath.Join(dir, "subfolder1", "subfolder12"), 0777))
|
||||
require.NoError(t, os.MkdirAll(filepath.Join(dir, "subfolder2"), 0777))
|
||||
for _, fname := range fnames {
|
||||
require.NoError(t, ioutil.WriteFile(filepath.Join(dir, fname), []byte(fname), 0777))
|
||||
}
|
||||
sort.Strings(fnames)
|
||||
return dir, fnames
|
||||
}
|
||||
|
||||
// tmpDirWithContentsForRecursiveFind returns path to temporary directory having some folders/files in it.
|
||||
// Directory is automatically removed by internal testing cleanup methods.
|
||||
func tmpDirWithContentsForRecursiveFind(t *testing.T) (string, []string) {
|
||||
dir := t.TempDir()
|
||||
fnames := []string{
|
||||
"subfolder1/subfolder11/file1",
|
||||
"subfolder2/file2",
|
||||
"file3",
|
||||
}
|
||||
require.NoError(t, os.MkdirAll(filepath.Join(dir, "subfolder1", "subfolder11"), 0777))
|
||||
require.NoError(t, os.MkdirAll(filepath.Join(dir, "subfolder2"), 0777))
|
||||
for _, fname := range fnames {
|
||||
require.NoError(t, ioutil.WriteFile(filepath.Join(dir, fname), []byte(fname), 0777))
|
||||
}
|
||||
sort.Strings(fnames)
|
||||
return dir, fnames
|
||||
}
|
||||
|
||||
func TestHasReadWritePermissions(t *testing.T) {
|
||||
type args struct {
|
||||
itemPath string
|
||||
perms os.FileMode
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "0600 permissions returns true",
|
||||
args: args{
|
||||
itemPath: "somefile",
|
||||
perms: params.BeaconIoConfig().ReadWritePermissions,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "other permissions returns false",
|
||||
args: args{
|
||||
itemPath: "somefile2",
|
||||
perms: params.BeaconIoConfig().ReadWriteExecutePermissions,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
fullPath := filepath.Join(os.TempDir(), tt.args.itemPath)
|
||||
require.NoError(t, ioutil.WriteFile(fullPath, []byte("foo"), tt.args.perms))
|
||||
t.Cleanup(func() {
|
||||
if err := os.RemoveAll(fullPath); err != nil {
|
||||
t.Fatalf("Could not delete temp dir: %v", err)
|
||||
}
|
||||
})
|
||||
got, err := file.HasReadWritePermissions(fullPath)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("HasReadWritePermissions() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("HasReadWritePermissions() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
29
io/logs/BUILD.bazel
Normal file
29
io/logs/BUILD.bazel
Normal file
@@ -0,0 +1,29 @@
|
||||
load("@prysm//tools/go:def.bzl", "go_library", "go_test")
|
||||
|
||||
go_library(
|
||||
name = "go_default_library",
|
||||
srcs = [
|
||||
"logutil.go",
|
||||
"stream.go",
|
||||
],
|
||||
importpath = "github.com/prysmaticlabs/prysm/io/logs",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//cache/lru:go_default_library",
|
||||
"//crypto/rand:go_default_library",
|
||||
"//shared/event:go_default_library",
|
||||
"//shared/params:go_default_library",
|
||||
"@com_github_hashicorp_golang_lru//:go_default_library",
|
||||
"@com_github_sirupsen_logrus//:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "go_default_test",
|
||||
srcs = [
|
||||
"logutil_test.go",
|
||||
"stream_test.go",
|
||||
],
|
||||
embed = [":go_default_library"],
|
||||
deps = ["//shared/testutil/require:go_default_library"],
|
||||
)
|
||||
55
io/logs/logutil.go
Normal file
55
io/logs/logutil.go
Normal file
@@ -0,0 +1,55 @@
|
||||
// Package logs creates a Multi writer instance that
|
||||
// write all logs that are written to stdout.
|
||||
package logs
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/shared/params"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func addLogWriter(w io.Writer) {
|
||||
mw := io.MultiWriter(logrus.StandardLogger().Out, w)
|
||||
logrus.SetOutput(mw)
|
||||
}
|
||||
|
||||
// ConfigurePersistentLogging adds a log-to-file writer. File content is identical to stdout.
|
||||
func ConfigurePersistentLogging(logFileName string) error {
|
||||
logrus.WithField("logFileName", logFileName).Info("Logs will be made persistent")
|
||||
f, err := os.OpenFile(logFileName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, params.BeaconIoConfig().ReadWritePermissions) // #nosec G304
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addLogWriter(f)
|
||||
|
||||
logrus.Info("File logging initialized")
|
||||
return nil
|
||||
}
|
||||
|
||||
// MaskCredentialsLogging masks the url credentials before logging for security purpose
|
||||
// [scheme:][//[userinfo@]host][/]path[?query][#fragment] --> [scheme:][//[***]host][/***][#***]
|
||||
// if the format is not matched nothing is done, string is returned as is.
|
||||
func MaskCredentialsLogging(currUrl string) string {
|
||||
// error if the input is not a URL
|
||||
MaskedUrl := currUrl
|
||||
u, err := url.Parse(currUrl)
|
||||
if err != nil {
|
||||
return currUrl // Not a URL, nothing to do
|
||||
}
|
||||
// Mask the userinfo and the URI (path?query or opaque?query ) and fragment, leave the scheme and host(host/port) untouched
|
||||
if u.User != nil {
|
||||
MaskedUrl = strings.Replace(MaskedUrl, u.User.String(), "***", 1)
|
||||
}
|
||||
if len(u.RequestURI()) > 1 { // Ignore the '/'
|
||||
MaskedUrl = strings.Replace(MaskedUrl, u.RequestURI(), "/***", 1)
|
||||
}
|
||||
if len(u.Fragment) > 0 {
|
||||
MaskedUrl = strings.Replace(MaskedUrl, u.RawFragment, "***", 1)
|
||||
}
|
||||
return MaskedUrl
|
||||
}
|
||||
26
io/logs/logutil_test.go
Normal file
26
io/logs/logutil_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package logs
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/shared/testutil/require"
|
||||
)
|
||||
|
||||
var urltests = []struct {
|
||||
url string
|
||||
maskedUrl string
|
||||
}{
|
||||
{"https://a:b@xyz.net", "https://***@xyz.net"},
|
||||
{"https://eth-goerli.alchemyapi.io/v2/tOZG5mjl3.zl_nZdZTNIBUzsDq62R_dkOtY",
|
||||
"https://eth-goerli.alchemyapi.io/***"},
|
||||
{"https://google.com/search?q=golang", "https://google.com/***"},
|
||||
{"https://user@example.com/foo%2fbar", "https://***@example.com/***"},
|
||||
{"http://john@example.com/#x/y%2Fz", "http://***@example.com/#***"},
|
||||
{"https://me:pass@example.com/foo/bar?x=1&y=2", "https://***@example.com/***"},
|
||||
}
|
||||
|
||||
func TestMaskCredentialsLogging(t *testing.T) {
|
||||
for _, test := range urltests {
|
||||
require.Equal(t, MaskCredentialsLogging(test.url), test.maskedUrl)
|
||||
}
|
||||
}
|
||||
69
io/logs/stream.go
Normal file
69
io/logs/stream.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package logs
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
lru "github.com/hashicorp/golang-lru"
|
||||
lruwrpr "github.com/prysmaticlabs/prysm/cache/lru"
|
||||
"github.com/prysmaticlabs/prysm/crypto/rand"
|
||||
"github.com/prysmaticlabs/prysm/shared/event"
|
||||
)
|
||||
|
||||
const (
|
||||
// The number of log entries to keep in memory.
|
||||
logCacheSize = 20
|
||||
)
|
||||
|
||||
var (
|
||||
// Compile time interface checks.
|
||||
_ = io.Writer(&StreamServer{})
|
||||
_ = Streamer(&StreamServer{})
|
||||
)
|
||||
|
||||
// Streamer defines a struct which can retrieve and stream process logs.
|
||||
type Streamer interface {
|
||||
GetLastFewLogs() [][]byte
|
||||
LogsFeed() *event.Feed
|
||||
}
|
||||
|
||||
// StreamServer defines a a websocket server which can receive events from
|
||||
// a feed and write them to open websocket connections.
|
||||
type StreamServer struct {
|
||||
feed *event.Feed
|
||||
cache *lru.Cache
|
||||
}
|
||||
|
||||
// NewStreamServer initializes a new stream server capable of
|
||||
// streaming log events.
|
||||
func NewStreamServer() *StreamServer {
|
||||
ss := &StreamServer{
|
||||
feed: new(event.Feed),
|
||||
cache: lruwrpr.New(logCacheSize),
|
||||
}
|
||||
addLogWriter(ss)
|
||||
return ss
|
||||
}
|
||||
|
||||
// GetLastFewLogs returns the last few entries of logs stored in an LRU cache.
|
||||
func (ss *StreamServer) GetLastFewLogs() [][]byte {
|
||||
messages := make([][]byte, 0)
|
||||
for _, k := range ss.cache.Keys() {
|
||||
d, ok := ss.cache.Get(k)
|
||||
if ok {
|
||||
messages = append(messages, d.([]byte))
|
||||
}
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
// LogsFeed returns a feed callers can subscribe to to receive logs via a channel.
|
||||
func (ss *StreamServer) LogsFeed() *event.Feed {
|
||||
return ss.feed
|
||||
}
|
||||
|
||||
// Write a binary message and send over the event feed.
|
||||
func (ss *StreamServer) Write(p []byte) (n int, err error) {
|
||||
ss.feed.Send(p)
|
||||
ss.cache.Add(rand.NewGenerator().Uint64(), p)
|
||||
return len(p), nil
|
||||
}
|
||||
23
io/logs/stream_test.go
Normal file
23
io/logs/stream_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package logs
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/shared/testutil/require"
|
||||
)
|
||||
|
||||
func TestStreamServer_BackfillsMessages(t *testing.T) {
|
||||
ss := NewStreamServer()
|
||||
msgs := [][]byte{
|
||||
[]byte("foo"),
|
||||
[]byte("bar"),
|
||||
[]byte("buzz"),
|
||||
}
|
||||
for _, msg := range msgs {
|
||||
_, err := ss.Write(msg)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
recentMessages := ss.GetLastFewLogs()
|
||||
require.DeepEqual(t, msgs, recentMessages)
|
||||
}
|
||||
30
io/prompt/BUILD.bazel
Normal file
30
io/prompt/BUILD.bazel
Normal file
@@ -0,0 +1,30 @@
|
||||
load("@prysm//tools/go:def.bzl", "go_library", "go_test")
|
||||
|
||||
go_library(
|
||||
name = "go_default_library",
|
||||
srcs = [
|
||||
"prompt.go",
|
||||
"validate.go",
|
||||
],
|
||||
importpath = "github.com/prysmaticlabs/prysm/io/prompt",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//io/file:go_default_library",
|
||||
"@com_github_logrusorgru_aurora//:go_default_library",
|
||||
"@com_github_nbutton23_zxcvbn_go//:go_default_library",
|
||||
"@com_github_pkg_errors//:go_default_library",
|
||||
"@com_github_sirupsen_logrus//:go_default_library",
|
||||
"@com_github_urfave_cli_v2//:go_default_library",
|
||||
"@org_golang_x_crypto//ssh/terminal:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "go_default_test",
|
||||
srcs = ["validate_test.go"],
|
||||
embed = [":go_default_library"],
|
||||
deps = [
|
||||
"//shared/testutil/assert:go_default_library",
|
||||
"//shared/testutil/require:go_default_library",
|
||||
],
|
||||
)
|
||||
172
io/prompt/prompt.go
Normal file
172
io/prompt/prompt.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/logrusorgru/aurora"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/prysmaticlabs/prysm/io/file"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/urfave/cli/v2"
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
)
|
||||
|
||||
var au = aurora.NewAurora(true)
|
||||
|
||||
// PasswordReaderFunc takes in a *file and returns a password using the terminal package
|
||||
func passwordReaderFunc(file *os.File) ([]byte, error) {
|
||||
pass, err := terminal.ReadPassword(int(file.Fd()))
|
||||
return pass, err
|
||||
}
|
||||
|
||||
// PasswordReader has passwordReaderFunc as the default but can be changed for testing purposes.
|
||||
var PasswordReader = passwordReaderFunc
|
||||
|
||||
// ValidatePrompt requests the user for text and expects the user to fulfill the provided validation function.
|
||||
func ValidatePrompt(r io.Reader, promptText string, validateFunc func(string) error) (string, error) {
|
||||
var responseValid bool
|
||||
var response string
|
||||
for !responseValid {
|
||||
fmt.Printf("%s:\n", au.Bold(promptText))
|
||||
scanner := bufio.NewScanner(r)
|
||||
if ok := scanner.Scan(); ok {
|
||||
item := scanner.Text()
|
||||
response = strings.TrimRight(item, "\r\n")
|
||||
if err := validateFunc(response); err != nil {
|
||||
fmt.Printf("Entry not valid: %s\n", au.BrightRed(err))
|
||||
} else {
|
||||
responseValid = true
|
||||
}
|
||||
} else {
|
||||
return "", errors.New("could not scan text input")
|
||||
}
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// DefaultPrompt prompts the user for any text and performs no validation. If nothing is entered it returns the default.
|
||||
func DefaultPrompt(promptText, defaultValue string) (string, error) {
|
||||
var response string
|
||||
if defaultValue != "" {
|
||||
fmt.Printf("%s %s:\n", promptText, fmt.Sprintf("(%s: %s)", au.BrightGreen("default"), defaultValue))
|
||||
} else {
|
||||
fmt.Printf("%s:\n", promptText)
|
||||
}
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
if ok := scanner.Scan(); ok {
|
||||
item := scanner.Text()
|
||||
response = strings.TrimRight(item, "\r\n")
|
||||
if response == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
return "", errors.New("could not scan text input")
|
||||
}
|
||||
|
||||
// DefaultAndValidatePrompt prompts the user for any text and expects it to fulfill a validation function. If nothing is entered
|
||||
// the default value is returned.
|
||||
func DefaultAndValidatePrompt(promptText, defaultValue string, validateFunc func(string) error) (string, error) {
|
||||
var responseValid bool
|
||||
var response string
|
||||
for !responseValid {
|
||||
fmt.Printf("%s %s:\n", promptText, fmt.Sprintf("(%s: %s)", au.BrightGreen("default"), defaultValue))
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
if ok := scanner.Scan(); ok {
|
||||
item := scanner.Text()
|
||||
response = strings.TrimRight(item, "\r\n")
|
||||
if response == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
if err := validateFunc(response); err != nil {
|
||||
fmt.Printf("Entry not valid: %s\n", au.BrightRed(err))
|
||||
} else {
|
||||
responseValid = true
|
||||
}
|
||||
} else {
|
||||
return "", errors.New("could not scan text input")
|
||||
}
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// PasswordPrompt prompts the user for a password, that repeatedly requests the password until it qualifies the
|
||||
// passed in validation function.
|
||||
func PasswordPrompt(promptText string, validateFunc func(string) error) (string, error) {
|
||||
var responseValid bool
|
||||
var response string
|
||||
for !responseValid {
|
||||
fmt.Printf("%s: ", au.Bold(promptText))
|
||||
bytePassword, err := PasswordReader(os.Stdin)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
response = strings.TrimRight(string(bytePassword), "\r\n")
|
||||
if err := validateFunc(response); err != nil {
|
||||
fmt.Printf("\nEntry not valid: %s\n", au.BrightRed(err))
|
||||
} else {
|
||||
fmt.Println("")
|
||||
responseValid = true
|
||||
}
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// InputPassword with a custom validator along capabilities of confirming
|
||||
// the password and reading it from disk if a specified flag is set.
|
||||
func InputPassword(
|
||||
cliCtx *cli.Context,
|
||||
passwordFileFlag *cli.StringFlag,
|
||||
promptText, confirmText string,
|
||||
shouldConfirmPassword bool,
|
||||
passwordValidator func(input string) error,
|
||||
) (string, error) {
|
||||
if cliCtx.IsSet(passwordFileFlag.Name) {
|
||||
passwordFilePathInput := cliCtx.String(passwordFileFlag.Name)
|
||||
passwordFilePath, err := file.ExpandPath(passwordFilePathInput)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "could not determine absolute path of password file")
|
||||
}
|
||||
data, err := ioutil.ReadFile(passwordFilePath) // #nosec G304
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "could not read password file")
|
||||
}
|
||||
enteredPassword := strings.TrimRight(string(data), "\r\n")
|
||||
if err := passwordValidator(enteredPassword); err != nil {
|
||||
return "", errors.Wrap(err, "password did not pass validation")
|
||||
}
|
||||
return enteredPassword, nil
|
||||
}
|
||||
if strings.Contains(strings.ToLower(promptText), "new wallet") {
|
||||
fmt.Println("Password requirements: at least 8 characters including at least 1 alphabetical character, 1 number, and 1 unicode special character. " +
|
||||
"Must not be a common password nor easy to guess")
|
||||
}
|
||||
var hasValidPassword bool
|
||||
var password string
|
||||
var err error
|
||||
for !hasValidPassword {
|
||||
password, err = PasswordPrompt(promptText, passwordValidator)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not read password: %w", err)
|
||||
}
|
||||
if shouldConfirmPassword {
|
||||
passwordConfirmation, err := PasswordPrompt(confirmText, passwordValidator)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not read password confirmation: %w", err)
|
||||
}
|
||||
if password != passwordConfirmation {
|
||||
log.Error("Passwords do not match")
|
||||
continue
|
||||
}
|
||||
hasValidPassword = true
|
||||
} else {
|
||||
return password, nil
|
||||
}
|
||||
}
|
||||
return password, nil
|
||||
}
|
||||
124
io/prompt/validate.go
Normal file
124
io/prompt/validate.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
strongPasswords "github.com/nbutton23/zxcvbn-go"
|
||||
)
|
||||
|
||||
const (
|
||||
// Constants for passwords.
|
||||
minPasswordLength = 8
|
||||
// Min password score of 2 out of 5 based on the https://github.com/nbutton23/zxcvbn-go
|
||||
// library for strong-entropy password computation.
|
||||
minPasswordScore = 2
|
||||
)
|
||||
|
||||
var (
|
||||
errIncorrectPhrase = errors.New("input does not match wanted phrase")
|
||||
errPasswordWeak = errors.New("password must have at least 8 characters, at least 1 alphabetical character, 1 unicode symbol, and 1 number")
|
||||
)
|
||||
|
||||
// NotEmpty is a validation function to make sure the input given isn't empty and is valid unicode.
|
||||
func NotEmpty(input string) error {
|
||||
if input == "" {
|
||||
return errors.New("input cannot be empty")
|
||||
}
|
||||
if !IsValidUnicode(input) {
|
||||
return errors.New("not valid unicode")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateNumber makes sure the entered text is a valid number.
|
||||
func ValidateNumber(input string) error {
|
||||
_, err := strconv.Atoi(input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateConfirmation makes sure the entered text is the user confirming.
|
||||
func ValidateConfirmation(input string) error {
|
||||
if input != "Y" && input != "y" {
|
||||
return errors.New("please confirm the above text")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateYesOrNo ensures the user input either Y, y or N, n.
|
||||
func ValidateYesOrNo(input string) error {
|
||||
lowercase := strings.ToLower(input)
|
||||
if lowercase != "y" && lowercase != "n" {
|
||||
return errors.New("please enter y or n")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsValidUnicode checks if an input string is a valid unicode string comprised of only
|
||||
// letters, numbers, punctuation, or symbols.
|
||||
func IsValidUnicode(input string) bool {
|
||||
for _, char := range input {
|
||||
if !(unicode.IsLetter(char) ||
|
||||
unicode.IsNumber(char) ||
|
||||
unicode.IsPunct(char) ||
|
||||
unicode.IsSymbol(char) ||
|
||||
unicode.IsSpace(char)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ValidatePasswordInput validates a strong password input for new accounts,
|
||||
// including a min length, at least 1 number and at least
|
||||
// 1 special character.
|
||||
func ValidatePasswordInput(input string) error {
|
||||
var (
|
||||
hasMinLen = false
|
||||
hasLetter = false
|
||||
hasNumber = false
|
||||
hasSpecial = false
|
||||
)
|
||||
if len(input) >= minPasswordLength {
|
||||
hasMinLen = true
|
||||
}
|
||||
for _, char := range input {
|
||||
switch {
|
||||
case !(unicode.IsSpace(char) ||
|
||||
unicode.IsLetter(char) ||
|
||||
unicode.IsNumber(char) ||
|
||||
unicode.IsPunct(char) ||
|
||||
unicode.IsSymbol(char)):
|
||||
return errors.New("password must only contain unicode alphanumeric characters, numbers, or unicode symbols")
|
||||
case unicode.IsLetter(char):
|
||||
hasLetter = true
|
||||
case unicode.IsNumber(char):
|
||||
hasNumber = true
|
||||
case unicode.IsPunct(char) || unicode.IsSymbol(char):
|
||||
hasSpecial = true
|
||||
}
|
||||
}
|
||||
if !(hasMinLen && hasLetter && hasNumber && hasSpecial) {
|
||||
return errPasswordWeak
|
||||
}
|
||||
strength := strongPasswords.PasswordStrength(input, nil)
|
||||
if strength.Score < minPasswordScore {
|
||||
return errors.New(
|
||||
"password is too easy to guess, try a stronger password",
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidatePhrase checks whether the user input is equal to the wanted phrase. The verification is case sensitive.
|
||||
func ValidatePhrase(input, wantedPhrase string) error {
|
||||
if strings.TrimSpace(input) != wantedPhrase {
|
||||
return errIncorrectPhrase
|
||||
}
|
||||
return nil
|
||||
}
|
||||
190
io/prompt/validate_test.go
Normal file
190
io/prompt/validate_test.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/shared/testutil/assert"
|
||||
"github.com/prysmaticlabs/prysm/shared/testutil/require"
|
||||
)
|
||||
|
||||
func TestValidatePasswordInput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantedErr string
|
||||
}{
|
||||
{
|
||||
name: "no numbers nor special characters",
|
||||
input: "abcdefghijklmnopqrs",
|
||||
wantedErr: errPasswordWeak.Error(),
|
||||
},
|
||||
{
|
||||
name: "number and letters but no special characters",
|
||||
input: "abcdefghijklmnopqrs2020",
|
||||
wantedErr: errPasswordWeak.Error(),
|
||||
},
|
||||
{
|
||||
name: "numbers, letters, special characters, but too short",
|
||||
input: "abc2$",
|
||||
wantedErr: errPasswordWeak.Error(),
|
||||
},
|
||||
{
|
||||
name: "proper length and strong password",
|
||||
input: "%Str0ngpassword32kjAjsd22020$%",
|
||||
},
|
||||
{
|
||||
name: "password format correct but weak entropy score",
|
||||
input: "aaaaaaa1$",
|
||||
wantedErr: "password is too easy to guess, try a stronger password",
|
||||
},
|
||||
{
|
||||
name: "allow spaces",
|
||||
input: "x*329293@aAJSD i22903saj",
|
||||
},
|
||||
{
|
||||
name: "strong password from LastPass",
|
||||
input: "jXl!q5pkQnXsyT6dbJ3X5plQ!9%iqJCTr&*UIoaDu#b6GYJD##^GI3qniKdr240f",
|
||||
},
|
||||
{
|
||||
name: "allow underscores",
|
||||
input: "jXl!q5pkQn_syT6dbJ3X5plQ_9_iqJCTr_*UIoaDu#b6GYJD##^GI3qniKdr240f",
|
||||
},
|
||||
{
|
||||
name: "only numbers and symbols should fail",
|
||||
input: "123493489223423_23923929",
|
||||
wantedErr: errPasswordWeak.Error(),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidatePasswordInput(tt.input)
|
||||
if tt.wantedErr != "" {
|
||||
assert.ErrorContains(t, tt.wantedErr, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidUnicode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "Regular alphanumeric",
|
||||
input: "Someone23xx",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Unicode strings separated by a space character",
|
||||
input: "x*329293@aAJSD i22903saj",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Japanese",
|
||||
input: "僕は絵お見るのが好きです",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "Other foreign",
|
||||
input: "Etérium",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := IsValidUnicode(tt.input); got != tt.want {
|
||||
t.Errorf("isValidUnicode() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAndValidatePrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
def string
|
||||
want string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "number",
|
||||
input: "3",
|
||||
def: "0",
|
||||
want: "3",
|
||||
},
|
||||
{
|
||||
name: "empty return default",
|
||||
input: "",
|
||||
def: "0",
|
||||
want: "0",
|
||||
},
|
||||
{
|
||||
name: "empty return default no zero",
|
||||
input: "",
|
||||
def: "3",
|
||||
want: "3",
|
||||
},
|
||||
{
|
||||
name: "empty return default, no zero",
|
||||
input: "a",
|
||||
def: "0",
|
||||
want: "",
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
content := []byte(tt.input + "\n")
|
||||
tmpfile, err := ioutil.TempFile("", "content")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := os.Remove(tmpfile.Name())
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
_, err = tmpfile.Write(content)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = tmpfile.Seek(0, 0)
|
||||
require.NoError(t, err)
|
||||
oldStdin := os.Stdin
|
||||
defer func() { os.Stdin = oldStdin }() // Restore original Stdin
|
||||
os.Stdin = tmpfile
|
||||
got, err := DefaultAndValidatePrompt(tt.name, tt.def, ValidateNumber)
|
||||
if !tt.wantError {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.want, got)
|
||||
err = tmpfile.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePhrase(t *testing.T) {
|
||||
wantedPhrase := "wanted phrase"
|
||||
|
||||
t.Run("correct input", func(t *testing.T) {
|
||||
assert.NoError(t, ValidatePhrase(wantedPhrase, wantedPhrase))
|
||||
})
|
||||
t.Run("correct input with whitespace", func(t *testing.T) {
|
||||
assert.NoError(t, ValidatePhrase(" wanted phrase ", wantedPhrase))
|
||||
})
|
||||
t.Run("incorrect input", func(t *testing.T) {
|
||||
err := ValidatePhrase("foo", wantedPhrase)
|
||||
assert.NotNil(t, err)
|
||||
assert.ErrorContains(t, errIncorrectPhrase.Error(), err)
|
||||
})
|
||||
t.Run("wrong letter case", func(t *testing.T) {
|
||||
err := ValidatePhrase("Wanted Phrase", wantedPhrase)
|
||||
assert.NotNil(t, err)
|
||||
assert.ErrorContains(t, errIncorrectPhrase.Error(), err)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user