Skip to content
52 changes: 47 additions & 5 deletions disk.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@ package disko

import (
"encoding/json"
"errors"
"fmt"
"sort"
"strings"

"machinerun.io/disko/partid"
)

// ErrDiskTypeUndetermined is returned by RAIDController.GetDiskType when
// the controller layer cannot determine the disk type. Typical cases:
// the device is on the controller's sysfs tree but is not a configured
// virtual/logical drive and cannot be matched as a JBOD/passthrough
// disk, or controller queries were inconclusive (e.g. controller tool
// unavailable). Callers should fall back to generic udev-based detection.
var ErrDiskTypeUndetermined = errors.New("RAID controller could not determine disk type")

// DiskType enumerates supported disk types.
type DiskType int

Expand All @@ -24,19 +33,28 @@ const (

// TYPEFILE - A file on disk, not a block device.
TYPEFILE

// Unknown is an internal disko placeholder returned alongside a
// non-nil error by RAID-controller drivers when they cannot classify
// a device. It MUST NOT leak out of disko onto disko.Disk.Type: the
// linux system layer either consumes it via udev fallback or
// propagates the accompanying error. External callers should never
// observe this value.
Unknown
)

func (t DiskType) String() string {
return []string{"HDD", "SSD", "NVME", "FILE"}[t]
return []string{"HDD", "SSD", "NVME", "FILE", "UNKNOWN"}[t]
}

// StringToDiskType - convert a string to a disk type.
func StringToDiskType(typeStr string) DiskType {
kmap := map[string]DiskType{
"HDD": HDD,
"SSD": SSD,
"NVME": NVME,
"FILE": TYPEFILE,
"HDD": HDD,
"SSD": SSD,
"NVME": NVME,
"FILE": TYPEFILE,
"UNKNOWN": Unknown,
}
if dtype, ok := kmap[typeStr]; ok {
return dtype
Expand Down Expand Up @@ -485,6 +503,30 @@ type UdevInfo struct {
Properties map[string]string `json:"properties"`
}

// CollectSerials returns the set of udev serial-style tokens that
// identify this device. Used by RAID drivers (megaraid, smartpqi) to
// correlate a Linux device with a controller-reported physical drive
// serial number.
//
// Controllers typically expose the SCSI INQUIRY page-80 serial, which
// udev surfaces as ID_SCSI_SERIAL on SCSI-class devices and is the
// primary key. ID_SERIAL_SHORT and ID_SERIAL are WWN-derived on most
// SAS/SATA drives but cover drives that don't expose a distinct VPD
// page-80 serial. All non-empty values are returned so callers can match
// any of them against a controller record.
func (u UdevInfo) CollectSerials() map[string]struct{} {
out := map[string]struct{}{}

for _, key := range []string{"ID_SCSI_SERIAL", "ID_SERIAL_SHORT", "ID_SERIAL"} {
v := strings.TrimSpace(u.Properties[key])
if v != "" {
out[v] = struct{}{}
}
}

return out
}

// PartitionSet is a map of partition number to the partition.
type PartitionSet map[uint]Partition

Expand Down
58 changes: 58 additions & 0 deletions disk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"strings"
"testing"

"github.com/stretchr/testify/assert"

"machinerun.io/disko"
"machinerun.io/disko/partid"
)
Expand Down Expand Up @@ -394,3 +396,59 @@ func TestMarshalProperties(t *testing.T) {
}
}
}

func TestUdevInfoCollectSerialsPrefersBoth(t *testing.T) {
ud := disko.UdevInfo{
Properties: map[string]string{
"ID_SERIAL_SHORT": "short1",
"ID_SERIAL": "long_short1",
"ID_MODEL": "ignored",
},
}

got := ud.CollectSerials()
assert.Contains(t, got, "short1", "missing ID_SERIAL_SHORT token")
assert.Contains(t, got, "long_short1", "missing ID_SERIAL token")
assert.NotContains(t, got, "ignored",
"CollectSerials must not include unrelated properties")
}

func TestUdevInfoCollectSerialsIncludesSCSISerial(t *testing.T) {
ud := disko.UdevInfo{
Properties: map[string]string{
"ID_SCSI_SERIAL": "TESTSN0PQI01",
"ID_SERIAL_SHORT": "deadbeefcafef001",
"ID_SERIAL": "3deadbeefcafef001",
},
}

got := ud.CollectSerials()
for _, want := range []string{
"TESTSN0PQI01",
"deadbeefcafef001",
"3deadbeefcafef001",
} {
assert.Contains(t, got, want, "CollectSerials: missing token %q", want)
}
}

