diff --git a/api/config/v1/strategy.go b/api/config/v1/strategy.go index a28d9241c..01feff82b 100644 --- a/api/config/v1/strategy.go +++ b/api/config/v1/strategy.go @@ -48,8 +48,8 @@ func (s DeviceListStrategies) Includes(strategy string) bool { return s[strategy] } -// IsCDIEnabled returns whether any of the strategies being used require CDI. -func (s DeviceListStrategies) IsCDIEnabled() bool { +// AnyCDIEnabled returns whether any of the strategies being used require CDI. +func (s DeviceListStrategies) AnyCDIEnabled() bool { for k, v := range s { if strings.HasPrefix(k, "cdi-") && v { return true @@ -57,3 +57,13 @@ func (s DeviceListStrategies) IsCDIEnabled() bool { } return false } + +// IsOnlyCDIEnabled returns whether all strategies being used require CDI. +func (s DeviceListStrategies) AllCDIEnabled() bool { + for k, v := range s { + if !strings.HasPrefix(k, "cdi-") && v { + return false + } + } + return true +} diff --git a/cmd/nvidia-device-plugin/main.go b/cmd/nvidia-device-plugin/main.go index 81e8643d1..452750d17 100644 --- a/cmd/nvidia-device-plugin/main.go +++ b/cmd/nvidia-device-plugin/main.go @@ -152,7 +152,7 @@ func validateFlags(infolib nvinfo.Interface, config *spec.Config) error { } hasNvml, _ := infolib.HasNvml() - if deviceListStrategies.IsCDIEnabled() && !hasNvml { + if deviceListStrategies.AnyCDIEnabled() && !hasNvml { return fmt.Errorf("CDI --device-list-strategy options are only supported on NVML-based systems") } diff --git a/internal/cdi/cdi.go b/internal/cdi/cdi.go index e9292ad4e..796947fc5 100644 --- a/internal/cdi/cdi.go +++ b/internal/cdi/cdi.go @@ -75,7 +75,7 @@ func New(infolib info.Interface, nvmllib nvml.Interface, devicelib device.Interf opt(c) } - if !c.deviceListStrategies.IsCDIEnabled() { + if !c.deviceListStrategies.AnyCDIEnabled() { return &null{}, nil } hasNVML, _ := infolib.HasNvml() @@ -87,11 +87,14 @@ func New(infolib info.Interface, nvmllib nvml.Interface, devicelib device.Interf if c.logger == nil { c.logger = logrus.StandardLogger() } + if c.deviceIDStrategy == "" { + c.deviceIDStrategy = "uuid" + } if c.driverRoot == "" { c.driverRoot = "/" } - if c.deviceIDStrategy == "" { - c.deviceIDStrategy = "uuid" + if c.devRoot == "" { + c.devRoot = c.driverRoot } if c.targetDriverRoot == "" { c.targetDriverRoot = c.driverRoot diff --git a/internal/plugin/server.go b/internal/plugin/server.go index 167f3223e..166d0bd2f 100644 --- a/internal/plugin/server.go +++ b/internal/plugin/server.go @@ -339,21 +339,28 @@ func (plugin *NvidiaDevicePlugin) getAllocateResponse(requestIds []string) (*plu response := &pluginapi.ContainerAllocateResponse{ Envs: make(map[string]string), } - if plugin.deviceListStrategies.IsCDIEnabled() { + if plugin.deviceListStrategies.AnyCDIEnabled() { responseID := uuid.New().String() if err := plugin.updateResponseForCDI(response, responseID, deviceIDs...); err != nil { return nil, fmt.Errorf("failed to get allocate response for CDI: %v", err) } } + if plugin.config.Sharing.SharingStrategy() == spec.SharingStrategyMPS { + plugin.updateResponseForMPS(response) + } + + // The following modifications are only made if at least one non-CDI device + // list strategy is selected. + if plugin.deviceListStrategies.AllCDIEnabled() { + return response, nil + } + if plugin.deviceListStrategies.Includes(spec.DeviceListStrategyEnvvar) { plugin.updateResponseForDeviceListEnvvar(response, deviceIDs...) } if plugin.deviceListStrategies.Includes(spec.DeviceListStrategyVolumeMounts) { plugin.updateResponseForDeviceMounts(response, deviceIDs...) } - if plugin.config.Sharing.SharingStrategy() == spec.SharingStrategyMPS { - plugin.updateResponseForMPS(response) - } if *plugin.config.Flags.Plugin.PassDeviceSpecs { response.Devices = append(response.Devices, plugin.apiDeviceSpecs(*plugin.config.Flags.NvidiaDevRoot, requestIds)...) }