Skip to content
52 changes: 47 additions & 5 deletions disk.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@ package disko

import (
"encoding/json"
"errors"
"fmt"
"sort"
"strings"

"machinerun.io/disko/partid"
)

// ErrDiskTypeUndetermined is returned by RAIDController.GetDiskType when
// the controller layer cannot determine the disk type. Typical cases:
// the device is on the controller's sysfs tree but is not a configured
// virtual/logical drive and cannot be matched as a JBOD/passthrough
// disk, or controller queries were inconclusive (e.g. controller tool
// unavailable). Callers should fall back to generic udev-based detection.
var ErrDiskTypeUndetermined = errors.New("RAID controller could not determine disk type")

// DiskType enumerates supported disk types.
type DiskType int

Expand All @@ -24,19 +33,28 @@ const (

// TYPEFILE - A file on disk, not a block device.
TYPEFILE

// Unknown is an internal disko placeholder returned alongside a
// non-nil error by RAID-controller drivers when they cannot classify
// a device. It MUST NOT leak out of disko onto disko.Disk.Type: the
// linux system layer either consumes it via udev fallback or
// propagates the accompanying error. External callers should never
// observe this value.
Unknown
)

func (t DiskType) String() string {
return []string{"HDD", "SSD", "NVME", "FILE"}[t]
return []string{"HDD", "SSD", "NVME", "FILE", "UNKNOWN"}[t]
}

// StringToDiskType - convert a string to a disk type.
func StringToDiskType(typeStr string) DiskType {
kmap := map[string]DiskType{
"HDD": HDD,
"SSD": SSD,
"NVME": NVME,
"FILE": TYPEFILE,
"HDD": HDD,
"SSD": SSD,
"NVME": NVME,
"FILE": TYPEFILE,
"UNKNOWN": Unknown,
}
if dtype, ok := kmap[typeStr]; ok {
return dtype
Expand Down Expand Up @@ -485,6 +503,30 @@ type UdevInfo struct {
Properties map[string]string `json:"properties"`
}

// CollectSerials returns the set of udev serial-style tokens that
// identify this device. Used by RAID drivers (megaraid, smartpqi) to
// correlate a Linux device with a controller-reported physical drive
// serial number.
//
// Controllers typically expose the SCSI INQUIRY page-80 serial, which
// udev surfaces as ID_SCSI_SERIAL on SCSI-class devices and is the
// primary key. ID_SERIAL_SHORT and ID_SERIAL are WWN-derived on most
// SAS/SATA drives but cover drives that don't expose a distinct VPD
// page-80 serial. All non-empty values are returned so callers can match
// any of them against a controller record.
func (u UdevInfo) CollectSerials() map[string]struct{} {
out := map[string]struct{}{}

for _, key := range []string{"ID_SCSI_SERIAL", "ID_SERIAL_SHORT", "ID_SERIAL"} {
v := strings.TrimSpace(u.Properties[key])
if v != "" {
out[v] = struct{}{}
}
}

return out
}

// PartitionSet is a map of partition number to the partition.
type PartitionSet map[uint]Partition

Expand Down
58 changes: 58 additions & 0 deletions disk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"strings"
"testing"

"github.com/stretchr/testify/assert"

"machinerun.io/disko"
"machinerun.io/disko/partid"
)
Expand Down Expand Up @@ -394,3 +396,59 @@ func TestMarshalProperties(t *testing.T) {
}
}
}

func TestUdevInfoCollectSerialsPrefersBoth(t *testing.T) {
ud := disko.UdevInfo{
Properties: map[string]string{
"ID_SERIAL_SHORT": "short1",
"ID_SERIAL": "long_short1",
"ID_MODEL": "ignored",
},
}

got := ud.CollectSerials()
assert.Contains(t, got, "short1", "missing ID_SERIAL_SHORT token")
assert.Contains(t, got, "long_short1", "missing ID_SERIAL token")
assert.NotContains(t, got, "ignored",
"CollectSerials must not include unrelated properties")
}

func TestUdevInfoCollectSerialsIncludesSCSISerial(t *testing.T) {
ud := disko.UdevInfo{
Properties: map[string]string{
"ID_SCSI_SERIAL": "TESTSN0PQI01",
"ID_SERIAL_SHORT": "deadbeefcafef001",
"ID_SERIAL": "3deadbeefcafef001",
},
}

got := ud.CollectSerials()
for _, want := range []string{
"TESTSN0PQI01",
"deadbeefcafef001",
"3deadbeefcafef001",
} {
assert.Contains(t, got, want, "CollectSerials: missing token %q", want)
}
}

