diff --git a/components/arm/fake/fake.go b/components/arm/fake/fake.go index 70c22f14021..f25ce04493f 100644 --- a/components/arm/fake/fake.go +++ b/components/arm/fake/fake.go @@ -13,6 +13,7 @@ import ( models3d "go.viam.com/rdk/components/arm/fake/3d_models" "go.viam.com/rdk/logging" "go.viam.com/rdk/motionplan" + "go.viam.com/rdk/motionplan/armplanning" "go.viam.com/rdk/referenceframe" "go.viam.com/rdk/resource" "go.viam.com/rdk/spatialmath" @@ -255,6 +256,39 @@ func (a *Arm) GoToInputs(ctx context.Context, inputSteps ...[]referenceframe.Inp return a.MoveThroughJointPositions(ctx, inputSteps, nil, nil) } +// DoCommand handles traj-gen do_command keys. It reports capability support and executes +// a precomputed trajectory by moving through each configuration in order. +func (a *Arm) DoCommand(ctx context.Context, cmd map[string]interface{}) (map[string]interface{}, error) { + resp := map[string]interface{}{} + + if _, ok := cmd[armplanning.DoCommandKeySupportsExecuteTrajGenPlan]; ok { + resp[armplanning.DoCommandKeySupportsExecuteTrajGenPlan] = true + } + + if payload, ok := cmd[armplanning.DoCommandKeyExecuteTrajGenPlan]; ok { + payloadMap, ok := payload.(map[string]any) + if !ok { + return nil, errors.Errorf("execute_traj_gen_plan payload must be a map, got %T", payload) + } + configsRaw, ok := payloadMap["configurations_rads"] + if !ok { + return nil, errors.New("execute_traj_gen_plan payload missing configurations_rads") + } + configs, ok := configsRaw.([][]float64) + if !ok { + return nil, errors.Errorf("configurations_rads has unexpected type %T", configsRaw) + } + positions := make([][]referenceframe.Input, len(configs)) + copy(positions, configs) + if err := a.MoveThroughJointPositions(ctx, positions, nil, nil); err != nil { + return nil, err + } + resp[armplanning.DoCommandKeyExecuteTrajGenPlan] = true + } + + return resp, nil +} + // Close does nothing. func (a *Arm) Close(ctx context.Context) error { a.mu.Lock() diff --git a/motionplan/armplanning/api.go b/motionplan/armplanning/api.go index c39ac81fd42..a1d069cc708 100644 --- a/motionplan/armplanning/api.go +++ b/motionplan/armplanning/api.go @@ -183,19 +183,17 @@ type PlanMeta struct { } // PlanMotion plans a motion from a provided plan request. -func PlanMotion(ctx context.Context, parentLogger logging.Logger, request *PlanRequest) (motionplan.Plan, *PlanMeta, error) { - logger := parentLogger.Sublogger("mp") - - start := time.Now() - meta := &PlanMeta{} - ctx, span := trace.StartSpan(ctx, "PlanMotion") - defer func() { - meta.Duration = time.Since(start) - span.End() - }() - +// planWaypoints runs validation, sets up the planner, and executes multi-waypoint planning. +// It populates meta.GoalsProcessed (and meta.Partial on a partial-plan error). +// The span/timing setup is left to the caller so that it covers the caller's full lifetime. +func planWaypoints( + ctx context.Context, + logger logging.Logger, + request *PlanRequest, + meta *PlanMeta, +) ([]*referenceframe.LinearInputs, error) { if err := request.validatePlanRequest(); err != nil { - return nil, meta, err + return nil, err } logger.CDebugf(ctx, "constraint specs for this step: %v", request.Constraints) logger.CDebugf(ctx, "motion config for this step: %v", request.PlannerOptions) @@ -209,12 +207,12 @@ func PlanMotion(ctx context.Context, parentLogger logging.Logger, request *PlanR // goal configurations. However, the blocker here is the lack of a "known good" configuration used to determine which obstacles // are allowed to collide with one another. if request.StartState.structuredConfiguration == nil { - return nil, meta, errors.New("must populate start state configuration") + return nil, errors.New("must populate start state configuration") } sfPlanner, err := newPlanManager(ctx, logger, request, meta) if err != nil { - return nil, meta, err + return nil, err } trajAsInps, goalsProcessed, err := sfPlanner.planMultiWaypoint(ctx) @@ -224,11 +222,30 @@ func PlanMotion(ctx context.Context, parentLogger logging.Logger, request *PlanR meta.PartialError = err logger.Infof("returning partial plan, error: %v", err) } else { - return nil, meta, err + return nil, err } } meta.GoalsProcessed = goalsProcessed + return trajAsInps, nil +} + +// PlanMotion plans a motion from a provided plan request. +func PlanMotion(ctx context.Context, parentLogger logging.Logger, request *PlanRequest) (motionplan.Plan, *PlanMeta, error) { + logger := parentLogger.Sublogger("mp") + + start := time.Now() + meta := &PlanMeta{} + ctx, span := trace.StartSpan(ctx, "PlanMotion") + defer func() { + meta.Duration = time.Since(start) + span.End() + }() + + trajAsInps, err := planWaypoints(ctx, logger, request, meta) + if err != nil { + return nil, meta, err + } t, err := motionplan.NewSimplePlanFromTrajectory(trajAsInps, request.FrameSystem) if err != nil { diff --git a/motionplan/armplanning/check.go b/motionplan/armplanning/check.go new file mode 100644 index 00000000000..32f1ba57bff --- /dev/null +++ b/motionplan/armplanning/check.go @@ -0,0 +1,202 @@ +package armplanning + +import ( + "context" + "fmt" + "strings" + + "go.viam.com/rdk/logging" + "go.viam.com/rdk/motionplan" + "go.viam.com/rdk/referenceframe" + "go.viam.com/rdk/spatialmath" +) + +// CheckPlanFromRequest checks a plan for collisions by interpolating between each pair of +// consecutive trajectory steps. It is a convenience wrapper around CheckPlan that extracts the +// necessary information from a PlanRequest. +// +// Moving frames are auto-detected by analysing which frames have changing inputs across the +// trajectory. Collisions that exist at the start of the trajectory are automatically allowed +// throughout the plan so that the arm is not penalised for pre-existing contact. +// +// Returns nil if no collision is found, or an error describing the first detected collision. +func CheckPlanFromRequest( + ctx context.Context, + logger logging.Logger, + req *PlanRequest, + plan motionplan.Plan, +) error { + ws := req.WorldState + if ws == nil { + ws = referenceframe.NewEmptyWorldState() + } + return CheckPlan(ctx, logger, ws, req.FrameSystem, plan) +} + +// CheckPlan checks a plan for collisions by interpolating between each pair of consecutive +// trajectory steps. +// +// Moving frames are auto-detected by analysing which frames have changing inputs across the +// trajectory. Collisions that exist at the start of the trajectory are automatically allowed +// throughout the plan. +// +// Returns nil if no collision is found, or an error describing the first detected collision. +func CheckPlan( + ctx context.Context, + logger logging.Logger, + worldState *referenceframe.WorldState, + fs *referenceframe.FrameSystem, + plan motionplan.Plan, +) error { + traj := plan.Trajectory() + if len(traj) < 2 { + return nil + } + + // Convert to LinearInputs once for reuse. + linearTraj := make([]*referenceframe.LinearInputs, len(traj)) + for i, step := range traj { + linearTraj[i] = step.ToLinearInputs() + } + startInputs := linearTraj[0] + + // Get frame geometries at the start configuration. + frameSystemGeometries, err := referenceframe.FrameSystemGeometriesLinearInputs(fs, startInputs) + if err != nil { + return err + } + + // Auto-detect which frames are actually moving in this trajectory. + movingFrames := detectMovingFrames(traj) + + // Split geometries into moving and static based on the detected moving frames. + var movingGeos, staticGeos []spatialmath.Geometry + for _, geoms := range frameSystemGeometries { + for _, geom := range geoms.Geometries() { + if belongsToMovingFrame(geom.Label(), movingFrames) { + movingGeos = append(movingGeos, geom) + } else { + staticGeos = append(staticGeos, geom) + } + } + } + + // Get world obstacle geometries at the start configuration. + obstaclesInFrame, err := worldState.ObstaclesInWorldFrame(fs, startInputs.ToFrameSystemInputs()) + if err != nil { + return err + } + worldGeos := obstaclesInFrame.Geometries() + + collisionBufferMM := NewBasicPlannerOptions().CollisionBufferMM + + // Determine which collisions already exist at the start so we can ignore them. + var allowedCollisions []motionplan.Collision + if len(movingGeos) > 0 { + if len(worldGeos) > 0 { + cols, _, err := motionplan.CheckCollisions(movingGeos, worldGeos, nil, collisionBufferMM, true) + if err != nil { + return err + } + allowedCollisions = append(allowedCollisions, cols...) + } + if len(staticGeos) > 0 { + cols, _, err := motionplan.CheckCollisions(movingGeos, staticGeos, nil, collisionBufferMM, true) + if err != nil { + return err + } + allowedCollisions = append(allowedCollisions, cols...) + } + if len(movingGeos) > 1 { + cols, _, err := motionplan.CheckCollisions(movingGeos, movingGeos, nil, collisionBufferMM, true) + if err != nil { + return err + } + allowedCollisions = append(allowedCollisions, cols...) + } + } + + collisionConstraints, err := motionplan.CreateAllCollisionConstraints( + fs, movingGeos, staticGeos, worldGeos, allowedCollisions, collisionBufferMM, + ) + if err != nil { + return err + } + + checker := motionplan.NewEmptyConstraintChecker(logger) + checker.SetCollisionConstraints(collisionConstraints) + + resolution := defaultResolution + for i := 0; i < len(linearTraj)-1; i++ { + seg := &motionplan.SegmentFS{ + StartConfiguration: linearTraj[i], + EndConfiguration: linearTraj[i+1], + FS: fs, + } + if _, err := checker.CheckStateConstraintsAcrossSegmentFS(ctx, seg, resolution, true); err != nil { + return fmt.Errorf("collision in segment %d (waypoints %d→%d): %w", i, i, i+1, err) + } + } + + return nil +} + +// detectMovingFrames returns the set of frame names whose inputs change at any point in the +// trajectory. If no frames change (e.g. all waypoints are identical) every frame in the +// trajectory is treated as moving. +// +// Runs in a single pass, comparing each step to the previous one. +func detectMovingFrames(trajectory motionplan.Trajectory) map[string]bool { + movingFrames := make(map[string]bool) + framesInTraj := make(map[string]bool) + + var prev referenceframe.FrameSystemInputs + first := true + for _, step := range trajectory { + for name := range step { + framesInTraj[name] = true + } + if first { + prev = step + first = false + continue + } + for name, inputs := range step { + if movingFrames[name] { + continue + } + prevInputs := prev[name] + if len(inputs) != len(prevInputs) { + movingFrames[name] = true + continue + } + for j := range inputs { + if inputs[j] != prevInputs[j] { + movingFrames[name] = true + break + } + } + } + prev = step + } + + // Fallback: if nothing moved, treat all frames in the trajectory as moving. + if len(movingFrames) == 0 { + for name := range framesInTraj { + movingFrames[name] = true + } + } + + return movingFrames +} + +// belongsToMovingFrame returns true when the geometry label starts with one of the moving frame +// names followed by a colon (e.g. "myArm:link1" belongs to frame "myArm"). +func belongsToMovingFrame(label string, movingFrames map[string]bool) bool { + for name := range movingFrames { + if strings.HasPrefix(label, name+":") { + return true + } + } + return false +} diff --git a/motionplan/armplanning/check_test.go b/motionplan/armplanning/check_test.go new file mode 100644 index 00000000000..a1b378ad2ec --- /dev/null +++ b/motionplan/armplanning/check_test.go @@ -0,0 +1,199 @@ +package armplanning + +import ( + "context" + "testing" + + "github.com/golang/geo/r3" + "go.viam.com/test" + + "go.viam.com/rdk/logging" + "go.viam.com/rdk/motionplan" + frame "go.viam.com/rdk/referenceframe" + "go.viam.com/rdk/spatialmath" + "go.viam.com/rdk/utils" +) + +// testPlan is a minimal motionplan.Plan implementation used in tests. +type testPlan struct { + trajectory motionplan.Trajectory +} + +func (p *testPlan) Trajectory() motionplan.Trajectory { return p.trajectory } +func (p *testPlan) Path() motionplan.Path { return nil } + +func TestCheckPlan(t *testing.T) { + logger := logging.NewTestLogger(t) + ctx := context.Background() + + ur20, err := frame.ParseModelJSONFile(utils.ResolveFile("components/arm/fake/kinematics/ur20.json"), "") + test.That(t, err, test.ShouldBeNil) + + fs := frame.NewEmptyFrameSystem("test") + err = fs.AddFrame(ur20, fs.World()) + test.That(t, err, test.ShouldBeNil) + + startInputs := []frame.Input{ + 0.7853981633974483, + -0.7853981633974483, + 1.5707963267948966, + -0.7853981633974483, + 0.7853981633974483, + 0, + } + + bigWall, err := spatialmath.NewBox( + spatialmath.NewPose( + r3.Vector{X: 499.80892449234604, Y: 0, Z: 0}, + &spatialmath.OrientationVectorDegrees{OZ: 1, Theta: 0}, + ), + r3.Vector{X: 100, Y: 6774.340100002068, Z: 4708.262746117678}, + "bigWall", + ) + test.That(t, err, test.ShouldBeNil) + + littleWall1, err := spatialmath.NewBox( + spatialmath.NewPose( + r3.Vector{X: -489.0617925579456, Y: 0, Z: 0}, + &spatialmath.OrientationVectorDegrees{OZ: 1, Theta: 0}, + ), + r3.Vector{X: 693.3530661392058, Y: 100, Z: 725.3808831665151}, + "littleWall1", + ) + test.That(t, err, test.ShouldBeNil) + + littleWall2, err := spatialmath.NewBox( + spatialmath.NewPose( + r3.Vector{X: -812.5564475789858, Y: 0, Z: 295.11694017940315}, + &spatialmath.OrientationVectorDegrees{OZ: 1, Theta: 0}, + ), + r3.Vector{X: 369.641316443868, Y: 100, Z: 596.3239775519398}, + "littleWall2", + ) + test.That(t, err, test.ShouldBeNil) + + worldState, err := frame.NewWorldState( + []*frame.GeometriesInFrame{ + frame.NewGeometriesInFrame(frame.World, []spatialmath.Geometry{bigWall, littleWall1, littleWall2}), + }, + nil, + ) + test.That(t, err, test.ShouldBeNil) + + goalPose := spatialmath.NewPose( + r3.Vector{ + X: -1091.7784630090632, + Y: 653.2215369909372, + Z: 171.2573338461849, + }, + &spatialmath.OrientationVectorDegrees{ + OX: -0.9999999999999999, + OY: -5.551115123125783e-17, + OZ: -8.495620873461007e-11, + Theta: 89.9999999833808, + }, + ) + + planRequest := &PlanRequest{ + FrameSystem: fs, + Goals: []*PlanState{ + {poses: frame.FrameSystemPoses{ur20.Name(): frame.NewPoseInFrame(frame.World, goalPose)}}, + }, + StartState: &PlanState{structuredConfiguration: frame.FrameSystemInputs{ur20.Name(): startInputs}}, + WorldState: worldState, + PlannerOptions: NewBasicPlannerOptions(), + } + + plan, _, err := PlanMotion(ctx, logger, planRequest) + test.That(t, err, test.ShouldBeNil) + test.That(t, plan, test.ShouldNotBeNil) + + err = CheckPlan(ctx, logger, worldState, fs, plan) + test.That(t, err, test.ShouldBeNil) + + err = CheckPlanFromRequest(ctx, logger, planRequest, plan) + test.That(t, err, test.ShouldBeNil) +} + +// TestCheckPlanWithAllowedCollisions verifies that collisions present at the start configuration +// are allowed throughout the plan and do not cause CheckPlan to fail. +func TestCheckPlanWithAllowedCollisions(t *testing.T) { + logger := logging.NewTestLogger(t) + ctx := context.Background() + + ur5, err := frame.ParseModelJSONFile(utils.ResolveFile("components/arm/fake/kinematics/ur5e.json"), "") + test.That(t, err, test.ShouldBeNil) + + fs := frame.NewEmptyFrameSystem("test") + err = fs.AddFrame(ur5, fs.World()) + test.That(t, err, test.ShouldBeNil) + + startInputs := []frame.Input{0, -1.5708, 1.5708, 0, 0, 0} + + // Find where the forearm link is so we can place an obstacle on top of it. + startLinearInputs := frame.FrameSystemInputs{ur5.Name(): startInputs}.ToLinearInputs() + fsGeoms, err := frame.FrameSystemGeometriesLinearInputs(fs, startLinearInputs) + test.That(t, err, test.ShouldBeNil) + + var forearmCenter r3.Vector + for _, geomsInFrame := range fsGeoms { + for _, geom := range geomsInFrame.Geometries() { + if geom.Label() == "UR5e:forearm_link" { + forearmCenter = geom.Pose().Point() + break + } + } + } + + // Create an obstacle overlapping the forearm link at the start configuration. + obstacle, err := spatialmath.NewBox( + spatialmath.NewPose(forearmCenter, &spatialmath.OrientationVectorDegrees{OZ: 1, Theta: 0}), + r3.Vector{X: 300, Y: 300, Z: 300}, + "obstacle", + ) + test.That(t, err, test.ShouldBeNil) + + worldState, err := frame.NewWorldState( + []*frame.GeometriesInFrame{ + frame.NewGeometriesInFrame(frame.World, []spatialmath.Geometry{obstacle}), + }, + nil, + ) + test.That(t, err, test.ShouldBeNil) + + // A trajectory that doesn't move — the start-state collision should be allowed throughout. + plan := &testPlan{ + trajectory: motionplan.Trajectory{ + {ur5.Name(): startInputs}, + {ur5.Name(): startInputs}, + }, + } + + err = CheckPlan(ctx, logger, worldState, fs, plan) + test.That(t, err, test.ShouldBeNil) +} + +// TestCheckPlanNilWorldState verifies that a nil WorldState is handled gracefully (treated as +// having no obstacles) rather than panicking. +func TestCheckPlanNilWorldState(t *testing.T) { + logger := logging.NewTestLogger(t) + ctx := context.Background() + + ur5, err := frame.ParseModelJSONFile(utils.ResolveFile("components/arm/fake/kinematics/ur5e.json"), "") + test.That(t, err, test.ShouldBeNil) + + fs := frame.NewEmptyFrameSystem("test") + err = fs.AddFrame(ur5, fs.World()) + test.That(t, err, test.ShouldBeNil) + + plan := &testPlan{ + trajectory: motionplan.Trajectory{ + {ur5.Name(): []frame.Input{0, 0, 0, 0, 0, 0}}, + {ur5.Name(): []frame.Input{0.5, 0, 0, 0, 0, 0}}, + }, + } + + req := &PlanRequest{FrameSystem: fs, WorldState: nil} + err = CheckPlanFromRequest(ctx, logger, req, plan) + test.That(t, err, test.ShouldBeNil) +} diff --git a/motionplan/armplanning/trajgen.go b/motionplan/armplanning/trajgen.go new file mode 100644 index 00000000000..4ea291bab30 --- /dev/null +++ b/motionplan/armplanning/trajgen.go @@ -0,0 +1,383 @@ +package armplanning + +import ( + "context" + "fmt" + "time" + + "github.com/pkg/errors" + "go.viam.com/utils/trace" + "gorgonia.org/tensor" + + "go.viam.com/rdk/logging" + "go.viam.com/rdk/ml" + "go.viam.com/rdk/motionplan" + "go.viam.com/rdk/referenceframe" + "go.viam.com/rdk/resource" + "go.viam.com/rdk/services/mlmodel" +) + +// TrajGenConfig holds configuration for the trajectory generator ML model service. +type TrajGenConfig struct { + Service string `json:"service"` + PathToleranceDeltaRads *float64 `json:"path_tolerance_delta_rads,omitempty"` + PathColinearizationRatio *float64 `json:"path_colinearization_ratio,omitempty"` + WaypointDeduplicationToleranceRads *float64 `json:"waypoint_deduplication_tolerance_rads,omitempty"` + VelocityLimitsRadsPerSec float64 `json:"velocity_limits_rads_per_sec,omitempty"` + AccelerationLimitsRadsPerSec2 float64 `json:"acceleration_limits_rads_per_sec2,omitempty"` + SamplingFreqHz *float64 `json:"trajectory_sampling_freq_hz,omitempty"` +} + +// Validate returns the mlmodel service name as a required dependency and checks that velocity and +// acceleration limits are positive. +func (cfg *TrajGenConfig) Validate(path string) ([]string, error) { + if cfg.VelocityLimitsRadsPerSec <= 0 { + return nil, fmt.Errorf("need positive velocity_limits_rads_per_sec if using trajectory_generator, got %v", + cfg.VelocityLimitsRadsPerSec) + } + if cfg.AccelerationLimitsRadsPerSec2 <= 0 { + return nil, fmt.Errorf("need positive acceleration_limits_rads_per_sec2 if using trajectory_generator, got %v", + cfg.AccelerationLimitsRadsPerSec2) + } + if cfg.Service == "" { + return nil, resource.NewConfigValidationFieldRequiredError(path, "service") + } + return []string{cfg.Service}, nil +} + +// ToTrajGen resolves the named mlmodel service from deps and returns a TrajGen ready for use. +func (cfg *TrajGenConfig) ToTrajGen(deps resource.Dependencies) (*TrajGen, error) { + svc, err := mlmodel.FromProvider(deps, cfg.Service) + if err != nil { + return nil, err + } + return NewTrajGen( + svc, + cfg.PathToleranceDeltaRads, + cfg.PathColinearizationRatio, + cfg.WaypointDeduplicationToleranceRads, + cfg.VelocityLimitsRadsPerSec, + cfg.AccelerationLimitsRadsPerSec2, + cfg.SamplingFreqHz, + ), nil +} + +const ( + defaultTrajGenPathToleranceDeltaRads = 0.1 + defaultTrajGenWaypointDeduplicationToleranceRads = 1e-3 + defaultTrajGenSamplingFreqHz = 10.0 + defaultTrajGenPathColinearizationRatio = 0.0 +) + +// DoCommandKeyExecuteTrajGenPlan is the do_command key for sending a precomputed kinodynamic +// trajectory to an arm component that supports it. +const DoCommandKeyExecuteTrajGenPlan = "execute_traj_gen_plan" + +// DoCommandKeySupportsExecuteTrajGenPlan is the capability probe key. Arms that support +// execute_traj_gen_plan respond to this with true. +const DoCommandKeySupportsExecuteTrajGenPlan = "supports_execute_traj_gen_plan" + +// TrajGen holds a resolved trajectory generator ML model service along with its configuration. +type TrajGen struct { + trajGen mlmodel.Service + PathToleranceDeltaRads float64 `json:"path_tolerance_delta_rads"` + PathColinearizationRatio float64 `json:"path_colinearization_ratio"` + WaypointDeduplicationToleranceRads float64 `json:"waypoint_deduplication_tolerance_rads"` + VelocityLimitsRadsPerSec float64 `json:"velocity_limits_rads_per_sec"` + AccelerationLimitsRadsPerSec2 float64 `json:"acceleration_limits_rads_per_sec2"` + SamplingFreqHz float64 `json:"trajectory_sampling_freq_hz"` +} + +func applyDefault(v *float64, def float64) float64 { + if v == nil { + return def + } + return *v +} + +// TrajGenOverride holds per-call overrides for TrajGen settings. Any nil field +// means "use the already-configured value". Pass it to TrajGen.WithOverrides to +// get a modified copy. +type TrajGenOverride struct { + PathToleranceDeltaRads *float64 `json:"path_tolerance_delta_rads,omitempty"` + PathColinearizationRatio *float64 `json:"path_colinearization_ratio,omitempty"` + WaypointDeduplicationToleranceRads *float64 `json:"waypoint_deduplication_tolerance_rads,omitempty"` + VelocityLimitsRadsPerSec *float64 `json:"velocity_limits_rads_per_sec,omitempty"` + AccelerationLimitsRadsPerSec2 *float64 `json:"acceleration_limits_rads_per_sec2,omitempty"` + SamplingFreqHz *float64 `json:"trajectory_sampling_freq_hz,omitempty"` +} + +// WithOverrides returns a shallow copy of tg with any non-nil fields from o applied. +func (tg *TrajGen) WithOverrides(o *TrajGenOverride) *TrajGen { + copy := *tg + if o.PathToleranceDeltaRads != nil { + copy.PathToleranceDeltaRads = *o.PathToleranceDeltaRads + } + if o.PathColinearizationRatio != nil { + copy.PathColinearizationRatio = *o.PathColinearizationRatio + } + if o.WaypointDeduplicationToleranceRads != nil { + copy.WaypointDeduplicationToleranceRads = *o.WaypointDeduplicationToleranceRads + } + if o.VelocityLimitsRadsPerSec != nil { + copy.VelocityLimitsRadsPerSec = *o.VelocityLimitsRadsPerSec + } + if o.AccelerationLimitsRadsPerSec2 != nil { + copy.AccelerationLimitsRadsPerSec2 = *o.AccelerationLimitsRadsPerSec2 + } + if o.SamplingFreqHz != nil { + copy.SamplingFreqHz = *o.SamplingFreqHz + } + return © +} + +// NewTrajGen constructs a TrajGen from an mlmodel service and configuration fields, +// applying defaults for any nil optional values. +func NewTrajGen( + svc mlmodel.Service, + pathToleranceDeltaRads *float64, + pathColinearizationRatio *float64, + waypointDeduplicationToleranceRads *float64, + velocityLimitsRadsPerSec float64, + accelerationLimitsRadsPerSec2 float64, + samplingFreqHz *float64, +) *TrajGen { + return &TrajGen{ + trajGen: svc, + PathToleranceDeltaRads: applyDefault(pathToleranceDeltaRads, defaultTrajGenPathToleranceDeltaRads), + PathColinearizationRatio: applyDefault(pathColinearizationRatio, defaultTrajGenPathColinearizationRatio), + WaypointDeduplicationToleranceRads: applyDefault(waypointDeduplicationToleranceRads, defaultTrajGenWaypointDeduplicationToleranceRads), + VelocityLimitsRadsPerSec: velocityLimitsRadsPerSec, + AccelerationLimitsRadsPerSec2: accelerationLimitsRadsPerSec2, + SamplingFreqHz: applyDefault(samplingFreqHz, defaultTrajGenSamplingFreqHz), + } +} + +// TrajGenPlan is a motionplan.Plan enriched with the kinodynamic data produced by the trajectory +// generator service. Callers that only need joint configurations can use it as a plain Plan; +// callers that need velocities, accelerations, or timestamps can type-assert to *TrajGenPlan. +type TrajGenPlan struct { + *motionplan.SimplePlan + // Configurations holds per-joint positions at each trajectory sample, parallel to Trajectory(). + Configurations []*referenceframe.LinearInputs + // Velocities holds per-joint velocities at each trajectory sample, parallel to Trajectory(). + Velocities []*referenceframe.LinearInputs + // Accelerations holds per-joint accelerations at each sample. It is nil when the service did + // not return acceleration data. + Accelerations []*referenceframe.LinearInputs + // SampleTimes holds the time (in seconds) of each sample, parallel to Trajectory(). + SampleTimes []float64 +} + +// DoCommandPayload returns the map[string]any value for the "execute_traj_gen_plan" do_command key. +func (t *TrajGenPlan) DoCommandPayload() map[string]any { + flatten := func(lis []*referenceframe.LinearInputs) [][]float64 { + out := make([][]float64, len(lis)) + for i, li := range lis { + out[i] = li.GetLinearizedInputs() + } + return out + } + payload := map[string]any{ + "configurations_rads": flatten(t.Configurations), + "velocities_rads_per_sec": flatten(t.Velocities), + "sample_times_sec": t.SampleTimes, + } + if len(t.Accelerations) > 0 { + payload["accelerations_rads_per_sec2"] = flatten(t.Accelerations) + } + return payload +} + +// trajGenResult is the raw output of inferTrajGen. +type trajGenResult struct { + configurations []*referenceframe.LinearInputs + velocities []*referenceframe.LinearInputs + accelerations []*referenceframe.LinearInputs // nil when not provided by the service + sampleTimes []float64 +} + +// inferTrajGen sends the waypoints to the trajectory generator service and returns the resulting +// densely-sampled kinodynamic trajectory. Returns nil when the service indicates the component is +// already at the goal (fewer than 2 distinct waypoints after deduplication). +func inferTrajGen( + ctx context.Context, + fs *referenceframe.FrameSystem, + trajAsInps []*referenceframe.LinearInputs, + tg *TrajGen, +) (*trajGenResult, error) { + if len(trajAsInps) == 0 { + return &trajGenResult{}, nil + } + + schema, err := trajAsInps[0].GetSchema(fs) + if err != nil { + return nil, err + } + + dof := len(trajAsInps[0].GetLinearizedInputs()) + nWaypoints := len(trajAsInps) + + waypoints := make([]float64, 0, nWaypoints*dof) + for _, li := range trajAsInps { + waypoints = append(waypoints, li.GetLinearizedInputs()...) + } + + velLimits := make([]float64, dof) + accelLimits := make([]float64, dof) + for i := range dof { + velLimits[i] = tg.VelocityLimitsRadsPerSec + accelLimits[i] = tg.AccelerationLimitsRadsPerSec2 + } + + outMap, err := tg.trajGen.Infer(ctx, ml.Tensors{ + "waypoints_rads": tensor.New( + tensor.Of(tensor.Float64), + tensor.WithShape(nWaypoints, dof), + tensor.WithBacking(waypoints), + ), + "velocity_limits_rads_per_sec": tensor.New( + tensor.Of(tensor.Float64), + tensor.WithShape(dof), + tensor.WithBacking(velLimits), + ), + "acceleration_limits_rads_per_sec2": tensor.New( + tensor.Of(tensor.Float64), + tensor.WithShape(dof), + tensor.WithBacking(accelLimits), + ), + "path_tolerance_delta_rads": tensor.New( + tensor.Of(tensor.Float64), + tensor.WithShape(1), + tensor.WithBacking([]float64{tg.PathToleranceDeltaRads}), + ), + "path_colinearization_ratio": tensor.New( + tensor.Of(tensor.Float64), + tensor.WithShape(1), + tensor.WithBacking([]float64{tg.PathColinearizationRatio}), + ), + "waypoint_deduplication_tolerance_rads": tensor.New( + tensor.Of(tensor.Float64), + tensor.WithShape(1), + tensor.WithBacking([]float64{tg.WaypointDeduplicationToleranceRads}), + ), + "trajectory_sampling_freq_hz": tensor.New( + tensor.Of(tensor.Int64), + tensor.WithShape(1), + tensor.WithBacking([]int64{int64(tg.SamplingFreqHz)}), + ), + }) + if err != nil { + return nil, err + } + + configsTensor, ok := outMap["configurations_rads"] + if !ok { + // Service returns an empty map when fewer than 2 distinct waypoints remain after + // deduplication -- the arm is already at the goal. + return nil, nil + } + + nSamples := configsTensor.Shape()[0] + + // Helper: convert a flat [n_samples, dof] tensor into []*LinearInputs using the schema. + linearize := func(t *tensor.Dense) ([]*referenceframe.LinearInputs, error) { + data := t.Data().([]float64) + out := make([]*referenceframe.LinearInputs, nSamples) + for i := range nSamples { + li, err := schema.FloatsToInputs(data[i*dof : (i+1)*dof]) + if err != nil { + return nil, err + } + out[i] = li + } + return out, nil + } + + configs, err := linearize(configsTensor) + if err != nil { + return nil, err + } + + velsTensor, ok := outMap["velocities_rads_per_sec"] + if !ok { + return nil, errors.New("trajectory generator service did not return velocities_rads_per_sec") + } + vels, err := linearize(velsTensor) + if err != nil { + return nil, err + } + + times := outMap["sample_times_sec"].Data().([]float64) + + result := &trajGenResult{ + configurations: configs, + velocities: vels, + sampleTimes: times, + } + + if accelTensor, ok := outMap["accelerations_rads_per_sec2"]; ok { + result.accelerations, err = linearize(accelTensor) + if err != nil { + return nil, err + } + } + + return result, nil +} + +// PlanMotionTrajGen plans a motion from a provided plan request using a trajectory generator. +func PlanMotionTrajGen( + ctx context.Context, parentLogger logging.Logger, request *PlanRequest, trajGen *TrajGen, +) (motionplan.Plan, *PlanMeta, error) { + logger := parentLogger.Sublogger("mp") + + start := time.Now() + meta := &PlanMeta{} + ctx, span := trace.StartSpan(ctx, "PlanMotion") + defer func() { + meta.Duration = time.Since(start) + span.End() + }() + + trajAsInps, err := planWaypoints(ctx, logger, request, meta) + if err != nil { + return nil, meta, err + } + + logger.CInfof(ctx, "sending %d waypoints to traj-gen service", len(trajAsInps)) + tgResult, err := inferTrajGen(ctx, request.FrameSystem, trajAsInps, trajGen) + if err != nil { + return nil, meta, err + } + + configs := []*referenceframe.LinearInputs{} + if tgResult != nil { + logger.CInfof(ctx, "traj-gen service returned %d samples (accelerations present: %v)", + len(tgResult.configurations), len(tgResult.accelerations) > 0) + configs = tgResult.configurations + } else { + logger.CInfof(ctx, "traj-gen service indicated arm is already at goal, skipping trajectory") + } + + simplePlan, err := motionplan.NewSimplePlanFromTrajectory(configs, request.FrameSystem) + if err != nil { + return nil, meta, err + } + + t := &TrajGenPlan{ + SimplePlan: simplePlan, + } + if tgResult != nil { + t.Configurations = tgResult.configurations + t.Velocities = tgResult.velocities + t.Accelerations = tgResult.accelerations + t.SampleTimes = tgResult.sampleTimes + } + + if err := CheckPlanFromRequest(ctx, logger, request, t); err != nil { + return nil, meta, err + } + + return t, meta, nil +} diff --git a/motionplan/armplanning/trajgen_test.go b/motionplan/armplanning/trajgen_test.go new file mode 100644 index 00000000000..f9224a33fe1 --- /dev/null +++ b/motionplan/armplanning/trajgen_test.go @@ -0,0 +1,195 @@ +package armplanning + +import ( + "context" + "testing" + + "go.viam.com/test" + + "go.viam.com/rdk/motionplan" + frame "go.viam.com/rdk/referenceframe" +) + +// TestNewTrajGenDefaults verifies that nil optional fields are replaced by their defaults. +func TestNewTrajGenDefaults(t *testing.T) { + tg := NewTrajGen(nil, nil, nil, nil, 1.0, 2.0, nil) + test.That(t, tg.PathToleranceDeltaRads, test.ShouldEqual, defaultTrajGenPathToleranceDeltaRads) + test.That(t, tg.PathColinearizationRatio, test.ShouldEqual, defaultTrajGenPathColinearizationRatio) + test.That(t, tg.WaypointDeduplicationToleranceRads, test.ShouldEqual, defaultTrajGenWaypointDeduplicationToleranceRads) + test.That(t, tg.SamplingFreqHz, test.ShouldEqual, defaultTrajGenSamplingFreqHz) + // Non-optional fields pass through unchanged. + test.That(t, tg.VelocityLimitsRadsPerSec, test.ShouldEqual, 1.0) + test.That(t, tg.AccelerationLimitsRadsPerSec2, test.ShouldEqual, 2.0) +} + +// TestNewTrajGenExplicitValues verifies that non-nil optional fields override the defaults. +func TestNewTrajGenExplicitValues(t *testing.T) { + pt := 0.05 + cr := 0.3 + dd := 0.002 + sf := 20.0 + tg := NewTrajGen(nil, &pt, &cr, &dd, 3.0, 4.0, &sf) + test.That(t, tg.PathToleranceDeltaRads, test.ShouldEqual, 0.05) + test.That(t, tg.PathColinearizationRatio, test.ShouldEqual, 0.3) + test.That(t, tg.WaypointDeduplicationToleranceRads, test.ShouldEqual, 0.002) + test.That(t, tg.SamplingFreqHz, test.ShouldEqual, 20.0) +} + +// TestTrajGenConfigValidate checks all validation rules. +func TestTrajGenConfigValidate(t *testing.T) { + valid := TrajGenConfig{ + Service: "my_svc", + VelocityLimitsRadsPerSec: 1.0, + AccelerationLimitsRadsPerSec2: 2.0, + } + + t.Run("valid config returns service as dependency", func(t *testing.T) { + deps, err := valid.Validate("path") + test.That(t, err, test.ShouldBeNil) + test.That(t, deps, test.ShouldContain, "my_svc") + }) + + t.Run("missing service", func(t *testing.T) { + cfg := valid + cfg.Service = "" + _, err := cfg.Validate("path") + test.That(t, err, test.ShouldNotBeNil) + }) + + t.Run("zero velocity limit", func(t *testing.T) { + cfg := valid + cfg.VelocityLimitsRadsPerSec = 0 + _, err := cfg.Validate("path") + test.That(t, err, test.ShouldNotBeNil) + }) + + t.Run("negative velocity limit", func(t *testing.T) { + cfg := valid + cfg.VelocityLimitsRadsPerSec = -1 + _, err := cfg.Validate("path") + test.That(t, err, test.ShouldNotBeNil) + }) + + t.Run("zero acceleration limit", func(t *testing.T) { + cfg := valid + cfg.AccelerationLimitsRadsPerSec2 = 0 + _, err := cfg.Validate("path") + test.That(t, err, test.ShouldNotBeNil) + }) +} + +// TestTrajGenPlanDoCommandPayload verifies the serialization format used in execute_traj_gen_plan. +func TestTrajGenPlanDoCommandPayload(t *testing.T) { + configs := [][]float64{ + {0.1, 0.2, 0.3, 0.4, 0.5, 0.6}, + {0.7, 0.8, 0.9, 1.0, 1.1, 1.2}, + } + vels := [][]float64{ + {0.01, 0.02, 0.03, 0.04, 0.05, 0.06}, + {0.07, 0.08, 0.09, 0.10, 0.11, 0.12}, + } + times := []float64{0.1, 0.2} + + toLinearInputs := func(rows [][]float64) []*frame.LinearInputs { + lis := make([]*frame.LinearInputs, len(rows)) + for i, row := range rows { + lis[i] = frame.FrameSystemInputs{"arm": row}.ToLinearInputs() + } + return lis + } + + t.Run("without accelerations", func(t *testing.T) { + tgp := &TrajGenPlan{ + SimplePlan: motionplan.NewSimplePlan(nil, nil), + Configurations: toLinearInputs(configs), + Velocities: toLinearInputs(vels), + SampleTimes: times, + } + + payload := tgp.DoCommandPayload() + test.That(t, payload["configurations_rads"], test.ShouldResemble, configs) + test.That(t, payload["velocities_rads_per_sec"], test.ShouldResemble, vels) + test.That(t, payload["sample_times_sec"], test.ShouldResemble, times) + _, hasAccels := payload["accelerations_rads_per_sec2"] + test.That(t, hasAccels, test.ShouldBeFalse) + }) + + t.Run("with accelerations", func(t *testing.T) { + accels := [][]float64{ + {0.001, 0.002, 0.003, 0.004, 0.005, 0.006}, + {0.007, 0.008, 0.009, 0.010, 0.011, 0.012}, + } + tgp := &TrajGenPlan{ + SimplePlan: motionplan.NewSimplePlan(nil, nil), + Configurations: toLinearInputs(configs), + Velocities: toLinearInputs(vels), + Accelerations: toLinearInputs(accels), + SampleTimes: times, + } + + payload := tgp.DoCommandPayload() + test.That(t, payload["accelerations_rads_per_sec2"], test.ShouldResemble, accels) + }) +} + +// TestDetectMovingFrames verifies the single-pass moving-frame detection logic. +func TestDetectMovingFrames(t *testing.T) { + t.Run("empty trajectory returns empty map", func(t *testing.T) { + moving := detectMovingFrames(motionplan.Trajectory{}) + test.That(t, moving, test.ShouldBeEmpty) + }) + + t.Run("single step — no pairs to compare, fallback treats all as moving", func(t *testing.T) { + traj := motionplan.Trajectory{{"arm": []float64{0, 0, 0}}} + moving := detectMovingFrames(traj) + test.That(t, moving, test.ShouldContainKey, "arm") + }) + + t.Run("frame with identical inputs is not moving; fallback fires for all-static", func(t *testing.T) { + traj := motionplan.Trajectory{ + {"arm": []float64{1, 2, 3}}, + {"arm": []float64{1, 2, 3}}, + } + // Nothing moved → fallback treats all frames as moving. + moving := detectMovingFrames(traj) + test.That(t, moving, test.ShouldContainKey, "arm") + }) + + t.Run("frame with changing inputs is detected as moving", func(t *testing.T) { + traj := motionplan.Trajectory{ + {"arm": []float64{0, 0, 0}}, + {"arm": []float64{1, 0, 0}}, + {"arm": []float64{2, 0, 0}}, + } + moving := detectMovingFrames(traj) + test.That(t, moving, test.ShouldContainKey, "arm") + }) + + t.Run("only the moving frame among multiple frames is returned", func(t *testing.T) { + traj := motionplan.Trajectory{ + {"arm": []float64{0, 0, 0}, "gripper": []float64{0}}, + {"arm": []float64{1, 0, 0}, "gripper": []float64{0}}, + } + moving := detectMovingFrames(traj) + test.That(t, moving, test.ShouldContainKey, "arm") + test.That(t, moving, test.ShouldNotContainKey, "gripper") + }) + + t.Run("detects change that only appears in a later step", func(t *testing.T) { + traj := motionplan.Trajectory{ + {"arm": []float64{0, 0, 0}}, + {"arm": []float64{0, 0, 0}}, + {"arm": []float64{0, 0, 1}}, // change only in last step + } + moving := detectMovingFrames(traj) + test.That(t, moving, test.ShouldContainKey, "arm") + }) +} + +// TestInferTrajGenEmptyWaypoints verifies the early-return for an empty waypoint list. +func TestInferTrajGenEmptyWaypoints(t *testing.T) { + result, err := inferTrajGen(context.Background(), nil, nil, nil) + test.That(t, err, test.ShouldBeNil) + test.That(t, result, test.ShouldNotBeNil) + test.That(t, result.configurations, test.ShouldBeEmpty) +} diff --git a/services/motion/builtin/builtin.go b/services/motion/builtin/builtin.go index f428d80b394..a4633c1d425 100644 --- a/services/motion/builtin/builtin.go +++ b/services/motion/builtin/builtin.go @@ -3,6 +3,7 @@ package builtin import ( "context" + "encoding/json" "fmt" "math" "os" @@ -92,6 +93,8 @@ type Config struct { // example { "arm" : { "3" : { "min" : 0, "max" : 2 } } } InputRangeOverride map[string]map[string]referenceframe.Limit `json:"input_range_override"` + + TrajGen *armplanning.TrajGenConfig `json:"trajectory_generator,omitempty"` } func (c *Config) shouldWritePlan(start time.Time, err error) bool { @@ -109,6 +112,7 @@ func (c *Config) shouldWritePlan(start time.Time, err error) bool { // Validate here adds a dependency on the internal framesystem service. func (c *Config) Validate(path string) ([]string, []string, error) { + deps := []string{framesystem.InternalServiceName.String()} if c.NumThreads < 0 { return nil, nil, fmt.Errorf("cannot configure with %d number of threads, number must be positive", c.NumThreads) } @@ -120,8 +124,15 @@ func (c *Config) Validate(path string) ([]string, []string, error) { if c.LogSlowPlanThresholdMS != 0 && c.PlanFilePath == "" { return nil, nil, fmt.Errorf("need a plan_file_path if you sent LogSlowPlanThresholdMS to %v", c.LogSlowPlanThresholdMS) } + if c.TrajGen != nil { + trajGenDeps, err := c.TrajGen.Validate(path) + if err != nil { + return nil, nil, err + } + deps = append(deps, trajGenDeps...) + } - return []string{framesystem.InternalServiceName.String()}, nil, nil + return deps, nil, nil } type builtIn struct { @@ -135,6 +146,7 @@ type builtIn struct { components map[string]resource.Resource logger logging.Logger configuredDefaultExtras map[string]any + trajGen *armplanning.TrajGen } // NewBuiltIn returns a new move and grab service for the given robot. @@ -175,6 +187,7 @@ func (ms *builtIn) Reconfigure( if config.NumThreads > 0 { ms.configuredDefaultExtras["num_threads"] = config.NumThreads } + ms.configuredDefaultExtras["skipTrajGen"] = false movementSensors := make(map[string]movementsensor.MovementSensor) slamServices := make(map[string]slam.Service) @@ -199,6 +212,14 @@ func (ms *builtIn) Reconfigure( ms.visionServices = visionServices ms.components = componentMap + ms.trajGen = nil + if config.TrajGen != nil { + ms.trajGen, err = config.TrajGen.ToTrajGen(deps) + if err != nil { + return err + } + } + return nil } @@ -216,7 +237,7 @@ func (ms *builtIn) Move(ctx context.Context, req motion.MoveReq) (bool, error) { if err != nil { return false, err } - err = ms.execute(ctx, plan.Trajectory(), math.MaxFloat64) + err = ms.execute(ctx, plan, math.MaxFloat64) return err == nil, err } @@ -349,7 +370,7 @@ func (ms *builtIn) DoCommand(ctx context.Context, cmd map[string]interface{}) (m resp[DoExecuteCheckStart] = "resource at starting location" } - if err := ms.execute(ctx, trajectory, epsilon); err != nil { + if err := ms.execute(ctx, motionplan.NewSimplePlan(nil, trajectory), epsilon); err != nil { return nil, err } resp[DoExecute] = true @@ -480,8 +501,26 @@ func (ms *builtIn) plan(ctx context.Context, req motion.MoveReq, logger logging. PlannerOptions: planOpts, } + skipTrajGen := req.Extra["skipTrajGen"].(bool) + trajGen := ms.trajGen + if overrideIface, ok := req.Extra["trajectory_generator"]; ok && trajGen != nil { + overrideMap, ok := overrideIface.(map[string]any) + if !ok { + return nil, errors.New("extras trajectory_generator must be a map") + } + var override armplanning.TrajGenOverride + if err := decodeJSONTagged(overrideMap, &override); err != nil { + return nil, fmt.Errorf("trajectory_generator override: %w", err) + } + trajGen = trajGen.WithOverrides(&override) + } start := time.Now() - plan, _, err := armplanning.PlanMotion(ctx, logger, planRequest) + var plan motionplan.Plan + if !skipTrajGen && trajGen != nil { + plan, _, err = armplanning.PlanMotionTrajGen(ctx, logger, planRequest, trajGen) + } else { + plan, _, err = armplanning.PlanMotion(ctx, logger, planRequest) + } if ms.conf.shouldWritePlan(start, err) { var traceID string if span := trace.FromContext(ctx); span != nil { @@ -504,7 +543,34 @@ func (ms *builtIn) plan(ctx context.Context, req motion.MoveReq, logger logging. return plan, err } -func (ms *builtIn) execute(ctx context.Context, trajectory motionplan.Trajectory, epsilon float64) error { +func (ms *builtIn) execute(ctx context.Context, plan motionplan.Plan, epsilon float64) error { + // When the plan carries kinodynamic data and the target arm supports the precomputed + // trajectory path, send everything in one DoCommand instead of batching GoToInputs. + if tgp, ok := plan.(*armplanning.TrajGenPlan); ok && len(tgp.Configurations) > 0 { + var armName string + for name := range tgp.Trajectory()[0] { + armName = name + break + } + if armName != "" { + if r, ok := ms.components[armName]; ok { + capResp, err := r.DoCommand(ctx, map[string]any{armplanning.DoCommandKeySupportsExecuteTrajGenPlan: true}) + if err == nil { + if v, _ := capResp[armplanning.DoCommandKeySupportsExecuteTrajGenPlan].(bool); v { + ms.logger.CInfof(ctx, "executing traj-gen plan on %q (%d samples)", armName, len(tgp.Configurations)) + _, err = r.DoCommand(ctx, map[string]any{ + armplanning.DoCommandKeyExecuteTrajGenPlan: tgp.DoCommandPayload(), + }) + return err + } + } + ms.logger.CInfof(ctx, "arm %q does not support %s, falling back to GoToInputs", + armName, armplanning.DoCommandKeyExecuteTrajGenPlan) + } + } + } + + trajectory := plan.Trajectory() // Batch GoToInputs calls if possible; components may want to blend between inputs combinedSteps := []map[string][][]referenceframe.Input{} currStep := map[string][][]referenceframe.Input{} @@ -679,6 +745,16 @@ func waypointsFromRequest( return startState, waypoints, nil } +// decodeJSONTagged decodes a map[string]any into a struct by round-tripping +// through JSON, so that json struct tags are honoured. +func decodeJSONTagged(m map[string]any, dst any) error { + b, err := json.Marshal(m) + if err != nil { + return err + } + return json.Unmarshal(b, dst) +} + func (ms *builtIn) writePlanRequest( req *armplanning.PlanRequest, plan motionplan.Plan, start time.Time, traceID, planTag string, planError error, ) error { @@ -721,5 +797,23 @@ func (ms *builtIn) writePlanRequest( } ms.logger.Infof("writing plan to %s", fn) - return req.WriteToFile(fn) + if err := req.WriteToFile(fn); err != nil { + return err + } + + if tgp, ok := plan.(*armplanning.TrajGenPlan); ok && ms.trajGen != nil { + trajGenFn := strings.TrimSuffix(fn, filepath.Ext(fn)) + "-trajgen.json" + ms.logger.Infof("writing traj-gen plan to %s", trajGenFn) + data, err := json.Marshal(map[string]any{ + "settings": ms.trajGen, + "plan": tgp.DoCommandPayload(), + }) + if err != nil { + return err + } + if err := os.WriteFile(filepath.Clean(trajGenFn), data, 0o600); err != nil { + return err + } + } + return nil } diff --git a/services/motion/builtin/builtin_test.go b/services/motion/builtin/builtin_test.go index e038a11e8a8..6fd56775a60 100644 --- a/services/motion/builtin/builtin_test.go +++ b/services/motion/builtin/builtin_test.go @@ -750,3 +750,121 @@ func TestWritePlanRequest(t *testing.T) { // Verify the filename contains the custom tag test.That(t, planFile.Name(), test.ShouldContainSubstring, "custom-test-tag") } + +// fakeDoCommandArm is a minimal resource.Resource whose DoCommand records calls and returns +// pre-configured responses. It does not implement InputEnabled and is only suitable for tests +// that exercise the TrajGenPlan fast path (which returns before GoToInputs). +type fakeDoCommandArm struct { + calls []map[string]interface{} + responses []map[string]interface{} +} + +func (f *fakeDoCommandArm) Name() resource.Name { return resource.Name{} } +func (f *fakeDoCommandArm) Reconfigure(_ context.Context, _ resource.Dependencies, _ resource.Config) error { + return nil +} +func (f *fakeDoCommandArm) Close(_ context.Context) error { return nil } +func (f *fakeDoCommandArm) DoCommand(_ context.Context, cmd map[string]interface{}) (map[string]interface{}, error) { + f.calls = append(f.calls, cmd) + if len(f.responses) > 0 { + resp := f.responses[0] + f.responses = f.responses[1:] + return resp, nil + } + return nil, nil +} + +// TestExecuteTrajGenFastPath verifies that execute() sends execute_traj_gen_plan via DoCommand +// when the plan is a *TrajGenPlan and the arm reports capability support. +func TestExecuteTrajGenFastPath(t *testing.T) { + ctx := context.Background() + logger := logging.NewTestLogger(t) + + configs := []float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6} + vels := []float64{0.01, 0.02, 0.03, 0.04, 0.05, 0.06} + times := []float64{0.1} + + toLinearInputs := func(vals []float64) []*referenceframe.LinearInputs { + return []*referenceframe.LinearInputs{ + referenceframe.FrameSystemInputs{"arm": vals}.ToLinearInputs(), + } + } + + tgp := &armplanning.TrajGenPlan{ + SimplePlan: motionplan.NewSimplePlan(nil, motionplan.Trajectory{{"arm": configs}}), + Configurations: toLinearInputs(configs), + Velocities: toLinearInputs(vels), + SampleTimes: times, + } + + arm := &fakeDoCommandArm{ + responses: []map[string]interface{}{ + // First call: capability probe → supported. + {armplanning.DoCommandKeySupportsExecuteTrajGenPlan: true}, + // Second call: execute → success. + {armplanning.DoCommandKeyExecuteTrajGenPlan: true}, + }, + } + + ms := &builtIn{ + components: map[string]resource.Resource{"arm": arm}, + logger: logger, + } + + err := ms.execute(ctx, tgp, math.MaxFloat64) + test.That(t, err, test.ShouldBeNil) + + // Two DoCommand calls should have been made. + test.That(t, arm.calls, test.ShouldHaveLength, 2) + + // First call is the capability probe. + _, hasCapKey := arm.calls[0][armplanning.DoCommandKeySupportsExecuteTrajGenPlan] + test.That(t, hasCapKey, test.ShouldBeTrue) + + // Second call carries the trajectory payload. + payload, hasExecKey := arm.calls[1][armplanning.DoCommandKeyExecuteTrajGenPlan] + test.That(t, hasExecKey, test.ShouldBeTrue) + payloadMap, ok := payload.(map[string]any) + test.That(t, ok, test.ShouldBeTrue) + test.That(t, payloadMap["configurations_rads"], test.ShouldResemble, [][]float64{configs}) + test.That(t, payloadMap["velocities_rads_per_sec"], test.ShouldResemble, [][]float64{vels}) + test.That(t, payloadMap["sample_times_sec"], test.ShouldResemble, times) +} + +// TestExecuteTrajGenCapabilityNotSupported verifies that execute() does not send +// execute_traj_gen_plan when the capability probe returns false, and instead falls +// through to the normal GoToInputs path. +func TestExecuteTrajGenCapabilityNotSupported(t *testing.T) { + ctx := context.Background() + logger := logging.NewTestLogger(t) + + configs := []float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6} + + tgp := &armplanning.TrajGenPlan{ + SimplePlan: motionplan.NewSimplePlan(nil, motionplan.Trajectory{{"arm": configs}}), + Configurations: []*referenceframe.LinearInputs{ + referenceframe.FrameSystemInputs{"arm": configs}.ToLinearInputs(), + }, + Velocities: []*referenceframe.LinearInputs{}, + SampleTimes: []float64{}, + } + + arm := &fakeDoCommandArm{ + responses: []map[string]interface{}{ + // Capability probe returns false. + {armplanning.DoCommandKeySupportsExecuteTrajGenPlan: false}, + }, + } + + ms := &builtIn{ + components: map[string]resource.Resource{"arm": arm}, + logger: logger, + } + + // execute falls back to GoToInputs, which will fail because fakeDoCommandArm doesn't + // implement InputEnabled — but only one DoCommand call (the probe) should have been made. + _ = ms.execute(ctx, tgp, math.MaxFloat64) + test.That(t, arm.calls, test.ShouldHaveLength, 1) + _, hasCapKey := arm.calls[0][armplanning.DoCommandKeySupportsExecuteTrajGenPlan] + test.That(t, hasCapKey, test.ShouldBeTrue) +}