From 9e6e1dc950f06bbd000d5b6438f39113e8902082 Mon Sep 17 00:00:00 2001 From: zeripath Date: Wed, 8 Dec 2021 19:08:16 +0000 Subject: [PATCH] Improve checkBranchName (#17901) The current implementation of checkBranchName is highly inefficient involving opening the repository, the listing all of the branch names checking them individually before then using using opened repo to get the tags. This PR avoids this by simply walking the references from show-ref instead of opening the repository (in the nogogit case). Signed-off-by: Andrew Thornton --- modules/context/repo.go | 4 +-- modules/git/repo_branch.go | 11 +++++-- modules/git/repo_branch_gogit.go | 26 ++++++++++++++- modules/git/repo_branch_nogogit.go | 50 ++++++++++++++++++++-------- modules/git/repo_branch_test.go | 8 ++--- routers/web/repo/branch.go | 4 +-- routers/web/repo/compare.go | 4 +-- routers/web/repo/issue.go | 2 +- services/repository/adopt.go | 2 +- services/repository/branch.go | 53 ++++++++++++++---------------- 10 files changed, 106 insertions(+), 58 deletions(-) diff --git a/modules/context/repo.go b/modules/context/repo.go index 159fd07d9d..b2844c04c4 100644 --- a/modules/context/repo.go +++ b/modules/context/repo.go @@ -584,7 +584,7 @@ func RepoAssignment(ctx *Context) (cancel context.CancelFunc) { } ctx.Data["Tags"] = tags - brs, _, err := ctx.Repo.GitRepo.GetBranches(0, 0) + brs, _, err := ctx.Repo.GitRepo.GetBranchNames(0, 0) if err != nil { ctx.ServerError("GetBranches", err) return @@ -810,7 +810,7 @@ func RepoRefByType(refType RepoRefType, ignoreNotExistErr ...bool) func(*Context if len(ctx.Params("*")) == 0 { refName = ctx.Repo.Repository.DefaultBranch if !ctx.Repo.GitRepo.IsBranchExist(refName) { - brs, _, err := ctx.Repo.GitRepo.GetBranches(0, 0) + brs, _, err := ctx.Repo.GitRepo.GetBranchNames(0, 0) if err != nil { ctx.ServerError("GetBranches", err) return diff --git a/modules/git/repo_branch.go b/modules/git/repo_branch.go index 98b1bc8ae7..01933d7ade 100644 --- a/modules/git/repo_branch.go +++ b/modules/git/repo_branch.go @@ -95,7 +95,12 @@ func GetBranchesByPath(path string, skip, limit int) ([]*Branch, int, error) { } defer gitRepo.Close() - brs, countAll, err := gitRepo.GetBranches(skip, limit) + return gitRepo.GetBranches(skip, limit) +} + +// GetBranches returns a slice of *git.Branch +func (repo *Repository) GetBranches(skip, limit int) ([]*Branch, int, error) { + brs, countAll, err := repo.GetBranchNames(skip, limit) if err != nil { return nil, 0, err } @@ -103,9 +108,9 @@ func GetBranchesByPath(path string, skip, limit int) ([]*Branch, int, error) { branches := make([]*Branch, len(brs)) for i := range brs { branches[i] = &Branch{ - Path: path, + Path: repo.Path, Name: brs[i], - gitRepo: gitRepo, + gitRepo: repo, } } diff --git a/modules/git/repo_branch_gogit.go b/modules/git/repo_branch_gogit.go index 6bf14b3999..d159aafd6f 100644 --- a/modules/git/repo_branch_gogit.go +++ b/modules/git/repo_branch_gogit.go @@ -9,6 +9,7 @@ package git import ( + "context" "strings" "github.com/go-git/go-git/v5/plumbing" @@ -52,7 +53,7 @@ func (repo *Repository) IsBranchExist(name string) bool { // GetBranches returns branches from the repository, skipping skip initial branches and // returning at most limit branches, or all branches if limit is 0. -func (repo *Repository) GetBranches(skip, limit int) ([]string, int, error) { +func (repo *Repository) GetBranchNames(skip, limit int) ([]string, int, error) { var branchNames []string branches, err := repo.gogitRepo.Branches() @@ -79,3 +80,26 @@ func (repo *Repository) GetBranches(skip, limit int) ([]string, int, error) { return branchNames, count, nil } + +// WalkReferences walks all the references from the repository +func WalkReferences(ctx context.Context, repoPath string, walkfn func(string) error) (int, error) { + repo, err := OpenRepositoryCtx(ctx, repoPath) + if err != nil { + return 0, err + } + defer repo.Close() + + i := 0 + iter, err := repo.gogitRepo.References() + if err != nil { + return i, err + } + defer iter.Close() + + err = iter.ForEach(func(ref *plumbing.Reference) error { + err := walkfn(string(ref.Name())) + i++ + return err + }) + return i, err +} diff --git a/modules/git/repo_branch_nogogit.go b/modules/git/repo_branch_nogogit.go index 1928c7515b..55952acda4 100644 --- a/modules/git/repo_branch_nogogit.go +++ b/modules/git/repo_branch_nogogit.go @@ -61,14 +61,29 @@ func (repo *Repository) IsBranchExist(name string) bool { return repo.IsReferenceExist(BranchPrefix + name) } -// GetBranches returns branches from the repository, skipping skip initial branches and +// GetBranchNames returns branches from the repository, skipping skip initial branches and // returning at most limit branches, or all branches if limit is 0. -func (repo *Repository) GetBranches(skip, limit int) ([]string, int, error) { +func (repo *Repository) GetBranchNames(skip, limit int) ([]string, int, error) { return callShowRef(repo.Ctx, repo.Path, BranchPrefix, "--heads", skip, limit) } +// WalkReferences walks all the references from the repository +func WalkReferences(ctx context.Context, repoPath string, walkfn func(string) error) (int, error) { + return walkShowRef(ctx, repoPath, "", 0, 0, walkfn) +} + // callShowRef return refs, if limit = 0 it will not limit func callShowRef(ctx context.Context, repoPath, prefix, arg string, skip, limit int) (branchNames []string, countAll int, err error) { + countAll, err = walkShowRef(ctx, repoPath, arg, skip, limit, func(branchName string) error { + branchName = strings.TrimPrefix(branchName, prefix) + branchNames = append(branchNames, branchName) + + return nil + }) + return +} + +func walkShowRef(ctx context.Context, repoPath, arg string, skip, limit int, walkfn func(string) error) (countAll int, err error) { stdoutReader, stdoutWriter := io.Pipe() defer func() { _ = stdoutReader.Close() @@ -77,7 +92,11 @@ func callShowRef(ctx context.Context, repoPath, prefix, arg string, skip, limit go func() { stderrBuilder := &strings.Builder{} - err := NewCommandContext(ctx, "show-ref", arg).RunInDirPipeline(repoPath, stdoutWriter, stderrBuilder) + args := []string{"show-ref"} + if arg != "" { + args = append(args, arg) + } + err := NewCommandContext(ctx, args...).RunInDirPipeline(repoPath, stdoutWriter, stderrBuilder) if err != nil { if stderrBuilder.Len() == 0 { _ = stdoutWriter.Close() @@ -94,10 +113,10 @@ func callShowRef(ctx context.Context, repoPath, prefix, arg string, skip, limit for i < skip { _, isPrefix, err := bufReader.ReadLine() if err == io.EOF { - return branchNames, i, nil + return i, nil } if err != nil { - return nil, 0, err + return 0, err } if !isPrefix { i++ @@ -112,39 +131,42 @@ func callShowRef(ctx context.Context, repoPath, prefix, arg string, skip, limit _, err = bufReader.ReadSlice(' ') } if err == io.EOF { - return branchNames, i, nil + return i, nil } if err != nil { - return nil, 0, err + return 0, err } branchName, err := bufReader.ReadString('\n') if err == io.EOF { // This shouldn't happen... but we'll tolerate it for the sake of peace - return branchNames, i, nil + return i, nil } if err != nil { - return nil, i, err + return i, err } - branchName = strings.TrimPrefix(branchName, prefix) + if len(branchName) > 0 { branchName = branchName[:len(branchName)-1] } - branchNames = append(branchNames, branchName) + err = walkfn(branchName) + if err != nil { + return i, err + } i++ } // count all refs for limit != 0 { _, isPrefix, err := bufReader.ReadLine() if err == io.EOF { - return branchNames, i, nil + return i, nil } if err != nil { - return nil, 0, err + return 0, err } if !isPrefix { i++ } } - return branchNames, i, nil + return i, nil } diff --git a/modules/git/repo_branch_test.go b/modules/git/repo_branch_test.go index 05d5237e6a..ac5f5deea9 100644 --- a/modules/git/repo_branch_test.go +++ b/modules/git/repo_branch_test.go @@ -17,21 +17,21 @@ func TestRepository_GetBranches(t *testing.T) { assert.NoError(t, err) defer bareRepo1.Close() - branches, countAll, err := bareRepo1.GetBranches(0, 2) + branches, countAll, err := bareRepo1.GetBranchNames(0, 2) assert.NoError(t, err) assert.Len(t, branches, 2) assert.EqualValues(t, 3, countAll) assert.ElementsMatch(t, []string{"branch1", "branch2"}, branches) - branches, countAll, err = bareRepo1.GetBranches(0, 0) + branches, countAll, err = bareRepo1.GetBranchNames(0, 0) assert.NoError(t, err) assert.Len(t, branches, 3) assert.EqualValues(t, 3, countAll) assert.ElementsMatch(t, []string{"branch1", "branch2", "master"}, branches) - branches, countAll, err = bareRepo1.GetBranches(5, 1) + branches, countAll, err = bareRepo1.GetBranchNames(5, 1) assert.NoError(t, err) assert.Len(t, branches, 0) @@ -48,7 +48,7 @@ func BenchmarkRepository_GetBranches(b *testing.B) { defer bareRepo1.Close() for i := 0; i < b.N; i++ { - _, _, err := bareRepo1.GetBranches(0, 0) + _, _, err := bareRepo1.GetBranchNames(0, 0) if err != nil { b.Fatal(err) } diff --git a/routers/web/repo/branch.go b/routers/web/repo/branch.go index 05b45eba4b..9c25180596 100644 --- a/routers/web/repo/branch.go +++ b/routers/web/repo/branch.go @@ -165,14 +165,14 @@ func redirect(ctx *context.Context) { // loadBranches loads branches from the repository limited by page & pageSize. // NOTE: May write to context on error. func loadBranches(ctx *context.Context, skip, limit int) ([]*Branch, int) { - defaultBranch, err := repo_service.GetBranch(ctx.Repo.Repository, ctx.Repo.Repository.DefaultBranch) + defaultBranch, err := ctx.Repo.GitRepo.GetBranch(ctx.Repo.Repository.DefaultBranch) if err != nil { log.Error("loadBranches: get default branch: %v", err) ctx.ServerError("GetDefaultBranch", err) return nil, 0 } - rawBranches, totalNumOfBranches, err := repo_service.GetBranches(ctx.Repo.Repository, skip, limit) + rawBranches, totalNumOfBranches, err := ctx.Repo.GitRepo.GetBranches(skip, limit) if err != nil { log.Error("GetBranches: %v", err) ctx.ServerError("GetBranches", err) diff --git a/routers/web/repo/compare.go b/routers/web/repo/compare.go index 54d7e77f2d..4cd817a399 100644 --- a/routers/web/repo/compare.go +++ b/routers/web/repo/compare.go @@ -660,7 +660,7 @@ func getBranchesAndTagsForRepo(repo *models.Repository) (branches, tags []string } defer gitRepo.Close() - branches, _, err = gitRepo.GetBranches(0, 0) + branches, _, err = gitRepo.GetBranchNames(0, 0) if err != nil { return nil, nil, err } @@ -711,7 +711,7 @@ func CompareDiff(ctx *context.Context) { return } - headBranches, _, err := ci.HeadGitRepo.GetBranches(0, 0) + headBranches, _, err := ci.HeadGitRepo.GetBranchNames(0, 0) if err != nil { ctx.ServerError("GetBranches", err) return diff --git a/routers/web/repo/issue.go b/routers/web/repo/issue.go index f0857b18c0..398aa26cc4 100644 --- a/routers/web/repo/issue.go +++ b/routers/web/repo/issue.go @@ -690,7 +690,7 @@ func RetrieveRepoMetas(ctx *context.Context, repo *models.Repository, isPull boo return nil } - brs, _, err := ctx.Repo.GitRepo.GetBranches(0, 0) + brs, _, err := ctx.Repo.GitRepo.GetBranchNames(0, 0) if err != nil { ctx.ServerError("GetBranches", err) return nil diff --git a/services/repository/adopt.go b/services/repository/adopt.go index 3f4045a778..5503155ab0 100644 --- a/services/repository/adopt.go +++ b/services/repository/adopt.go @@ -142,7 +142,7 @@ func adoptRepository(ctx context.Context, repoPath string, u *user_model.User, r repo.DefaultBranch = strings.TrimPrefix(repo.DefaultBranch, git.BranchPrefix) } - branches, _, _ := gitRepo.GetBranches(0, 0) + branches, _, _ := gitRepo.GetBranchNames(0, 0) found := false hasDefault := false hasMaster := false diff --git a/services/repository/branch.go b/services/repository/branch.go index f33bac7621..08310134bd 100644 --- a/services/repository/branch.go +++ b/services/repository/branch.go @@ -5,8 +5,10 @@ package repository import ( + "context" "errors" "fmt" + "strings" "code.gitea.io/gitea/models" user_model "code.gitea.io/gitea/models/user" @@ -20,7 +22,7 @@ import ( // CreateNewBranch creates a new repository branch func CreateNewBranch(doer *user_model.User, repo *models.Repository, oldBranchName, branchName string) (err error) { // Check if branch name can be used - if err := checkBranchName(repo, branchName); err != nil { + if err := checkBranchName(git.DefaultContext, repo, branchName); err != nil { return err } @@ -65,44 +67,39 @@ func GetBranches(repo *models.Repository, skip, limit int) ([]*git.Branch, int, } // checkBranchName validates branch name with existing repository branches -func checkBranchName(repo *models.Repository, name string) error { - gitRepo, err := git.OpenRepository(repo.RepoPath()) - if err != nil { - return err - } - defer gitRepo.Close() - - branches, _, err := GetBranches(repo, 0, 0) - if err != nil { - return err - } - - for _, branch := range branches { - if branch.Name == name { +func checkBranchName(ctx context.Context, repo *models.Repository, name string) error { + _, err := git.WalkReferences(ctx, repo.RepoPath(), func(refName string) error { + branchRefName := strings.TrimPrefix(refName, git.BranchPrefix) + switch { + case branchRefName == name: return models.ErrBranchAlreadyExists{ - BranchName: branch.Name, + BranchName: name, } - } else if (len(branch.Name) < len(name) && branch.Name+"/" == name[0:len(branch.Name)+1]) || - (len(branch.Name) > len(name) && name+"/" == branch.Name[0:len(name)+1]) { + // If branchRefName like a/b but we want to create a branch named a then we have a conflict + case strings.HasPrefix(branchRefName, name+"/"): return models.ErrBranchNameConflict{ - BranchName: branch.Name, + BranchName: branchRefName, + } + // Conversely if branchRefName like a but we want to create a branch named a/b then we also have a conflict + case strings.HasPrefix(name, branchRefName+"/"): + return models.ErrBranchNameConflict{ + BranchName: branchRefName, + } + case refName == git.TagPrefix+name: + return models.ErrTagAlreadyExists{ + TagName: name, } } - } + return nil + }) - if _, err := gitRepo.GetTag(name); err == nil { - return models.ErrTagAlreadyExists{ - TagName: name, - } - } - - return nil + return err } // CreateNewBranchFromCommit creates a new repository branch func CreateNewBranchFromCommit(doer *user_model.User, repo *models.Repository, commit, branchName string) (err error) { // Check if branch name can be used - if err := checkBranchName(repo, branchName); err != nil { + if err := checkBranchName(git.DefaultContext, repo, branchName); err != nil { return err }