Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions ldai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,21 @@ func (c *Client) Config(
return c.CompletionConfig(key, context, defaultValue, variables)
}

// CreateTracker reconstructs a Tracker from a resumption token and the given context.
// This delegates to TrackerFromResumptionToken. See that function for details.
func (c *Client) CreateTracker(token string, context ldcontext.Context) (*Tracker, error) {
return TrackerFromResumptionToken(token, c.sdk, context)
}

// returnDefault sets a tracker factory on a copy of def (so CreateTracker always works) and
// returns it along with an initial tracker. Used for all error-path returns in evaluateConfig.
func (c *Client) returnDefault(key string, context ldcontext.Context, def Config) (Config, *Tracker) {
def.trackerFactory = func() *Tracker {
return newTracker(c.sdk, newRunID(), key, "", 1, context, &def, c.logger)
}
return def, newTracker(c.sdk, newRunID(), key, "", 1, context, &def, c.logger)
}

// evaluateConfig fetches and interpolates an AI Config without emitting any metric.
// Callers (Config, JudgeConfig) are meant to emit their own metric before calling this.
func (c *Client) evaluateConfig(
Expand All @@ -125,13 +140,13 @@ func (c *Client) evaluateConfig(
// empty object.)
if result.Type() != ldvalue.ObjectType {
c.logConfigWarning(key, "unmarshalling failed, expected JSON object but got %s", result.Type().String())
return defaultValue, newTracker(key, "", 1, c.sdk, &defaultValue, context, c.logger)
return c.returnDefault(key, context, defaultValue)
}

var parsed datamodel.Config
if err := json.Unmarshal([]byte(result.JSONString()), &parsed); err != nil {
c.logConfigWarning(key, "unmarshalling failed: %v", err)
return defaultValue, newTracker(key, "", 1, c.sdk, &defaultValue, context, c.logger)
return c.returnDefault(key, context, defaultValue)
}

mergedVariables := map[string]interface{}{
Expand Down Expand Up @@ -169,7 +184,7 @@ func (c *Client) evaluateConfig(
c.logConfigWarning(key,
"malformed message at index %d: %v", i, err,
)
return defaultValue, &Tracker{}
return c.returnDefault(key, context, defaultValue)
}
builder.WithMessage(content, msg.Role)
}
Expand All @@ -181,7 +196,12 @@ func (c *Client) evaluateConfig(
version = *parsed.Meta.Version
}

return cfg, newTracker(key, parsed.Meta.VariationKey, version, c.sdk, &cfg, context, c.logger)
variationKey := parsed.Meta.VariationKey
cfg.trackerFactory = func() *Tracker {
return newTracker(c.sdk, newRunID(), key, variationKey, version, context, &cfg, c.logger)
}

return cfg, newTracker(c.sdk, newRunID(), key, parsed.Meta.VariationKey, version, context, &cfg, c.logger)
}

func getAllAttributes(context ldcontext.Context) map[string]interface{} {
Expand Down
245 changes: 243 additions & 2 deletions ldai/client_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package ldai

import (
"encoding/base64"
"encoding/json"
"errors"
"testing"

Expand Down Expand Up @@ -96,7 +98,10 @@ func TestEvalErrorReturnsDefault(t *testing.T) {

cfg, tracker := client.Config("key", ldcontext.New("user"), defaultVal, nil)
assert.NotNil(t, tracker)
assert.Equal(t, defaultVal, cfg)
assert.Equal(t, defaultVal.Enabled(), cfg.Enabled())
assert.Equal(t, defaultVal.Messages(), cfg.Messages())
assert.Equal(t, defaultVal.ModelName(), cfg.ModelName())
assert.Equal(t, defaultVal.ProviderName(), cfg.ProviderName())
}

func TestParseMultipleMessages(t *testing.T) {
Expand Down Expand Up @@ -191,7 +196,10 @@ func TestParseInvalidConfigReturnsDefault(t *testing.T) {
defaultVal := NewConfig().Enable().WithMessage("hello", datamodel.User).Build()

cfg, _ := client.Config("key", ldcontext.New("user"), defaultVal, nil)
assert.Equal(t, defaultVal, cfg)
// Verify config data matches the default
assert.Equal(t, defaultVal.AsLdValue(), cfg.AsLdValue())
// Verify CreateTracker() now works (returnDefault always injects a factory)
assert.NotNil(t, cfg.CreateTracker())

sdk.log.AssertMessageMatch(t, true, ldlog.Warn, "AI Config 'key':")
})
Expand Down Expand Up @@ -835,3 +843,236 @@ func TestConfig_WithoutReservedVarsWipesJudgePlaceholders(t *testing.T) {
require.Len(t, msgs, 1)
assert.Equal(t, "Input: \nOutput: ", msgs[0].Content, "Config without reserved vars renders placeholders as empty")
}

func TestCreateTracker_ManuallyBuiltConfig_ReturnsNil(t *testing.T) {
cfg := NewConfig().Enable().WithMessage("hello", datamodel.User).Build()
assert.Nil(t, cfg.CreateTracker(), "manually built config should not have a tracker factory")
}

func TestCreateTracker_DisabledConfig_ReturnsTracker(t *testing.T) {
json := []byte(`{
"_ldMeta": {"variationKey": "1", "enabled": false},
"messages": [{"content": "hello", "role": "user"}]
}`)

client, err := NewClient(newMockSDK(json, nil))
require.NoError(t, err)

cfg, _ := client.CompletionConfig("key", ldcontext.New("user"), Disabled(), nil)
assert.False(t, cfg.Enabled())
assert.NotNil(t, cfg.CreateTracker(), "disabled config should still have a tracker factory")
}

func TestCreateTracker_EnabledConfig_ReturnsTracker(t *testing.T) {
json := []byte(`{
"_ldMeta": {"variationKey": "1", "enabled": true},
"model": {"name": "gpt-4"},
"provider": {"name": "openai"},
"messages": [{"content": "hello", "role": "user"}]
}`)

client, err := NewClient(newMockSDK(json, nil))
require.NoError(t, err)

cfg, _ := client.CompletionConfig("key", ldcontext.New("user"), Disabled(), nil)
assert.True(t, cfg.Enabled())

tracker := cfg.CreateTracker()
require.NotNil(t, tracker, "enabled config should have a tracker factory")
}

func TestCreateTracker_FreshRunIdPerCall(t *testing.T) {
json := []byte(`{
"_ldMeta": {"variationKey": "1", "enabled": true},
"messages": [{"content": "hello", "role": "user"}]
}`)

mockSDK := newMockSDK(json, nil)
client, err := NewClient(mockSDK)
require.NoError(t, err)

// Clear SDK info event
mockSDK.events = nil

cfg, _ := client.CompletionConfig("key", ldcontext.New("user"), Disabled(), nil)

tracker1 := cfg.CreateTracker()
tracker2 := cfg.CreateTracker()
require.NotNil(t, tracker1)
require.NotNil(t, tracker2)

// Each tracker should be able to track independently. Track success on both to emit events.
_ = tracker1.TrackSuccess()
_ = tracker2.TrackSuccess()

// Filter out the usage event; we only want the generation events.
var genEvents []mockEvent
for _, e := range mockSDK.events {
if e.eventName == "$ld:ai:generation:success" {
genEvents = append(genEvents, e)
}
}

require.Len(t, genEvents, 2, "each tracker should emit its own event")

runId1 := genEvents[0].data.GetByKey("runId").StringValue()
runId2 := genEvents[1].data.GetByKey("runId").StringValue()
assert.NotEmpty(t, runId1)
assert.NotEmpty(t, runId2)
assert.NotEqual(t, runId1, runId2, "each tracker must have a unique runId")
}

func TestCreateTracker_TrackerHasCorrectMetadata(t *testing.T) {
json := []byte(`{
"_ldMeta": {"variationKey": "var-1", "enabled": true, "version": 5},
"model": {"name": "gpt-4"},
"provider": {"name": "openai"},
"messages": [{"content": "hello", "role": "user"}]
}`)

mockSDK := newMockSDK(json, nil)
client, err := NewClient(mockSDK)
require.NoError(t, err)

// Clear SDK info event
mockSDK.events = nil

cfg, _ := client.CompletionConfig("my-config", ldcontext.New("user"), Disabled(), nil)

tracker := cfg.CreateTracker()
require.NotNil(t, tracker)

_ = tracker.TrackSuccess()

// Filter for the generation event (skip usage event)
var genEvent *mockEvent
for i, e := range mockSDK.events {
if e.eventName == "$ld:ai:generation:success" {
genEvent = &mockSDK.events[i]
break
}
}
require.NotNil(t, genEvent)

data := genEvent.data
assert.Equal(t, "my-config", data.GetByKey("configKey").StringValue())
assert.Equal(t, "var-1", data.GetByKey("variationKey").StringValue())
assert.Equal(t, 5, data.GetByKey("version").IntValue())
assert.Equal(t, "openai", data.GetByKey("providerName").StringValue())
assert.Equal(t, "gpt-4", data.GetByKey("modelName").StringValue())
assert.NotEmpty(t, data.GetByKey("runId").StringValue())
}

func TestCreateTracker_JudgeConfigHasFactory(t *testing.T) {
json := []byte(`{
"_ldMeta": {"variationKey": "1", "enabled": true},
"mode": "judge",
"evaluationMetricKey": "toxicity",
"messages": [{"content": "test", "role": "system"}]
}`)

client, err := NewClient(newMockSDK(json, nil))
require.NoError(t, err)

cfg, _ := client.JudgeConfig("judge-key", ldcontext.New("user"), Disabled(), nil)
assert.True(t, cfg.Enabled())

tracker := cfg.CreateTracker()
require.NotNil(t, tracker, "enabled judge config should have a tracker factory")
}

func TestClient_CreateTracker_RoundTrip(t *testing.T) {
configJSON := []byte(`{
"_ldMeta": {"variationKey": "var-1", "enabled": true, "version": 5},
"model": {"name": "gpt-4"},
"provider": {"name": "openai"},
"messages": [{"content": "hello", "role": "user"}]
}`)

mockSDK := newMockSDK(configJSON, nil)
client, err := NewClient(mockSDK)
require.NoError(t, err)

// Clear SDK info event
mockSDK.events = nil

cfg, _ := client.CompletionConfig("my-config", ldcontext.New("user"), Disabled(), nil)
originalTracker := cfg.CreateTracker()
require.NotNil(t, originalTracker)

token := originalTracker.ResumptionToken()
require.NotEmpty(t, token)

// Reconstruct from token with a different context
newContext := ldcontext.New("other-user")
reconstructed, err := client.CreateTracker(token, newContext)
require.NoError(t, err)
require.NotNil(t, reconstructed)

// The reconstructed tracker should produce the same resumption token
assert.Equal(t, token, reconstructed.ResumptionToken())

// Track feedback on the reconstructed tracker and verify it uses the original runId
_ = originalTracker.TrackSuccess()
_ = reconstructed.TrackFeedback(FeedbackPositive)

var successEvent, feedbackEvent *mockEvent
for i, e := range mockSDK.events {
switch e.eventName {
case "$ld:ai:generation:success":
successEvent = &mockSDK.events[i]
case "$ld:ai:feedback:user:positive":
feedbackEvent = &mockSDK.events[i]
}
}
require.NotNil(t, successEvent)
require.NotNil(t, feedbackEvent)

// Both events should share the same runId
originalRunId := successEvent.data.GetByKey("runId").StringValue()
reconstructedRunId := feedbackEvent.data.GetByKey("runId").StringValue()
assert.Equal(t, originalRunId, reconstructedRunId, "reconstructed tracker must reuse the original runId")

// Reconstructed tracker should use the new context
assert.Equal(t, newContext, feedbackEvent.context)

// Verify metadata preserved
assert.Equal(t, "my-config", feedbackEvent.data.GetByKey("configKey").StringValue())
assert.Equal(t, "var-1", feedbackEvent.data.GetByKey("variationKey").StringValue())
assert.Equal(t, 5, feedbackEvent.data.GetByKey("version").IntValue())

// modelName and providerName should be empty on reconstructed tracker
assert.Equal(t, "", feedbackEvent.data.GetByKey("modelName").StringValue())
assert.Equal(t, "", feedbackEvent.data.GetByKey("providerName").StringValue())
}

func TestClient_CreateTracker_InvalidToken(t *testing.T) {
mockSDK := newMockSDK(nil, nil)
client, err := NewClient(mockSDK)
require.NoError(t, err)

t.Run("invalid base64", func(t *testing.T) {
_, err := client.CreateTracker("not-valid-base64!!!", ldcontext.New("user"))
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid resumption token")
})

t.Run("valid base64 but invalid JSON", func(t *testing.T) {
token := base64.RawURLEncoding.EncodeToString([]byte("not json"))
_, err := client.CreateTracker(token, ldcontext.New("user"))
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid resumption token")
})

t.Run("valid token with missing fields uses zero values", func(t *testing.T) {
payload, _ := json.Marshal(map[string]interface{}{"runId": "test-run"})
token := base64.RawURLEncoding.EncodeToString(payload)
tracker, err := client.CreateTracker(token, ldcontext.New("user"))
require.NoError(t, err)
require.NotNil(t, tracker)

// Should work with partial data
resumeToken := tracker.ResumptionToken()
assert.NotEmpty(t, resumeToken)
})
}
13 changes: 12 additions & 1 deletion ldai/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import (

// Config represents an AI Config.
type Config struct {
c datamodel.Config
c datamodel.Config
trackerFactory func() *Tracker
}

// VariationKey is used internally by LaunchDarkly.
Expand Down Expand Up @@ -87,6 +88,16 @@ func (c *Config) JudgeConfiguration() *datamodel.JudgeConfiguration {
}
}

// CreateTracker creates a new Tracker with a fresh runId for tracking metrics related to this
// AI Config evaluation. Each call returns a new, independent Tracker instance.
// Returns nil if the config was not obtained via the Client.
func (c *Config) CreateTracker() *Tracker {
if c.trackerFactory == nil {
return nil
}
return c.trackerFactory()
}

// AsLdValue is used internally.
func (c *Config) AsLdValue() ldvalue.Value {
return ldvalue.FromJSONMarshal(c.c)
Expand Down
Loading
Loading