From 5851790fe8a54ab3d6a9ac7d98613a69084da920 Mon Sep 17 00:00:00 2001 From: Julie G <43496356+julieg18@users.noreply.github.com> Date: Mon, 14 Aug 2023 08:29:05 -0500 Subject: [PATCH] Save comparison multi plot image values across sessions (#4476) --- extension/src/persistence/constants.ts | 1 + extension/src/plots/model/index.ts | 29 ++++++++++- extension/src/plots/webview/contract.ts | 5 ++ extension/src/plots/webview/messages.ts | 21 ++++++++ extension/src/telemetry/constants.ts | 3 ++ .../src/test/fixtures/plotsDiff/index.ts | 1 + extension/src/test/suite/plots/index.test.ts | 36 +++++++++++++ extension/src/webview/contract.ts | 5 ++ webview/src/plots/components/App.test.tsx | 50 ++++++++++++++++++- .../cell/ComparisonTableMultiCell.tsx | 31 ++++++++++-- .../comparisonTable/comparisonTableSlice.ts | 1 + webview/src/plots/util/messages.ts | 11 ++++ .../src/stories/ComparisonTable.stories.tsx | 1 + 13 files changed, 188 insertions(+), 7 deletions(-) diff --git a/extension/src/persistence/constants.ts b/extension/src/persistence/constants.ts index f15f42c2b5..e3664bbb65 100644 --- a/extension/src/persistence/constants.ts +++ b/extension/src/persistence/constants.ts @@ -18,6 +18,7 @@ export enum PersistenceKey { PLOT_SECTION_COLLAPSED = 'plotSectionCollapsed:', PLOT_SELECTED_METRICS = 'plotSelectedMetrics:', PLOTS_SMOOTH_PLOT_VALUES = 'plotSmoothPlotValues:', + PLOTS_COMPARISON_MULTI_PLOT_VALUES = 'plotComparisonMultiPlotValues:', PLOT_TEMPLATE_ORDER = 'plotTemplateOrder:', SHOW_ONLY_CHANGED = 'columnsShowOnlyChanged:' } diff --git a/extension/src/plots/model/index.ts b/extension/src/plots/model/index.ts index 2528e6aee2..f24798460d 100644 --- a/extension/src/plots/model/index.ts +++ b/extension/src/plots/model/index.ts @@ -33,7 +33,8 @@ import { DEFAULT_NB_ITEMS_PER_ROW, PlotHeight, SmoothPlotValues, - ImagePlot + ImagePlot, + ComparisonMultiPlotValues } from '../webview/contract' import { EXPERIMENT_WORKSPACE_ID, @@ -80,6 +81,7 @@ export class PlotsModel extends ModelWithPersistence { private comparisonData: ComparisonData = {} private comparisonOrder: string[] + private comparisonMultiPlotValues: ComparisonMultiPlotValues = {} private smoothPlotValues: SmoothPlotValues = {} private revisionData: RevisionData = {} @@ -113,6 +115,10 @@ export class PlotsModel extends ModelWithPersistence { PersistenceKey.PLOTS_SMOOTH_PLOT_VALUES, {} ) + this.comparisonMultiPlotValues = this.revive( + PersistenceKey.PLOTS_COMPARISON_MULTI_PLOT_VALUES, + {} + ) this.cleanupOutdatedCustomPlotsState() this.cleanupOutdatedTrendsState() @@ -307,6 +313,26 @@ export class PlotsModel extends ModelWithPersistence { this.persist(PersistenceKey.PLOT_COMPARISON_ORDER, this.comparisonOrder) } + public setComparisonMultiPlotValue( + revision: string, + path: string, + value: number + ) { + if (!this.comparisonMultiPlotValues[revision]) { + this.comparisonMultiPlotValues[revision] = {} + } + + this.comparisonMultiPlotValues[revision][path] = value + this.persist( + PersistenceKey.PLOTS_COMPARISON_MULTI_PLOT_VALUES, + this.comparisonMultiPlotValues + ) + } + + public getComparisonMultiPlotValues() { + return this.comparisonMultiPlotValues + } + public getSelectedRevisionIds() { return this.experiments.getSelectedRevisions().map(({ id }) => id) } @@ -394,6 +420,7 @@ export class PlotsModel extends ModelWithPersistence { ...this.comparisonData, ...comparisonData } + this.revisionData = { ...this.revisionData, ...revisionData diff --git a/extension/src/plots/webview/contract.ts b/extension/src/plots/webview/contract.ts index 32c8817459..d1c6eae0d1 100644 --- a/extension/src/plots/webview/contract.ts +++ b/extension/src/plots/webview/contract.ts @@ -66,11 +66,16 @@ export type Revision = { label: string } +export type ComparisonMultiPlotValues = { + [revision: string]: { [path: string]: number } +} + export interface PlotsComparisonData { plots: ComparisonPlots width: number height: PlotHeight revisions: Revision[] + multiPlotValues: ComparisonMultiPlotValues } export type CustomPlotValues = { diff --git a/extension/src/plots/webview/messages.ts b/extension/src/plots/webview/messages.ts index 06eb8d3b50..1031123877 100644 --- a/extension/src/plots/webview/messages.ts +++ b/extension/src/plots/webview/messages.ts @@ -102,6 +102,12 @@ export class WebviewMessages { return this.selectPlotsFromWebview() case MessageFromWebviewType.SELECT_EXPERIMENTS: return this.selectExperimentsFromWebview() + case MessageFromWebviewType.SET_COMPARISON_MULTI_PLOT_VALUE: + return this.setComparisonMultiPlotValue( + message.payload.revision, + message.payload.path, + message.payload.value + ) case MessageFromWebviewType.REMOVE_CUSTOM_PLOTS: return commands.executeCommand( RegisteredCommands.PLOTS_CUSTOM_REMOVE, @@ -224,6 +230,20 @@ export class WebviewMessages { ) } + private setComparisonMultiPlotValue( + revision: string, + path: string, + value: number + ) { + this.plots.setComparisonMultiPlotValue(revision, path, value) + this.sendComparisonPlots() + sendTelemetryEvent( + EventName.VIEWS_PLOTS_SET_COMPARISON_MULTI_PLOT_VALUE, + undefined, + undefined + ) + } + private setTemplateOrder(order: PlotsTemplatesReordered) { this.paths.setTemplateOrder(order) this.sendTemplatePlots() @@ -345,6 +365,7 @@ export class WebviewMessages { return { height: this.plots.getHeight(PlotsSection.COMPARISON_TABLE), + multiPlotValues: this.plots.getComparisonMultiPlotValues(), plots: comparison.map(({ path, revisions }) => { return { path, revisions: this.getRevisionsWithCorrectUrls(revisions) } }), diff --git a/extension/src/telemetry/constants.ts b/extension/src/telemetry/constants.ts index afe11d0c27..d4ed198165 100644 --- a/extension/src/telemetry/constants.ts +++ b/extension/src/telemetry/constants.ts @@ -92,6 +92,8 @@ export const EventName = Object.assign( VIEWS_PLOTS_SECTION_TOGGLE: 'views.plots.toggleSection', VIEWS_PLOTS_SELECT_EXPERIMENTS: 'view.plots.selectExperiments', VIEWS_PLOTS_SELECT_PLOTS: 'view.plots.selectPlots', + VIEWS_PLOTS_SET_COMPARISON_MULTI_PLOT_VALUE: + 'view.plots.setComparisonMultiPlotValue', VIEWS_PLOTS_SET_SMOOTH_PLOT_VALUE: 'view.plots.setSmoothPlotValues', VIEWS_PLOTS_ZOOM_PLOT: 'views.plots.zoomPlot', VIEWS_REORDER_PLOTS_CUSTOM: 'views.plots.customReordered', @@ -296,6 +298,7 @@ export interface IEventNamePropertyMapping { [EventName.VIEWS_PLOTS_ZOOM_PLOT]: { isImage: boolean } [EventName.VIEWS_REORDER_PLOTS_CUSTOM]: undefined [EventName.VIEWS_REORDER_PLOTS_TEMPLATES]: undefined + [EventName.VIEWS_PLOTS_SET_COMPARISON_MULTI_PLOT_VALUE]: undefined [EventName.VIEWS_PLOTS_SET_SMOOTH_PLOT_VALUE]: undefined [EventName.VIEWS_PLOTS_PATH_TREE_OPENED]: DvcRootCount diff --git a/extension/src/test/fixtures/plotsDiff/index.ts b/extension/src/test/fixtures/plotsDiff/index.ts index b54e733562..65513e0caf 100644 --- a/extension/src/test/fixtures/plotsDiff/index.ts +++ b/extension/src/test/fixtures/plotsDiff/index.ts @@ -811,6 +811,7 @@ export const getComparisonWebviewMessage = ( return { revisions: getRevisions(), + multiPlotValues: {}, plots: Object.values(plotAcc), width: DEFAULT_PLOT_WIDTH, height: DEFAULT_PLOT_HEIGHT diff --git a/extension/src/test/suite/plots/index.test.ts b/extension/src/test/suite/plots/index.test.ts index b0444d7b3a..2e1287316a 100644 --- a/extension/src/test/suite/plots/index.test.ts +++ b/extension/src/test/suite/plots/index.test.ts @@ -1065,6 +1065,42 @@ suite('Plots Test Suite', () => { ) }) + it('should handle an update comparison multi plot value message from the webview', async () => { + const { mockMessageReceived, plotsModel } = await buildPlotsWebview({ + disposer: disposable, + plotsDiff: plotsDiffFixture + }) + const multiImg = comparisonPlotsFixture.plots[3] + + const mockSendTelemetryEvent = stub(Telemetry, 'sendTelemetryEvent') + const mockSetComparisonMultiPlotValue = stub( + plotsModel, + 'setComparisonMultiPlotValue' + ) + + mockMessageReceived.fire({ + payload: { + path: multiImg.path, + revision: 'main', + value: 5 + }, + type: MessageFromWebviewType.SET_COMPARISON_MULTI_PLOT_VALUE + }) + + expect(mockSendTelemetryEvent).to.be.called + expect(mockSendTelemetryEvent).to.be.calledWithExactly( + EventName.VIEWS_PLOTS_SET_COMPARISON_MULTI_PLOT_VALUE, + undefined, + undefined + ) + expect(mockSetComparisonMultiPlotValue).to.be.called + expect(mockSetComparisonMultiPlotValue).to.be.calledWithExactly( + 'main', + multiImg.path, + 5 + ) + }) + it('should handle the CLI throwing an error', async () => { const { data, errorsModel, mockPlotsDiff, plots, plotsModel } = await buildPlots({ disposer: disposable, plotsDiff: plotsDiffFixture }) diff --git a/extension/src/webview/contract.ts b/extension/src/webview/contract.ts index 170864149e..fcc22479c4 100644 --- a/extension/src/webview/contract.ts +++ b/extension/src/webview/contract.ts @@ -44,6 +44,7 @@ export enum MessageFromWebviewType { RESIZE_COLUMN = 'resize-column', RESIZE_PLOTS = 'resize-plots', SAVE_STUDIO_TOKEN = 'save-studio-token', + SET_COMPARISON_MULTI_PLOT_VALUE = 'update-comparison-multi-plot-value', SET_SMOOTH_PLOT_VALUE = 'update-smooth-plot-value', SHOW_EXPERIMENT_LOGS = 'show-experiment-logs', SHOW_WALKTHROUGH = 'show-walkthrough', @@ -211,6 +212,10 @@ export type MessageFromWebview = type: MessageFromWebviewType.REORDER_PLOTS_COMPARISON_ROWS payload: string[] } + | { + type: MessageFromWebviewType.SET_COMPARISON_MULTI_PLOT_VALUE + payload: { path: string; revision: string; value: number } + } | { type: MessageFromWebviewType.REORDER_PLOTS_CUSTOM payload: string[] diff --git a/webview/src/plots/components/App.test.tsx b/webview/src/plots/components/App.test.tsx index c18d5adf24..6b17d13281 100644 --- a/webview/src/plots/components/App.test.tsx +++ b/webview/src/plots/components/App.test.tsx @@ -225,6 +225,7 @@ describe('App', () => { renderAppWithOptionalData({ comparison: { height: DEFAULT_PLOT_HEIGHT, + multiPlotValues: {}, plots: [ { path: 'training/plots/images/misclassified.jpg', @@ -279,6 +280,7 @@ describe('App', () => { renderAppWithOptionalData({ comparison: { height: DEFAULT_PLOT_HEIGHT, + multiPlotValues: {}, plots: [ { path: 'training/plots/images/image', @@ -1757,6 +1759,25 @@ describe('App', () => { const workspaceImgs = comparisonTableFixture.plots[3].revisions.workspace.imgs + const multiImgPlots = screen.getAllByTestId('multi-image-cell') + const slider = within(multiImgPlots[0]).getByRole('slider') + const workspaceImgEl = within(multiImgPlots[0]).getByRole('img') + + expect(workspaceImgEl).toHaveAttribute('src', workspaceImgs[0].url) + + fireEvent.change(slider, { target: { value: 3 } }) + + expect(workspaceImgEl).toHaveAttribute('src', workspaceImgs[3].url) + }) + + it('should send a message when the slider changes', async () => { + renderAppWithOptionalData({ + comparison: comparisonTableFixture + }) + + const multiImg = comparisonTableFixture.plots[3] + const workspacePlot = multiImg.revisions.workspace + const workspaceImgs = workspacePlot.imgs const multiImgPlots = screen.getAllByTestId('multi-image-cell') const slider = within(multiImgPlots[0]).getByRole('slider') @@ -1766,6 +1787,34 @@ describe('App', () => { fireEvent.change(slider, { target: { value: 3 } }) + await waitFor( + () => { + expect(mockPostMessage).toHaveBeenCalledWith({ + payload: { + path: multiImg.path, + revision: workspacePlot.id, + value: 3 + }, + type: MessageFromWebviewType.SET_COMPARISON_MULTI_PLOT_VALUE + }) + }, + { timeout: 1000 } + ) + }) + + it('should set default slider value if given a saved value', () => { + const multiImg = comparisonTableFixture.plots[3] + renderAppWithOptionalData({ + comparison: { + ...comparisonTableFixture, + multiPlotValues: { workspace: { [multiImg.path]: 3 } } + } + }) + + const workspaceImgs = multiImg.revisions.workspace.imgs + const multiImgPlots = screen.getAllByTestId('multi-image-cell') + const workspaceImgEl = within(multiImgPlots[0]).getByRole('img') + expect(workspaceImgEl).toHaveAttribute('src', workspaceImgs[3].url) }) @@ -1795,7 +1844,6 @@ describe('App', () => { }) const mainImgs = comparisonTableFixture.plots[3].revisions.main.imgs - const multiImgPlots = screen.getAllByTestId('multi-image-cell') const slider = within(multiImgPlots[1]).getByRole('slider') diff --git a/webview/src/plots/components/comparisonTable/cell/ComparisonTableMultiCell.tsx b/webview/src/plots/components/comparisonTable/cell/ComparisonTableMultiCell.tsx index 154fcec924..0ddbc3b1db 100644 --- a/webview/src/plots/components/comparisonTable/cell/ComparisonTableMultiCell.tsx +++ b/webview/src/plots/components/comparisonTable/cell/ComparisonTableMultiCell.tsx @@ -1,16 +1,23 @@ -import React, { useCallback, useState } from 'react' -import { useDispatch } from 'react-redux' +import React, { useEffect, useCallback, useRef, useState } from 'react' +import { useDispatch, useSelector } from 'react-redux' import { ComparisonPlot } from 'dvc/src/plots/webview/contract' import { ComparisonTableCell } from './ComparisonTableCell' import styles from '../styles.module.scss' import { changeDisabledDragIds } from '../comparisonTableSlice' +import { setComparisonMultiPlotValue } from '../../../util/messages' +import { PlotsState } from '../../../store' export const ComparisonTableMultiCell: React.FC<{ path: string plot: ComparisonPlot }> = ({ path, plot }) => { - const [currentStep, setCurrentStep] = useState(0) + const values = useSelector( + (state: PlotsState) => state.comparison.multiPlotValues + ) + const [currentStep, setCurrentStep] = useState(values?.[plot.id]?.[path] || 0) const dispatch = useDispatch() + const maxStep = plot.imgs.length - 1 + const changeDebounceTimer = useRef(0) const addDisabled = useCallback(() => { dispatch(changeDisabledDragIds([path])) @@ -20,6 +27,16 @@ export const ComparisonTableMultiCell: React.FC<{ dispatch(changeDisabledDragIds([])) }, [dispatch]) + useEffect(() => { + window.clearTimeout(changeDebounceTimer.current) + changeDebounceTimer.current = window.setTimeout(() => { + if (currentStep === values?.[plot.id]?.[path]) { + return + } + setComparisonMultiPlotValue(path, plot.id, currentStep) + }, 500) + }, [values, path, plot.id, currentStep]) + return (
{ + if (!event.target) { + return + } + setCurrentStep(Number(event.target.value)) }} /> diff --git a/webview/src/plots/components/comparisonTable/comparisonTableSlice.ts b/webview/src/plots/components/comparisonTable/comparisonTableSlice.ts index 81029d8710..0ddd2f385e 100644 --- a/webview/src/plots/components/comparisonTable/comparisonTableSlice.ts +++ b/webview/src/plots/components/comparisonTable/comparisonTableSlice.ts @@ -21,6 +21,7 @@ export const comparisonTableInitialState: ComparisonTableState = { hasData: false, height: DEFAULT_HEIGHT[PlotsSection.COMPARISON_TABLE], isCollapsed: DEFAULT_SECTION_COLLAPSED[PlotsSection.COMPARISON_TABLE], + multiPlotValues: {}, plots: [], revisions: [], rowHeight: DEFAULT_ROW_HEIGHT, diff --git a/webview/src/plots/util/messages.ts b/webview/src/plots/util/messages.ts index 24639833ab..a356ad7bea 100644 --- a/webview/src/plots/util/messages.ts +++ b/webview/src/plots/util/messages.ts @@ -96,6 +96,17 @@ export const selectRevisions = () => type: MessageFromWebviewType.SELECT_EXPERIMENTS }) +export const setComparisonMultiPlotValue = ( + path: string, + revision: string, + value: number +) => { + sendMessage({ + payload: { path, revision, value }, + type: MessageFromWebviewType.SET_COMPARISON_MULTI_PLOT_VALUE + }) +} + export const togglePlotsSection = ( sectionKey: PlotsSection, sectionCollapsed: boolean diff --git a/webview/src/stories/ComparisonTable.stories.tsx b/webview/src/stories/ComparisonTable.stories.tsx index cef1732ad2..ae1650153d 100644 --- a/webview/src/stories/ComparisonTable.stories.tsx +++ b/webview/src/stories/ComparisonTable.stories.tsx @@ -47,6 +47,7 @@ const Template: StoryFn = ({ plots, revisions }) => {