// Whitespace-only values must not be returned as serial tokens; otherwise
// a controller record with a blank SerialNumber could match every device.
func TestUdevInfoCollectSerialsTrimsWhitespace(t *testing.T) {
ud := disko.UdevInfo{
Properties: map[string]string{
"ID_SCSI_SERIAL": " ",
"ID_SERIAL_SHORT": "",
"ID_SERIAL": "real-serial",
},
}

got := ud.CollectSerials()
assert.Contains(t, got, "real-serial", "missing real serial token")
assert.Len(t, got, 1, "expected only the non-empty token")
}

func TestUdevInfoCollectSerialsEmpty(t *testing.T) {
got := disko.UdevInfo{}.CollectSerials()
assert.Empty(t, got, "expected empty map for empty UdevInfo")
}
3 changes: 1 addition & 2 deletions linux/raidcontroller.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ const (

type RAIDController interface {
// Type() RAIDControllerType
GetDiskType(string) (disko.DiskType, error)
IsSysPathRAID(string) bool
GetDiskType(path string, udInfo disko.UdevInfo) (disko.DiskType, error)
DriverSysfsPath() string
}
64 changes: 64 additions & 0 deletions linux/sysfs/scsi.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package sysfs

import (
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
)

// ReadSCSITarget returns Target from /sys/block/<kname>/device ->
// Host:Channel:Target:LUN (H:C:T:L). For SCSI-attached JBOD/passthrough disks
// behind a RAID HBA, Target matches the controller-reported drive ID
// (megaraid Drive.DID, mpi3mr PhysicalDrive.PID), which lets callers
// correlate a Linux block device with an entry in the controller's PD list.
//
// The "device" entry under a SCSI block device is a symlink into
// /sys/class/scsi_device/, e.g.:
//
// % ls -al /sys/block/sda/device
// lrwxrwxrwx 1 root root 0 Jan 31 16:11 /sys/block/sda/device -> ../../../2:0:0:0
//
// Callers must only invoke ReadSCSITarget for devices that udev reports as
// SCSI (ID_SCSI=1); virtio-blk, NVMe, ATA/SATA, etc. do not expose this
// symlink and should be filtered upstream. ok=false with a nil error is
// reserved for benign cases (empty kname, missing "device" symlink); a
// malformed Host:Channel:Target:LUN link target is reported as an
// error. sysRoot is injectable for tests; production passes "/sys".
func ReadSCSITarget(sysRoot, kname string) (target int, ok bool, err error) {
if kname == "" {
return 0, false, nil

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return 0, false, nil
return 0, false, fmt.Errorf("invalid empty 'kname' parameter")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we'd call this if kname is empty

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't. I made the change.

}

link := filepath.Join(sysRoot, "block", kname, "device")

dest, err := os.Readlink(link)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return 0, false, nil
}
return 0, false, fmt.Errorf("readlink %q: %w", link, err)
}

// Parse out the SCSI device id from the symlink. e.g.
// ../../../0:3:110:0 -> 0:3:110:0
// matching an entry under /sys/class/scsi_device/.
scsiDeviceID := filepath.Base(dest)
hctlFields := strings.Split(scsiDeviceID, ":")

const requiredHCTLFields = 4
if len(hctlFields) != requiredHCTLFields {
return 0, false, fmt.Errorf(
"invalid SCSI Host:Channel:Target:LUN value %q from %q: expected %d fields, got %d",
scsiDeviceID, link, requiredHCTLFields, len(hctlFields))
}

t, cerr := strconv.Atoi(hctlFields[2])
if cerr != nil {
return 0, false, fmt.Errorf("parse SCSI target from %q: %w", scsiDeviceID, cerr)
}

return t, true, nil
}
83 changes: 83 additions & 0 deletions linux/sysfs/scsi_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package sysfs