// Whitespace-only values must not be returned as serial tokens; otherwise
// a controller record with a blank SerialNumber could match every device.
func TestUdevInfoCollectSerialsTrimsWhitespace(t *testing.T) {
ud := disko.UdevInfo{
Properties: map[string]string{
"ID_SCSI_SERIAL": " ",
"ID_SERIAL_SHORT": "",
"ID_SERIAL": "real-serial",
},
}

got := ud.CollectSerials()
assert.Contains(t, got, "real-serial", "missing real serial token")
assert.Len(t, got, 1, "expected only the non-empty token")
}

func TestUdevInfoCollectSerialsEmpty(t *testing.T) {
got := disko.UdevInfo{}.CollectSerials()
assert.Empty(t, got, "expected empty map for empty UdevInfo")
}
3 changes: 1 addition & 2 deletions linux/raidcontroller.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ const (

type RAIDController interface {
// Type() RAIDControllerType
GetDiskType(string) (disko.DiskType, error)
IsSysPathRAID(string) bool
GetDiskType(path string, udInfo disko.UdevInfo) (disko.DiskType, error)
DriverSysfsPath() string
}
65 changes: 65 additions & 0 deletions linux/sysfs/scsi.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package sysfs

import (
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
)

// ReadSCSITarget returns Target from /sys/block/<kname>/device ->
// Host:Channel:Target:LUN (H:C:T:L). For SCSI-attached JBOD/passthrough disks
// behind a RAID HBA, Target matches the controller-reported drive ID
// (megaraid Drive.DID, mpi3mr PhysicalDrive.PID), which lets callers
// correlate a Linux block device with an entry in the controller's PD list.
//
// The "device" entry under a SCSI block device is a symlink into
// /sys/class/scsi_device/, e.g.:
//
// % ls -al /sys/block/sda/device
// lrwxrwxrwx 1 root root 0 Jan 31 16:11 /sys/block/sda/device -> ../../../2:0:0:0
//
// Callers must only invoke ReadSCSITarget for devices that udev reports as
// SCSI (ID_SCSI=1); virtio-blk, NVMe, ATA/SATA, etc. do not expose this
// symlink and should be filtered upstream. An empty kname is a caller bug
// and is rejected with an error; ok=false with a nil error is reserved for
// the benign missing-"device"-symlink case; a malformed
// Host:Channel:Target:LUN link target is reported as an error. sysRoot is
// injectable for tests; production passes "/sys".
func ReadSCSITarget(sysRoot, kname string) (target int, ok bool, err error) {
if kname == "" {
return 0, false, fmt.Errorf("invalid empty kname parameter")
}

link := filepath.Join(sysRoot, "block", kname, "device")

dest, err := os.Readlink(link)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return 0, false, nil
}
return 0, false, fmt.Errorf("readlink %q: %w", link, err)
}

// Parse out the SCSI device id from the symlink. e.g.
// ../../../0:3:110:0 -> 0:3:110:0
// matching an entry under /sys/class/scsi_device/.
scsiDeviceID := filepath.Base(dest)
hctlFields := strings.Split(scsiDeviceID, ":")

const requiredHCTLFields = 4
if len(hctlFields) != requiredHCTLFields {
return 0, false, fmt.Errorf(
"invalid SCSI Host:Channel:Target:LUN value %q from %q: expected %d fields, got %d",
scsiDeviceID, link, requiredHCTLFields, len(hctlFields))
}

t, cerr := strconv.Atoi(hctlFields[2])
if cerr != nil {
return 0, false, fmt.Errorf("parse SCSI target from %q: %w", scsiDeviceID, cerr)
}

return t, true, nil
}
83 changes: 83 additions & 0 deletions linux/sysfs/scsi_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package sysfs

