Skip to content

Commit

Permalink
Skip container updates if only CDI is selected
Browse files Browse the repository at this point in the history
In the case where ONLY CDI-based device-list-strategies are
requested, the standard updates to the container allocate response
do not make sense.

This change skips these updates in this case.

Signed-off-by: Evan Lezar <[email protected]>
  • Loading branch information
elezar committed Jun 5, 2024
1 parent b8019c1 commit 23eddda
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 10 deletions.
14 changes: 12 additions & 2 deletions api/config/v1/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,22 @@ func (s DeviceListStrategies) Includes(strategy string) bool {
return s[strategy]
}

// IsCDIEnabled returns whether any of the strategies being used require CDI.
func (s DeviceListStrategies) IsCDIEnabled() bool {
// AnyCDIEnabled returns whether any of the strategies being used require CDI.
func (s DeviceListStrategies) AnyCDIEnabled() bool {
for k, v := range s {
if strings.HasPrefix(k, "cdi-") && v {
return true
}
}
return false
}

// IsOnlyCDIEnabled returns whether all strategies being used require CDI.
func (s DeviceListStrategies) AllCDIEnabled() bool {
for k, v := range s {
if !strings.HasPrefix(k, "cdi-") && v {
return false
}
}
return true
}
2 changes: 1 addition & 1 deletion cmd/nvidia-device-plugin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func validateFlags(infolib nvinfo.Interface, config *spec.Config) error {
}

hasNvml, _ := infolib.HasNvml()
if deviceListStrategies.IsCDIEnabled() && !hasNvml {
if deviceListStrategies.AnyCDIEnabled() && !hasNvml {
return fmt.Errorf("CDI --device-list-strategy options are only supported on NVML-based systems")
}

Expand Down
9 changes: 6 additions & 3 deletions internal/cdi/cdi.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func New(infolib info.Interface, nvmllib nvml.Interface, devicelib device.Interf
opt(c)
}

if !c.deviceListStrategies.IsCDIEnabled() {
if !c.deviceListStrategies.AnyCDIEnabled() {
return &null{}, nil
}
hasNVML, _ := infolib.HasNvml()
Expand All @@ -87,11 +87,14 @@ func New(infolib info.Interface, nvmllib nvml.Interface, devicelib device.Interf
if c.logger == nil {
c.logger = logrus.StandardLogger()
}
if c.deviceIDStrategy == "" {
c.deviceIDStrategy = "uuid"
}
if c.driverRoot == "" {
c.driverRoot = "/"
}
if c.deviceIDStrategy == "" {
c.deviceIDStrategy = "uuid"
if c.devRoot == "" {
c.devRoot = c.driverRoot
}
if c.targetDriverRoot == "" {
c.targetDriverRoot = c.driverRoot
Expand Down
15 changes: 11 additions & 4 deletions internal/plugin/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,21 +339,28 @@ func (plugin *NvidiaDevicePlugin) getAllocateResponse(requestIds []string) (*plu
response := &pluginapi.ContainerAllocateResponse{
Envs: make(map[string]string),
}
if plugin.deviceListStrategies.IsCDIEnabled() {
if plugin.deviceListStrategies.AnyCDIEnabled() {
responseID := uuid.New().String()
if err := plugin.updateResponseForCDI(response, responseID, deviceIDs...); err != nil {
return nil, fmt.Errorf("failed to get allocate response for CDI: %v", err)
}
}
if plugin.config.Sharing.SharingStrategy() == spec.SharingStrategyMPS {
plugin.updateResponseForMPS(response)
}

// The following modifications are only made if at least one non-CDI device
// list strategy is selected.
if plugin.deviceListStrategies.AllCDIEnabled() {
return response, nil
}

if plugin.deviceListStrategies.Includes(spec.DeviceListStrategyEnvvar) {
plugin.updateResponseForDeviceListEnvvar(response, deviceIDs...)
}
if plugin.deviceListStrategies.Includes(spec.DeviceListStrategyVolumeMounts) {
plugin.updateResponseForDeviceMounts(response, deviceIDs...)
}
if plugin.config.Sharing.SharingStrategy() == spec.SharingStrategyMPS {
plugin.updateResponseForMPS(response)
}
if *plugin.config.Flags.Plugin.PassDeviceSpecs {
response.Devices = append(response.Devices, plugin.apiDeviceSpecs(*plugin.config.Flags.NvidiaDevRoot, requestIds)...)
}
Expand Down

0 comments on commit 23eddda

Please sign in to comment.