import (
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// makeBlockDeviceSymlink wires up a fake sysfs entry at
// <root>/block/<kname>/device pointing at
// ../../scsi_device/<host:channel:target:lun>.
func makeBlockDeviceSymlink(t *testing.T, root, kname, hctl string) {
t.Helper()

blockDir := filepath.Join(root, "block", kname)
require.NoError(t, os.MkdirAll(blockDir, 0o755), "mkdir %q", blockDir)

scsiDir := filepath.Join(root, "scsi_device", hctl)
require.NoError(t, os.MkdirAll(scsiDir, 0o755), "mkdir %q", scsiDir)

link := filepath.Join(blockDir, "device")
target := filepath.Join("..", "..", "scsi_device", hctl)
require.NoError(t, os.Symlink(target, link), "symlink")
}

func TestReadSCSITargetJBOD(t *testing.T) {
root := t.TempDir()
makeBlockDeviceSymlink(t, root, "sdb", "0:2:3:0")

target, ok, err := ReadSCSITarget(root, "sdb")
require.NoError(t, err)
require.True(t, ok, "expected ok=true for a SCSI-backed block device")
assert.Equal(t, 3, target, "target")
}

// An NVMe/virtio-style block device has no Host:Channel:Target:LUN
// "device" symlink.
// ReadSCSITarget must report ok=false (not an error) so the caller can
// fall through to generic udev detection.
func TestReadSCSITargetNoDevice(t *testing.T) {
Comment thread
andaaron marked this conversation as resolved.
root := t.TempDir()
require.NoError(t, os.MkdirAll(filepath.Join(root, "block", "nvme0n1"), 0o755))

_, ok, err := ReadSCSITarget(root, "nvme0n1")
require.NoError(t, err)
assert.False(t, ok, "expected ok=false when device symlink is absent")
}

// A "device" symlink whose last segment isn't Host:Channel:Target:LUN
// (e.g. points at a PCI node) is malformed for a SCSI block device and
// should be reported as an error so the caller can log/diagnose. Callers
// are expected to filter non-SCSI devices upstream via udev (ID_SCSI=1).
func TestReadSCSITargetNonHCTL(t *testing.T) {
root := t.TempDir()
blockDir := filepath.Join(root, "block", "vda")
require.NoError(t, os.MkdirAll(blockDir, 0o755))
other := filepath.Join(root, "devices", "virtio0")
require.NoError(t, os.MkdirAll(other, 0o755))
require.NoError(t, os.Symlink(filepath.Join("..", "..", "devices", "virtio0"),
filepath.Join(blockDir, "device")))

_, ok, err := ReadSCSITarget(root, "vda")
require.Error(t, err, "expected error for malformed Host:Channel:Target:LUN link target")
assert.False(t, ok, "expected ok=false when link target is not Host:Channel:Target:LUN")
}

func TestReadSCSITargetBadTarget(t *testing.T) {
root := t.TempDir()
makeBlockDeviceSymlink(t, root, "sdc", "0:2:notanum:0")

_, ok, err := ReadSCSITarget(root, "sdc")
require.Error(t, err, "expected parse error")
assert.False(t, ok, "expected ok=false on parse error")
}

func TestReadSCSITargetEmptyKname(t *testing.T) {
_, ok, err := ReadSCSITarget("/sys", "")
require.NoError(t, err)
assert.False(t, ok, "expected ok=false for empty kname")
}
66 changes: 66 additions & 0 deletions linux/sysfs/sysfs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Package sysfs contains Linux-specific helpers for inspecting the sysfs
// hierarchy. The helpers are driver-agnostic so they can be shared by the
// top-level linux package and by individual RAID driver packages without
// creating an import cycle back into linux.
package sysfs

import (
"fmt"
"path/filepath"
"strings"
)

// IsSysPathRAID checks whether syspath (udevadm DEVPATH) belongs to a RAID
// controller whose PCI driver is registered at driverSysPath.
//
// syspath will look something like
// /devices/pci0000:3a/0000:3a:02.0/0000:3c:00.0/host0/target0:2:2/0:2:2:0/block/sdc
func IsSysPathRAID(syspath string, driverSysPath string) bool {
if !strings.HasPrefix(syspath, "/sys") {
syspath = "/sys" + syspath
}

if !strings.Contains(syspath, "/host") {
return false
}

fp, err := filepath.EvalSymlinks(syspath)
if err != nil {
fmt.Printf("seriously? %s\n", err)
return false
}

for _, path := range GetSysPaths(driverSysPath) {
if strings.HasPrefix(fp, path) {
return true
}
}

return false
}

// GetSysPaths returns the resolved PCI device paths for a RAID driver.
func GetSysPaths(driverSysPath string) []string {
paths := []string{}
// a raid driver has directory entries for each of the scsi hosts on that controller.
// $cd /sys/bus/pci/drivers/<driver name>
// $ for d in *; do [ -d "$d" ] || continue; echo "$d -> $( cd "$d" && pwd -P )"; done
// 0000:3c:00.0 -> /sys/devices/pci0000:3a/0000:3a:02.0/0000:3c:00.0
// module -> /sys/module/<driver module name>

// We take a hack path and consider anything with a ":" in that dir as a host path.
matches, err := filepath.Glob(driverSysPath + "/*:*")

if err != nil {
fmt.Printf("errors: %s\n", err)
return paths
}

for _, p := range matches {
if fp, err := filepath.EvalSymlinks(p); err == nil {
paths = append(paths, fp)
}
}

return paths
}
Loading
Loading