import (
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// makeBlockDeviceSymlink wires up a fake sysfs entry at
// <root>/block/<kname>/device pointing at
// ../../scsi_device/<host:channel:target:lun>.
func makeBlockDeviceSymlink(t *testing.T, root, kname, hctl string) {
t.Helper()

blockDir := filepath.Join(root, "block", kname)
require.NoError(t, os.MkdirAll(blockDir, 0o755), "mkdir %q", blockDir)

scsiDir := filepath.Join(root, "scsi_device", hctl)
require.NoError(t, os.MkdirAll(scsiDir, 0o755), "mkdir %q", scsiDir)

link := filepath.Join(blockDir, "device")
target := filepath.Join("..", "..", "scsi_device", hctl)
require.NoError(t, os.Symlink(target, link), "symlink")
}

func TestReadSCSITargetJBOD(t *testing.T) {
root := t.TempDir()
makeBlockDeviceSymlink(t, root, "sdb", "0:2:3:0")

target, ok, err := ReadSCSITarget(root, "sdb")
require.NoError(t, err)
require.True(t, ok, "expected ok=true for a SCSI-backed block device")
assert.Equal(t, 3, target, "target")
}

// An NVMe/virtio-style block device has no Host:Channel:Target:LUN
// "device" symlink.
// ReadSCSITarget must report ok=false (not an error) so the caller can
// fall through to generic udev detection.
func TestReadSCSITargetNoDevice(t *testing.T) {
Comment thread
andaaron marked this conversation as resolved.
root := t.TempDir()
require.NoError(t, os.MkdirAll(filepath.Join(root, "block", "nvme0n1"), 0o755))

_, ok, err := ReadSCSITarget(root, "nvme0n1")
require.NoError(t, err)
assert.False(t, ok, "expected ok=false when device symlink is absent")
}

// A "device" symlink whose last segment isn't Host:Channel:Target:LUN
// (e.g. points at a PCI node) is malformed for a SCSI block device and
// should be reported as an error so the caller can log/diagnose. Callers
// are expected to filter non-SCSI devices upstream via udev (ID_SCSI=1).
func TestReadSCSITargetNonHCTL(t *testing.T) {
root := t.TempDir()
blockDir := filepath.Join(root, "block", "vda")
require.NoError(t, os.MkdirAll(blockDir, 0o755))
other := filepath.Join(root, "devices", "virtio0")
require.NoError(t, os.MkdirAll(other, 0o755))
require.NoError(t, os.Symlink(filepath.Join("..", "..", "devices", "virtio0"),
filepath.Join(blockDir, "device")))

_, ok, err := ReadSCSITarget(root, "vda")
require.Error(t, err, "expected error for malformed Host:Channel:Target:LUN link target")
assert.False(t, ok, "expected ok=false when link target is not Host:Channel:Target:LUN")
}

func TestReadSCSITargetBadTarget(t *testing.T) {
root := t.TempDir()
makeBlockDeviceSymlink(t, root, "sdc", "0:2:notanum:0")

_, ok, err := ReadSCSITarget(root, "sdc")
require.Error(t, err, "expected parse error")
assert.False(t, ok, "expected ok=false on parse error")
}

func TestReadSCSITargetEmptyKname(t *testing.T) {
_, ok, err := ReadSCSITarget("/sys", "")
require.Error(t, err, "expected error for empty kname")
assert.False(t, ok, "expected ok=false for empty kname")
}
66 changes: 66 additions & 0 deletions linux/sysfs/sysfs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Package sysfs contains Linux-specific helpers for inspecting the sysfs
// hierarchy. The helpers are driver-agnostic so they can be shared by the
// top-level linux package and by individual RAID driver packages without
// creating an import cycle back into linux.
package sysfs

import (
"fmt"
"path/filepath"
"strings"
)

// IsSysPathRAID checks whether syspath (udevadm DEVPATH) belongs to a RAID
// controller whose PCI driver is registered at driverSysPath.
//
// syspath will look something like
// /devices/pci0000:3a/0000:3a:02.0/0000:3c:00.0/host0/target0:2:2/0:2:2:0/block/sdc
func IsSysPathRAID(syspath string, driverSysPath string) bool {
if !strings.HasPrefix(syspath, "/sys") {
syspath = "/sys" + syspath
}

if !strings.Contains(syspath, "/host") {
return false
}

fp, err := filepath.EvalSymlinks(syspath)
if err != nil {
fmt.Printf("seriously? %s\n", err)
return false
}

for _, path := range GetSysPaths(driverSysPath) {
if strings.HasPrefix(fp, path) {
return true
}
}

return false
}

// GetSysPaths returns the resolved PCI device paths for a RAID driver.
func GetSysPaths(driverSysPath string) []string {
paths := []string{}
// a raid driver has directory entries for each of the scsi hosts on that controller.
// $cd /sys/bus/pci/drivers/<driver name>
// $ for d in *; do [ -d "$d" ] || continue; echo "$d -> $( cd "$d" && pwd -P )"; done
// 0000:3c:00.0 -> /sys/devices/pci0000:3a/0000:3a:02.0/0000:3c:00.0
// module -> /sys/module/<driver module name>

// We take a hack path and consider anything with a ":" in that dir as a host path.
matches, err := filepath.Glob(driverSysPath + "/*:*")

if err != nil {
fmt.Printf("errors: %s\n", err)
return paths
}

for _, p := range matches {
if fp, err := filepath.EvalSymlinks(p); err == nil {
paths = append(paths, fp)
}
}

return paths
}
Loading
Loading