diff --git a/extension/src/experiments/commands/index.ts b/extension/src/experiments/commands/index.ts index c1f5e63862..282c3e96a2 100644 --- a/extension/src/experiments/commands/index.ts +++ b/extension/src/experiments/commands/index.ts @@ -14,8 +14,21 @@ import { RegisteredCommands } from '../../commands/external' export const getBranchExperimentCommand = (experiments: WorkspaceExperiments) => - (cwd: string, name: string, input: string) => - experiments.runCommand(AvailableCommands.EXP_BRANCH, cwd, name, input) + async (cwd: string, name: string, input: string) => { + const output = await experiments.runCommand( + AvailableCommands.EXP_BRANCH, + cwd, + name, + input + ) + + if (!output) { + return + } + + const repository = experiments.getRepository(cwd) + return repository.addBranch(input) + } export const getRenameExperimentCommand = (experiments: WorkspaceExperiments) => diff --git a/extension/src/experiments/index.ts b/extension/src/experiments/index.ts index 67f6a61ab9..1ee36c5b05 100644 --- a/extension/src/experiments/index.ts +++ b/extension/src/experiments/index.ts @@ -613,6 +613,11 @@ export class Experiments extends BaseRepository { return this.experiments.getAvailableBranchesToSelect() } + public addBranch(branch: string) { + this.experiments.addBranch(branch) + return this.refresh() + } + public refresh() { return this.data.update() } diff --git a/extension/src/experiments/model/index.ts b/extension/src/experiments/model/index.ts index b94088a8a6..e25b411539 100644 --- a/extension/src/experiments/model/index.ts +++ b/extension/src/experiments/model/index.ts @@ -585,6 +585,13 @@ export class ExperimentsModel extends ModelWithPersistence { return this.availableBranchesToSelect } + public addBranch(branch: string) { + const selectedBranches: string[] = this.getSelectedBranches() + const branchesWithNewBranch = [...selectedBranches, branch].sort() + + this.setSelectedBranches(branchesWithNewBranch) + } + public setStudioData( live: { baselineSha: string; name: string }[], pushed: string[] diff --git a/extension/src/test/suite/experiments/index.test.ts b/extension/src/test/suite/experiments/index.test.ts index 163f378d90..b3cf937161 100644 --- a/extension/src/test/suite/experiments/index.test.ts +++ b/extension/src/test/suite/experiments/index.test.ts @@ -534,31 +534,77 @@ suite('Experiments Test Suite', () => { }).timeout(WEBVIEW_TEST_TIMEOUT) it('should be able to handle a message to create a branch from an experiment', async () => { - const { mockMessageReceived } = - await stubWorkspaceGettersWebview(disposable) - - const mockBranch = 'mock-branch-input' - const inputEvent = getInputBoxEvent(mockBranch) + const { + mockMessageReceived, + experimentsModel, + mockUpdateExperimentsData + } = await stubWorkspaceGettersWebview(disposable) - const mockExperimentBranch = stub( - DvcExecutor.prototype, - 'expBranch' - ).resolves('undefined') + stub(Setup.prototype, 'getCliVersion').resolves('3.22.0') const mockExperimentId = 'exp-e7a67' + const mockBranch = 'mock-branch-input' + const mockExperimentBranch = stub(DvcExecutor.prototype, 'expBranch') + const mockSetSelectedBranches = stub( + experimentsModel, + 'setSelectedBranches' + ) + stub(window, 'showInputBox').resolves(mockBranch) + + const failedExperimentBranchEvent = new Promise(resolve => + mockExperimentBranch.onFirstCall().callsFake(() => { + resolve(undefined) + return Promise.resolve('') + }) + ) mockMessageReceived.fire({ payload: mockExperimentId, type: MessageFromWebviewType.CREATE_BRANCH_FROM_EXPERIMENT }) - await inputEvent + await failedExperimentBranchEvent + expect(mockExperimentBranch).to.be.calledOnce expect(mockExperimentBranch).to.be.calledWithExactly( dvcDemoPath, mockExperimentId, mockBranch ) + expect(mockSetSelectedBranches).not.to.be.called + expect(mockUpdateExperimentsData).not.to.be.called + + const selectedBranches = ['main', 'other'] + const selectedBranchesWithNewBranch = [ + 'main', + 'mock-branch-input', + 'other' + ] + mockExperimentBranch.onSecondCall().resolves('branch created') + stub(experimentsModel, 'getSelectedBranches') + .onFirstCall() + .returns(selectedBranches) + const waitForBranchesToBeSelected = new Promise(resolve => + mockSetSelectedBranches.callsFake(() => resolve(undefined)) + ) + + mockMessageReceived.fire({ + payload: mockExperimentId, + type: MessageFromWebviewType.CREATE_BRANCH_FROM_EXPERIMENT + }) + + await waitForBranchesToBeSelected + + expect(mockExperimentBranch).to.be.calledTwice + expect(mockExperimentBranch).to.be.calledWithExactly( + dvcDemoPath, + mockExperimentId, + mockBranch + ) + expect(mockSetSelectedBranches).to.be.calledWithExactly( + selectedBranchesWithNewBranch + ) + expect(mockUpdateExperimentsData).to.be.calledOnce }).timeout(WEBVIEW_TEST_TIMEOUT) it('should be able to handle a message to rename an experiment', async () => { diff --git a/extension/src/test/suite/experiments/util.ts b/extension/src/test/suite/experiments/util.ts index a330d40274..f08808a1ba 100644 --- a/extension/src/test/suite/experiments/util.ts +++ b/extension/src/test/suite/experiments/util.ts @@ -387,7 +387,8 @@ export const stubWorkspaceGettersWebview = async ( experimentsModel, messageSpy, mockMessageReceived, - webview + webview, + mockUpdateExperimentsData } = await buildExperimentsWebview({ disposer }) return { @@ -399,6 +400,7 @@ export const stubWorkspaceGettersWebview = async ( messageSpy, ...stubWorkspaceExperiments(dvcRoot, experiments), mockMessageReceived, + mockUpdateExperimentsData, webview } }