From d50be194c537851ec8f74814683934623b86e43e Mon Sep 17 00:00:00 2001 From: Christopher Desiniotis Date: Tue, 2 Jun 2026 15:51:54 -0700 Subject: [PATCH] [vfio-manage] migrate to go-nvlib/pkg/nvpassthrough Signed-off-by: Christopher Desiniotis --- cmd/vfio-manage/bind.go | 27 +- cmd/vfio-manage/unbind.go | 20 +- go.mod | 7 +- go.sum | 8 +- internal/nvpassthrough/modalias.go | 215 --------------- internal/nvpassthrough/modalias_test.go | 292 --------------------- internal/nvpassthrough/nvpassthrough.go | 334 ------------------------ 7 files changed, 38 insertions(+), 865 deletions(-) delete mode 100644 internal/nvpassthrough/modalias.go delete mode 100644 internal/nvpassthrough/modalias_test.go delete mode 100644 internal/nvpassthrough/nvpassthrough.go diff --git a/cmd/vfio-manage/bind.go b/cmd/vfio-manage/bind.go index a9721da2..1274fa7a 100644 --- a/cmd/vfio-manage/bind.go +++ b/cmd/vfio-manage/bind.go @@ -25,7 +25,7 @@ import ( "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" - "github.com/NVIDIA/k8s-driver-manager/internal/nvpassthrough" + "github.com/NVIDIA/go-nvlib/pkg/nvpassthrough" ) type bindCommand struct { @@ -38,7 +38,7 @@ type bindCommand struct { type bindOptions struct { all bool deviceID string - hostRoot string + libModulesRoot string bindNVSwitches bool } @@ -75,11 +75,9 @@ func (m bindCommand) build() *cli.Command { Usage: "Specific device ID to bind (e.g., 0000:01:00.0)", }, &cli.StringFlag{ - Name: "host-root", - Destination: &m.options.hostRoot, - EnvVars: []string{"HOST_ROOT"}, - Value: "/", - Usage: "Path to the host's root filesystem. This is used when loading the vfio-pci module.", + Name: "host-root", + EnvVars: []string{"HOST_ROOT"}, + Usage: "DEPRECATED: the host root is no longer required to load the vfio-pci module, please use --lib-modules-root instead", }, &cli.BoolFlag{ Name: "bind-nvswitches", @@ -87,6 +85,13 @@ func (m bindCommand) build() *cli.Command { EnvVars: []string{"BIND_NVSWITCHES"}, Usage: "Also bind NVSwitches to vfio-pci (default: false)", }, + &cli.StringFlag{ + Name: "lib-modules-root", + Destination: &m.options.libModulesRoot, + EnvVars: []string{"LIB_MODULES_ROOT"}, + Value: "/lib/modules", + Usage: "Path to the /lib/modules. This is used when loading the vfio-pci module.", + }, }, } @@ -112,7 +117,9 @@ func (m bindCommand) run() error { m.nvpassthrough = nvpassthrough.New( nvpassthrough.WithLogger(m.logger), - nvpassthrough.WithHostRoot(m.options.hostRoot), + nvpassthrough.WithLibModulesRoot(m.options.libModulesRoot), + nvpassthrough.WithNvpciLib(m.nvpci), + nvpassthrough.WithLoadKernelModules(true), ) if m.options.deviceID != "" { @@ -138,7 +145,7 @@ func (m bindCommand) bindAll() error { for _, dev := range devices { m.logger.Infof("Binding device %s", dev.Address) - if err := m.nvpassthrough.BindToVFIODriver(dev); err != nil { + if err := m.nvpassthrough.BindToVFIODriver(dev.Address); err != nil { m.logger.Warnf("Failed to bind device %s: %v", dev.Address, err) } } @@ -169,7 +176,7 @@ func (m bindCommand) bindDevice() error { m.logger.Infof("Binding device %s", device) - if err := m.nvpassthrough.BindToVFIODriver(nvdev); err != nil { + if err := m.nvpassthrough.BindToVFIODriver(device); err != nil { return fmt.Errorf("failed to bind device %s to vfio driver: %w", device, err) } diff --git a/cmd/vfio-manage/unbind.go b/cmd/vfio-manage/unbind.go index 06fb47d2..009cb127 100644 --- a/cmd/vfio-manage/unbind.go +++ b/cmd/vfio-manage/unbind.go @@ -21,11 +21,10 @@ package main import ( "fmt" + "github.com/NVIDIA/go-nvlib/pkg/nvpassthrough" "github.com/NVIDIA/go-nvlib/pkg/nvpci" "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" - - "github.com/NVIDIA/k8s-driver-manager/internal/nvpassthrough" ) type unbindCommand struct { @@ -43,13 +42,16 @@ type unbindOptions struct { // newUnbindCommand constructs an unbind command with the specified logger func newUnbindCommand(logger *logrus.Logger) *cli.Command { + nvpciLib := nvpci.New( + nvpci.WithLogger(logger), + ) + c := unbindCommand{ logger: logger, - nvpci: nvpci.New( - nvpci.WithLogger(logger), - ), + nvpci: nvpciLib, nvpassthrough: nvpassthrough.New( nvpassthrough.WithLogger(logger), + nvpassthrough.WithNvpciLib(nvpciLib), ), } return c.build() @@ -127,7 +129,7 @@ func (m unbindCommand) unbindAll() error { for _, dev := range devices { m.logger.Infof("Unbinding device %s", dev.Address) - if err := m.nvpassthrough.UnbindFromDriver(dev); err != nil { + if err := m.nvpassthrough.Unbind(dev.Address); err != nil { m.logger.Warnf("Failed to unbind device %s: %v", dev.Address, err) } } @@ -136,9 +138,7 @@ func (m unbindCommand) unbindAll() error { func (m unbindCommand) unbindDevice() error { device := m.options.deviceID - // Note: Despite its name, GetGPUByPciBusID returns any NVIDIA PCI device - // (GPU, NVSwitch, etc.) at the specified address, not just GPUs. - nvdev, err := m.nvpci.GetGPUByPciBusID(device) + nvdev, err := m.nvpci.GetNvidiaDeviceByPciBusID(device) if err != nil { return fmt.Errorf("failed to get NVIDIA device: %w", err) } @@ -157,7 +157,7 @@ func (m unbindCommand) unbindDevice() error { m.logger.Infof("Unbinding device %s", device) - if err := m.nvpassthrough.UnbindFromDriver(nvdev); err != nil { + if err := m.nvpassthrough.Unbind(device); err != nil { return fmt.Errorf("failed to unbind device %s from driver: %w", device, err) } diff --git a/go.mod b/go.mod index f838f0b0..6897625d 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,6 @@ require ( github.com/NVIDIA/go-nvlib v0.10.0 github.com/moby/sys/mount v0.3.4 github.com/sirupsen/logrus v1.9.4 - github.com/stretchr/testify v1.11.1 github.com/urfave/cli/v2 v2.27.7 golang.org/x/sys v0.45.0 k8s.io/api v0.36.1 @@ -15,6 +14,8 @@ require ( k8s.io/kubectl v0.36.1 ) +replace github.com/NVIDIA/go-nvlib => ../go-nvlib + require ( github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/MakeNowJust/heredoc v1.0.0 // indirect @@ -36,6 +37,7 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/compress v1.18.6 // indirect github.com/liggitt/tabwriter v0.0.0-20181228230101-89fcab3d43de // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mitchellh/go-wordwrap v1.0.1 // indirect @@ -46,10 +48,11 @@ require ( github.com/monochromegane/go-gitignore v0.0.0-20200626010858-205db1a8cc00 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/peterbourgon/diskv v2.0.1+incompatible // indirect - github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/pmorjan/kmod v1.1.1 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/spf13/cobra v1.10.2 // indirect github.com/spf13/pflag v1.0.9 // indirect + github.com/ulikunitz/xz v0.5.15 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/xlab/treeprint v1.2.0 // indirect github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect diff --git a/go.sum b/go.sum index 839b6ba1..a8b66b8b 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,6 @@ github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE= github.com/Masterminds/semver/v3 v3.4.0 h1:Zog+i5UMtVoCU8oKka5P7i9q9HgrJeGzI9SA1Xbatp0= github.com/Masterminds/semver/v3 v3.4.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= -github.com/NVIDIA/go-nvlib v0.10.0 h1:2jbAFmvLBntIc/4iUChI9DzxyYNI92pohXU4kFuNrg0= -github.com/NVIDIA/go-nvlib v0.10.0/go.mod h1:7mzx9FSdO9fXWP9NKuZmWkCwhkEcSWQFe2tmFwtLb9c= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ= github.com/chai2010/gettext-go v1.0.2 h1:1Lwwip6Q2QGsAdl/ZKPCwTe9fe0CjlUbqj5bFNSjIRk= @@ -57,6 +55,8 @@ github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8Hm github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/compress v1.18.6 h1:2jupLlAwFm95+YDR+NwD2MEfFO9d4z4Prjl1XXDjuao= +github.com/klauspost/compress v1.18.6/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= @@ -95,6 +95,8 @@ github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmorjan/kmod v1.1.1 h1:Vfw6bMaOg/sYSBCqJPT9TbqHHf5zK00GbaL5JQLO4r0= +github.com/pmorjan/kmod v1.1.1/go.mod h1:jR4fVosEpQ6b5U0rpxaqoShTDPvCjLIP8vEESZyvnqQ= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= @@ -120,6 +122,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY= +github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU= github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= diff --git a/internal/nvpassthrough/modalias.go b/internal/nvpassthrough/modalias.go deleted file mode 100644 index 088f1f52..00000000 --- a/internal/nvpassthrough/modalias.go +++ /dev/null @@ -1,215 +0,0 @@ -/* - * Copyright (c) NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package nvpassthrough - -import ( - "fmt" - "math" - "reflect" - "strings" - - "golang.org/x/sys/unix" -) - -const ( - vfioPciAliasPrefix string = "alias vfio_pci:" -) - -// modAlias is a decomposed version of string like this -// -// vNNNNNNNNdNNNNNNNNsvNNNNNNNNsdNNNNNNNNbcNNscNNiNN -// -// The "NNNN" are always of the length in the example -// unless replaced with a wildcard ("*") -type modAlias struct { - vendor string // v - device string // d - subvendor string // sv - subdevice string // sd - baseClass string // bc - subClass string // sc - programmingInterface string // i -} - -// vfioAlias represents an entry from the modules.alias file for a vfio driver -type vfioAlias struct { - modAlias *modAlias // The modalias pattern - driver string // The vfio driver name -} - -func parseModAliasString(input string) (*modAlias, error) { - if input == "" { - return nil, fmt.Errorf("modalias string is empty") - } - - input = strings.TrimSpace(input) - - // Trim the leading "pci:" prefix in the modalias file - split := strings.SplitN(input, ":", 2) - if len(split) != 2 { - return nil, fmt.Errorf("unexpected number of parts in modalias after trimming 'pci:' prefix: %s", input) - } - input = split[1] - - if !strings.HasPrefix(input, "v") { - return nil, fmt.Errorf("modalias must start with 'v', got: %s", input) - } - - ma := &modAlias{} - var before, after string - var found bool - after = input[1:] // cut leading 'v' - - before, after, found = strings.Cut(after, "d") - if !found { - return nil, fmt.Errorf("failed to find delimiter 'd' in %q", input) - } - ma.vendor = before - - before, after, found = strings.Cut(after, "sv") - if !found { - return nil, fmt.Errorf("failed to find delimiter 'sv' in %q", input) - } - ma.device = before - - before, after, found = strings.Cut(after, "sd") - if !found { - return nil, fmt.Errorf("failed to find delimiter 'sd' in %q", input) - } - ma.subvendor = before - - before, after, found = strings.Cut(after, "bc") - if !found { - return nil, fmt.Errorf("failed to find delimiter 'bc' in %q", input) - } - ma.subdevice = before - - before, after, found = strings.Cut(after, "sc") - if !found { - return nil, fmt.Errorf("failed to find delimiter 'sc' in input %q", input) - } - ma.baseClass = before - - before, after, found = strings.Cut(after, "i") - if !found { - return nil, fmt.Errorf("failed to find delimiter 'i' in %q", input) - } - ma.subClass = before - ma.programmingInterface = after - - return ma, nil -} - -func getKernelVersion() (string, error) { - var uname unix.Utsname - if err := unix.Uname(&uname); err != nil { - return "", err - } - - // Convert C-style byte array to Go string - release := make([]byte, 0, len(uname.Release)) - for _, c := range uname.Release { - if c == 0 { - break - } - release = append(release, c) - } - - return string(release), nil -} - -// getVFIOAliases returns the vfio driver aliases from the input string. -// The input string is expected to be the content of a modules.alias file. -// Only lines that begin with 'alias vfio_pci:' are parsed, with the -// format being: -// -// alias vfio_pci: -func getVFIOAliases(input string) []vfioAlias { - var aliases []vfioAlias - - lines := strings.Split(input, "\n") - for _, line := range lines { - line = strings.TrimSpace(line) - - if !strings.HasPrefix(line, vfioPciAliasPrefix) { - continue - } - - split := strings.SplitN(line, " ", 3) - if len(split) != 3 { - continue - } - modAliasStr := split[1] - modAlias, err := parseModAliasString(modAliasStr) - if err != nil { - continue - } - - driver := split[2] - aliases = append(aliases, vfioAlias{ - modAlias: modAlias, - driver: driver, - }) - } - - return aliases -} - -// findBestMatch finds the best matching VFIO driver for the given modalias -// by comparing against all available vfio alias patterns. The best match -// is the one with the fewest wildcard characters. -func findBestMatch(deviceModAlias *modAlias, aliases []vfioAlias) string { - var bestDriver string - bestWildcardCount := math.MaxInt - - for _, alias := range aliases { - if matches, wildcardCount := matchModalias(deviceModAlias, alias.modAlias); matches { - if wildcardCount < bestWildcardCount { - bestDriver = alias.driver - bestWildcardCount = wildcardCount - } - } - } - - return bestDriver -} - -// matchModalias checks if a device modalias matches a pattern from modules.alias -// Returns true if it matches and the number of wildcards -func matchModalias(deviceModAlias, patternModAlias *modAlias) (bool, int) { - wildcardCount := 0 - - modAliasType := reflect.TypeOf(*deviceModAlias) - deviceModAliasValue := reflect.ValueOf(*deviceModAlias) - patternModAliasValue := reflect.ValueOf(*patternModAlias) - - // iterate over both modAlias structs, comparing each field - for i := 0; i < modAliasType.NumField(); i++ { - deviceValue := deviceModAliasValue.Field(i).String() - patternValue := patternModAliasValue.Field(i).String() - - if patternValue == "*" { - wildcardCount++ - continue - } - - if deviceValue != patternValue { - return false, wildcardCount - } - } - return true, wildcardCount -} diff --git a/internal/nvpassthrough/modalias_test.go b/internal/nvpassthrough/modalias_test.go deleted file mode 100644 index c78cf147..00000000 --- a/internal/nvpassthrough/modalias_test.go +++ /dev/null @@ -1,292 +0,0 @@ -/* - * Copyright (c) NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package nvpassthrough - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestParseModAliasString(t *testing.T) { - testCases := []struct { - description string - input string - expectedOutput *modAlias - expectedError bool - }{ - { - description: "empty string", - input: "", - expectedError: true, - }, - { - description: "more than one semicolon delimiter", - input: "pci:foo:bar", - expectedError: true, - }, - { - description: "all wildcards", - input: "pci:v*d*sv*sd*bc*sc*i*", - expectedOutput: &modAlias{ - vendor: "*", - device: "*", - subvendor: "*", - subdevice: "*", - baseClass: "*", - subClass: "*", - programmingInterface: "*", - }, - }, - { - description: "some wildcards", - input: "pci:v000010DEd00002941sv*sd*bc*sc*i*", - expectedOutput: &modAlias{ - vendor: "000010DE", - device: "00002941", - subvendor: "*", - subdevice: "*", - baseClass: "*", - subClass: "*", - programmingInterface: "*", - }, - }, - { - description: "no wildcards", - input: "pci:v000010DEd00002941sv000010DEsd00002046bc03sc02i00", - expectedOutput: &modAlias{ - vendor: "000010DE", - device: "00002941", - subvendor: "000010DE", - subdevice: "00002046", - baseClass: "03", - subClass: "02", - programmingInterface: "00", - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.description, func(t *testing.T) { - modAlias, err := parseModAliasString(tc.input) - if tc.expectedError { - require.Error(t, err) - return - } - require.NotNil(t, modAlias) - require.EqualValues(t, *tc.expectedOutput, *modAlias) - }) - } -} - -func TestGetVFIOAliases(t *testing.T) { - testCases := []struct { - description string - input string - expectedOutput []vfioAlias - }{ - { - description: "empty string", - input: "", - expectedOutput: nil, - }, - { - description: "no vfio aliases", - input: ` -alias foo:v*d*sv*sd*bc*sc*i* bar -alias pci:v000010DEd00002941sv*sd*bc*sc*i* foo -`, - expectedOutput: nil, - }, - { - description: "vfio aliases present", - input: ` -alias foo:v*d*sv*sd*bc*sc*i* bar -alias pci:v000010DEd00002941sv*sd*bc*sc*i* foo -alias vfio_pci:v*d*sv*sd*bc*sc*i* vfio_pci -alias vfio_pci:v000010DEd00002941sv*sd*bc*sc*i* nvgrace_gpu_vfio_pci -`, - expectedOutput: []vfioAlias{ - { - driver: "vfio_pci", - modAlias: &modAlias{ - vendor: "*", - device: "*", - subvendor: "*", - subdevice: "*", - baseClass: "*", - subClass: "*", - programmingInterface: "*", - }, - }, - { - driver: "nvgrace_gpu_vfio_pci", - modAlias: &modAlias{ - vendor: "000010DE", - device: "00002941", - subvendor: "*", - subdevice: "*", - baseClass: "*", - subClass: "*", - programmingInterface: "*", - }, - }, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.description, func(t *testing.T) { - vfioAliases := getVFIOAliases(tc.input) - require.EqualValues(t, tc.expectedOutput, vfioAliases) - }) - } -} - -func TestMatchModalias(t *testing.T) { - testCases := []struct { - description string - modalias modAlias - compareTo modAlias - expectedMatch bool - expectedWildcardCount int - }{ - { - description: "all wildcards", - modalias: modAlias{ - vendor: "000010DE", - device: "00002941", - subvendor: "000010DE", - subdevice: "00002046", - baseClass: "03", - subClass: "02", - programmingInterface: "00", - }, - compareTo: modAlias{ - vendor: "*", - device: "*", - subvendor: "*", - subdevice: "*", - baseClass: "*", - subClass: "*", - programmingInterface: "*", - }, - expectedMatch: true, - expectedWildcardCount: 7, - }, - { - description: "some wildcards, match", - modalias: modAlias{ - vendor: "000010DE", - device: "00002941", - subvendor: "000010DE", - subdevice: "00002046", - baseClass: "03", - subClass: "02", - programmingInterface: "00", - }, - compareTo: modAlias{ - vendor: "000010DE", - device: "00002941", - subvendor: "*", - subdevice: "*", - baseClass: "*", - subClass: "*", - programmingInterface: "*", - }, - expectedMatch: true, - expectedWildcardCount: 5, - }, - { - description: "some wildcards, not a match", - modalias: modAlias{ - vendor: "000010DE", - device: "00002941", - subvendor: "000010DE", - subdevice: "00002046", - baseClass: "03", - subClass: "02", - programmingInterface: "00", - }, - compareTo: modAlias{ - vendor: "000010DE", - device: "00002900", - subvendor: "*", - subdevice: "*", - baseClass: "*", - subClass: "*", - programmingInterface: "*", - }, - expectedMatch: false, - }, - { - description: "no wildcards, match", - modalias: modAlias{ - vendor: "000010DE", - device: "00002941", - subvendor: "000010DE", - subdevice: "00002046", - baseClass: "03", - subClass: "02", - programmingInterface: "00", - }, - compareTo: modAlias{ - vendor: "000010DE", - device: "00002941", - subvendor: "000010DE", - subdevice: "00002046", - baseClass: "03", - subClass: "02", - programmingInterface: "00", - }, - expectedMatch: true, - expectedWildcardCount: 0, - }, - { - description: "no wildcards, not a match", - modalias: modAlias{ - vendor: "000010DE", - device: "00002941", - subvendor: "000010DE", - subdevice: "00002046", - baseClass: "03", - subClass: "02", - programmingInterface: "00", - }, - compareTo: modAlias{ - vendor: "00001111", - device: "00002222", - subvendor: "0000333", - subdevice: "00004444", - baseClass: "05", - subClass: "06", - programmingInterface: "07", - }, - expectedMatch: false, - }, - } - - for _, tc := range testCases { - t.Run(tc.description, func(t *testing.T) { - match, wildcardCount := matchModalias(&tc.modalias, &tc.compareTo) - require.EqualValues(t, tc.expectedMatch, match) - if tc.expectedMatch { - require.EqualValues(t, tc.expectedWildcardCount, wildcardCount) - } - }) - } -} diff --git a/internal/nvpassthrough/nvpassthrough.go b/internal/nvpassthrough/nvpassthrough.go deleted file mode 100644 index 4b36ff63..00000000 --- a/internal/nvpassthrough/nvpassthrough.go +++ /dev/null @@ -1,334 +0,0 @@ -/* - * Copyright (c) NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package nvpassthrough - -import ( - "fmt" - "os" - "path/filepath" - "strings" - - "github.com/NVIDIA/go-nvlib/pkg/nvpci" - "github.com/sirupsen/logrus" - - "github.com/NVIDIA/k8s-driver-manager/internal/linuxutils" -) - -const ( - pciRootDir = "/sys/bus/pci/" - pciDevicesRoot = pciRootDir + "devices" - pciDriversRoot = pciRootDir + "drivers" - vfioPCIDriverName = "vfio-pci" - consumerPrefix = "consumer:pci:" - libModulesRoot = "/lib/modules/" -) - -type Interface interface { - FindBestVFIOVariant(*nvpci.NvidiaPCIDevice) (string, error) - BindToVFIODriver(*nvpci.NvidiaPCIDevice) error - UnbindFromDriver(*nvpci.NvidiaPCIDevice) error -} - -type nvpassthrough struct { - logger *logrus.Logger - hostRoot string -} - -type nvidiaPCIAuxDevice struct { - Path string - Address string - Driver string -} - -func New(opts ...Option) Interface { - n := &nvpassthrough{} - for _, opt := range opts { - opt(n) - } - if n.logger == nil { - n.logger = logrus.New() - } - if n.hostRoot == "" { - n.hostRoot = "/" - } - - return n -} - -// Option defines a function for passing options to the New() call. -type Option func(*nvpassthrough) - -// WithLogger provides an Option to set the logger for the library. -func WithLogger(logger *logrus.Logger) Option { - return func(w *nvpassthrough) { - w.logger = logger - } -} - -// WithHostRoot provides an Option to set the path to the host root filesystem -func WithHostRoot(hostRoot string) Option { - return func(w *nvpassthrough) { - w.hostRoot = hostRoot - } -} - -// FindBestVFIOVariant finds the "best" match of all vfio_pci aliases for -// device in the host modules.alias file. This uses the algorithm of -// finding every modules.alias line that begins with "alias vfio_pci:", -// then picking the one that matches the device's own modalias value -// (from the file of that name in the device's sysfs directory) with the -// fewest "wildcards" (* character, meaning "match any value for this -// attribute"). -// -// (cdesiniotis) this code is inspired by: -// https://gitlab.com/libvirt/libvirt/-/commit/82e2fac297105f554f57fb589002933231b4f711 -func (n *nvpassthrough) FindBestVFIOVariant(device *nvpci.NvidiaPCIDevice) (string, error) { - modAliasPath := filepath.Join(device.Path, "modalias") - modAliasContent, err := os.ReadFile(modAliasPath) - if err != nil { - return "", fmt.Errorf("failed to read modalias file for %s: %w", device.Address, err) - } - - modAliasStr := strings.TrimSpace(string(modAliasContent)) - modAlias, err := parseModAliasString(modAliasStr) - if err != nil { - return "", fmt.Errorf("failed to parse modalias string %q for device %q: %w", modAliasStr, device.Address, err) - } - - kernelVersion, err := getKernelVersion() - if err != nil { - return "", fmt.Errorf("failed to get kernel version: %w", err) - } - - modulesAliasFilePath := filepath.Join(libModulesRoot, kernelVersion, "modules.alias") - modulesAliasContent, err := os.ReadFile(modulesAliasFilePath) - if err != nil { - return "", fmt.Errorf("failed to read file %s: %w", modulesAliasFilePath, err) - } - - // Get all vfio aliases from the modules.alias file - // (all lines starting with 'alias vfio_pci:') - vfioAliases := getVFIOAliases(string(modulesAliasContent)) - if len(vfioAliases) == 0 { - n.logger.Debugf("No vfio_pci entries found in modules.alias file, falling back to default vfio-pci driver") - return vfioPCIDriverName, nil - } - - // Find the best matching VFIO driver for this device - bestMatch := findBestMatch(modAlias, vfioAliases) - if bestMatch == "" { - n.logger.Debugf("No matching vfio driver found for device %s in modules.alias file, falling back to default vfio-pci driver", device.Address) - return vfioPCIDriverName, nil - } - - return bestMatch, nil -} - -// BindToVFIODriver binds the provided NVIDIA PCI device to the -// vfio-pci driver (or a variant VFIO driver if one is preferred). -// This function takes care of additional logic, like making sure -// the vfio-pci driver is loaded first and that an auxiliary graphics -// device also get bound to the vfio-pci driver. -func (n *nvpassthrough) BindToVFIODriver(device *nvpci.NvidiaPCIDevice) error { - vfioDriverName, err := n.FindBestVFIOVariant(device) - if err != nil { - return fmt.Errorf("failed to find best vfio variant driver: %w", err) - } - - km := linuxutils.NewKernelModules(n.logger, linuxutils.WithRoot(n.hostRoot)) - if err := km.Load(vfioDriverName); err != nil { - return fmt.Errorf("failed to load %q driver: %w", vfioDriverName, err) - } - - // (cdesiniotis) Module names in the modules.alias file will only ever contain - // underscores characters and not dashes -- this aligns with how the linux kernel - // stores module names internally. This can sometimes differ from the name of the - // directory in /sys/bus/pci/driver/ for a given module. For example, this - // contradiction exists for the standard vfio-pci module: - // - // $ file /sys/bus/pci/drivers/vfio-pci - // sys/bus/pci/drivers/vfio-pci: directory - // - // $ modinfo vfio-pci | grep ^name: - // name: vfio_pci - // - // To account for this difference, we check if the module name returned by - // findBestVFIOVariant() exists in /sys/bus/pci/drivers, and if not, we try - // again but with any underscore characters converted to dashes. - driverDir := filepath.Join(pciDriversRoot, vfioDriverName) - if _, err := os.Stat(driverDir); err != nil { - vfioDriverNameNormalized := strings.ReplaceAll(vfioDriverName, "_", "-") - driverDir = filepath.Join(pciDriversRoot, vfioDriverNameNormalized) - if _, err := os.Stat(driverDir); err != nil { - return fmt.Errorf("failed to find directory for vfio driver %s at %s, is the module loaded?", vfioDriverName, pciDriversRoot) - } - vfioDriverName = vfioDriverNameNormalized - } - - n.logger.Infof("Binding device %s to driver: %s", device.Address, vfioDriverName) - - if device.Driver != vfioDriverName { - if err := unbind(device.Address); err != nil { - return fmt.Errorf("failed to unbind device %s: %w", device.Address, err) - } - if err := bind(device.Address, vfioDriverName); err != nil { - return fmt.Errorf("failed to bind device %s to %s: %w", device.Address, vfioDriverName, err) - } - } - - // For graphics mode, bind the auxiliary device as well - auxDev, err := getGraphicsAuxDev(device) - if err != nil { - return fmt.Errorf("failed to get graphics auxiliary device for %s: %w", device.Address, err) - } - if auxDev == nil { - return nil - } - if auxDev.Driver == vfioDriverName { - return nil - } - - n.logger.Infof("Binding graphics auxiliary device %s to driver: %s", auxDev.Address, vfioDriverName) - - if err := unbind(auxDev.Address); err != nil { - return fmt.Errorf("failed to unbind graphics auxiliary device %s: %w", auxDev.Address, err) - } - if err := bind(auxDev.Address, vfioDriverName); err != nil { - return fmt.Errorf("failed to bind graphics auxiliary device %s to %s: %w", auxDev, vfioDriverName, err) - } - - return nil -} - -// UnbindFromDriver unbinds the provided NVIDIA PCI Device from -// any driver it is currently bound to. This function also ensures -// an auxiliary graphics device is also unbound. -func (n *nvpassthrough) UnbindFromDriver(device *nvpci.NvidiaPCIDevice) error { - if err := unbind(device.Address); err != nil { - return fmt.Errorf("failed to unbind device %s: %w", device.Address, err) - } - - // For graphics mode, unbind the auxiliary device as well - auxDev, err := getGraphicsAuxDev(device) - if err != nil { - return fmt.Errorf("failed to get graphics auxiliary device for %s: %w", device.Address, err) - } - if auxDev != nil { - if err := unbind(auxDev.Address); err != nil { - return fmt.Errorf("failed to unbind graphics auxiliary device %s: %w", auxDev.Address, err) - } - } - - return nil -} - -func bind(device string, driver string) error { - driverOverridePath := filepath.Join(pciDevicesRoot, device, "driver_override") - if err := os.WriteFile(driverOverridePath, []byte(driver), 0644); err != nil { - return fmt.Errorf("failed to set driver_override for %s: %w", device, err) - } - - bindPath := filepath.Join(pciDriversRoot, driver, "bind") - if err := os.WriteFile(bindPath, []byte(device), 0644); err != nil { - return fmt.Errorf("failed to bind %s to %s: %w", device, driver, err) - } - - return nil -} - -func unbind(device string) error { - driverOverridePath := filepath.Join(pciDevicesRoot, device, "driver_override") - if err := os.WriteFile(driverOverridePath, []byte("\n"), 0644); err != nil { - return fmt.Errorf("failed to clear driver_override for %s: %w", device, err) - } - - driverPath := filepath.Join(pciDevicesRoot, device, "driver") - if _, err := os.Stat(driverPath); os.IsNotExist(err) { - return nil - } - - driverLink, err := os.Readlink(driverPath) - if err != nil { - return fmt.Errorf("failed to read driver link for %s: %w", device, err) - } - driverName := filepath.Base(driverLink) - - unbindPath := filepath.Join(driverPath, "unbind") - if err := os.WriteFile(unbindPath, []byte(device), 0644); err != nil { - return fmt.Errorf("failed to unbind %s from %s: %w", device, driverName, err) - } - - return nil -} - -func getGraphicsAuxDev(device *nvpci.NvidiaPCIDevice) (*nvidiaPCIAuxDevice, error) { - if device.Class != nvpci.PCIVgaControllerClass { - return nil, nil - } - - // Look for consumer symlink - entries, err := os.ReadDir(device.Path) - if err != nil { - return nil, err - } - - for _, entry := range entries { - if strings.HasPrefix(entry.Name(), "consumer") { - // Extract aux device name from consumer:pci:XXXX:XX:XX.X format - parts := strings.Split(entry.Name(), consumerPrefix) - if len(parts) != 2 { - continue - } - - address := parts[1] - if address == "" { - continue - } - - // Check if aux device exists - path := filepath.Join(pciDevicesRoot, address) - if _, err := os.Stat(path); err != nil { - continue - } - - auxDev := &nvidiaPCIAuxDevice{ - Path: path, - Address: address, - } - - driver, err := getDriver(path) - if err != nil { - return nil, fmt.Errorf("failed to get driver for graphics auxiliary device %s: %w", address, err) - } - auxDev.Driver = driver - return auxDev, nil - } - } - - return nil, nil -} - -func getDriver(devicePath string) (string, error) { - driver, err := filepath.EvalSymlinks(filepath.Join(devicePath, "driver")) - switch { - case os.IsNotExist(err): - return "", nil - case err == nil: - return filepath.Base(driver), nil - } - return "", err -}