diff --git a/infra/conf/transport_internet.go b/infra/conf/transport_internet.go index 965a05a0b007..4eb7c54f3b11 100644 --- a/infra/conf/transport_internet.go +++ b/infra/conf/transport_internet.go @@ -37,6 +37,7 @@ import ( "github.com/xtls/xray-core/transport/internet/tcp" "github.com/xtls/xray-core/transport/internet/tls" "github.com/xtls/xray-core/transport/internet/websocket" + "github.com/xtls/xray-core/transport/internet/xdrive" "google.golang.org/protobuf/proto" ) @@ -452,6 +453,41 @@ func (c *SplitHTTPConfig) Build() (proto.Message, error) { return config, nil } +// XDriveConfig represents the configuration for the XDRIVE transport. +// For the "local" service, secrets[0] should be "client" for client-side +// or "server" for server-side. +// For "Google Drive" service, secrets should contain [role, ClientID, ClientSecret, RefreshToken] +// where role is "client" or "server". +type XDriveConfig struct { + RemoteFolder string `json:"remoteFolder"` + Service string `json:"service"` + Secrets []string `json:"secrets"` +} + +// Build implements Buildable. +func (c *XDriveConfig) Build() (proto.Message, error) { + switch c.Service { + case "local": + // For local service, secrets[0] must be "client" or "server" + if len(c.Secrets) < 1 { + return nil, errors.New("local service needs secrets[0] set to 'client' or 'server'") + } + case "Google Drive": + // For Google Drive, secrets should be [role, ClientID, ClientSecret, RefreshToken] + if len(c.Secrets) != 4 { + return nil, errors.New("Google Drive needs 4 secrets: role ('client' or 'server'), ClientID, ClientSecret, RefreshToken") + } + default: + return nil, errors.New("unsupported service") + } + config := &xdrive.Config{ + RemoteFolder: c.RemoteFolder, + Service: c.Service, + Secrets: c.Secrets, + } + return config, nil +} + const ( Byte = 1 Kilobyte = 1024 * Byte @@ -1054,6 +1090,8 @@ func (p TransportProtocol) Build() (string, error) { return "", errors.PrintRemovedFeatureError("QUIC transport (without web service, etc.)", "XHTTP stream-one H3") case "hysteria": return "hysteria", nil + case "xdrive": + return "xdrive", nil default: return "", errors.New("Config: unknown transport protocol: ", p) } @@ -1403,6 +1441,7 @@ type StreamConfig struct { WSSettings *WebSocketConfig `json:"wsSettings"` HTTPUPGRADESettings *HttpUpgradeConfig `json:"httpupgradeSettings"` HysteriaSettings *HysteriaConfig `json:"hysteriaSettings"` + XDriveSettings *XDriveConfig `json:"xdriveSettings"` SocketSettings *SocketConfig `json:"sockopt"` } @@ -1533,6 +1572,16 @@ func (c *StreamConfig) Build() (*internet.StreamConfig, error) { Settings: serial.ToTypedMessage(hs), }) } + if c.XDriveSettings != nil { + xs, err := c.XDriveSettings.Build() + if err != nil { + return nil, errors.New("Failed to build XDrive config.").Base(err) + } + config.TransportSettings = append(config.TransportSettings, &internet.TransportConfig{ + ProtocolName: "xdrive", + Settings: serial.ToTypedMessage(xs), + }) + } if c.SocketSettings != nil { ss, err := c.SocketSettings.Build() if err != nil { diff --git a/main/distro/all/all.go b/main/distro/all/all.go index 11b58d9215d3..219f34193716 100644 --- a/main/distro/all/all.go +++ b/main/distro/all/all.go @@ -59,6 +59,7 @@ import ( _ "github.com/xtls/xray-core/transport/internet/tls" _ "github.com/xtls/xray-core/transport/internet/udp" _ "github.com/xtls/xray-core/transport/internet/websocket" + _ "github.com/xtls/xray-core/transport/internet/xdrive" // Transport headers _ "github.com/xtls/xray-core/transport/internet/headers/http" diff --git a/transport/internet/xdrive/config.pb.go b/transport/internet/xdrive/config.pb.go new file mode 100644 index 000000000000..cfa657f18697 --- /dev/null +++ b/transport/internet/xdrive/config.pb.go @@ -0,0 +1,140 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v6.33.5 +// source: transport/internet/xdrive/config.proto + +package xdrive + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Config struct { + state protoimpl.MessageState `protogen:"open.v1"` + RemoteFolder string `protobuf:"bytes,1,opt,name=remote_folder,json=remoteFolder,proto3" json:"remote_folder,omitempty"` + Service string `protobuf:"bytes,2,opt,name=service,proto3" json:"service,omitempty"` + Secrets []string `protobuf:"bytes,3,rep,name=secrets,proto3" json:"secrets,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Config) Reset() { + *x = Config{} + mi := &file_transport_internet_xdrive_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_xdrive_config_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Config.ProtoReflect.Descriptor instead. +func (*Config) Descriptor() ([]byte, []int) { + return file_transport_internet_xdrive_config_proto_rawDescGZIP(), []int{0} +} + +func (x *Config) GetRemoteFolder() string { + if x != nil { + return x.RemoteFolder + } + return "" +} + +func (x *Config) GetService() string { + if x != nil { + return x.Service + } + return "" +} + +func (x *Config) GetSecrets() []string { + if x != nil { + return x.Secrets + } + return nil +} + +var File_transport_internet_xdrive_config_proto protoreflect.FileDescriptor + +const file_transport_internet_xdrive_config_proto_rawDesc = "" + + "\n" + + "&transport/internet/xdrive/config.proto\x12\x1exray.transport.internet.xdrive\"a\n" + + "\x06Config\x12#\n" + + "\rremote_folder\x18\x01 \x01(\tR\fremoteFolder\x12\x18\n" + + "\aservice\x18\x02 \x01(\tR\aservice\x12\x18\n" + + "\asecrets\x18\x03 \x03(\tR\asecretsB5Z3github.com/xtls/xray-core/transport/internet/xdriveb\x06proto3" + +var ( + file_transport_internet_xdrive_config_proto_rawDescOnce sync.Once + file_transport_internet_xdrive_config_proto_rawDescData []byte +) + +func file_transport_internet_xdrive_config_proto_rawDescGZIP() []byte { + file_transport_internet_xdrive_config_proto_rawDescOnce.Do(func() { + file_transport_internet_xdrive_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_xdrive_config_proto_rawDesc), len(file_transport_internet_xdrive_config_proto_rawDesc))) + }) + return file_transport_internet_xdrive_config_proto_rawDescData +} + +var file_transport_internet_xdrive_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_transport_internet_xdrive_config_proto_goTypes = []any{ + (*Config)(nil), // 0: xray.transport.internet.xdrive.Config +} +var file_transport_internet_xdrive_config_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_transport_internet_xdrive_config_proto_init() } +func file_transport_internet_xdrive_config_proto_init() { + if File_transport_internet_xdrive_config_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_xdrive_config_proto_rawDesc), len(file_transport_internet_xdrive_config_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_transport_internet_xdrive_config_proto_goTypes, + DependencyIndexes: file_transport_internet_xdrive_config_proto_depIdxs, + MessageInfos: file_transport_internet_xdrive_config_proto_msgTypes, + }.Build() + File_transport_internet_xdrive_config_proto = out.File + file_transport_internet_xdrive_config_proto_goTypes = nil + file_transport_internet_xdrive_config_proto_depIdxs = nil +} diff --git a/transport/internet/xdrive/config.proto b/transport/internet/xdrive/config.proto new file mode 100644 index 000000000000..cc545753cae0 --- /dev/null +++ b/transport/internet/xdrive/config.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +package xray.transport.internet.xdrive; +option go_package = "github.com/xtls/xray-core/transport/internet/xdrive"; + +message Config { + string remote_folder = 1; + string service = 2; + repeated string secrets = 3; +} diff --git a/transport/internet/xdrive/xdrive.go b/transport/internet/xdrive/xdrive.go new file mode 100644 index 000000000000..eacd47a75fd9 --- /dev/null +++ b/transport/internet/xdrive/xdrive.go @@ -0,0 +1,594 @@ +package xdrive + +import ( + "context" + "fmt" + "io" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/xtls/xray-core/common" + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/signal/done" + "github.com/xtls/xray-core/common/uuid" + "github.com/xtls/xray-core/transport/internet" + "github.com/xtls/xray-core/transport/internet/stat" +) + +const protocolName = "xdrive" + +// Configuration constants +const ( + // ReadPollInterval is the interval for polling when reading data + ReadPollInterval = 100 * time.Millisecond + // ServerPollInterval is the interval for polling new connections on server + ServerPollInterval = 500 * time.Millisecond + // ReadTimeout is the timeout for read operations + ReadTimeout = 10 * time.Second + // RetentionWindow is how long files are kept before cleanup + RetentionWindow = 10 * time.Second + // CleanupInterval is how often the cleanup routine runs + CleanupInterval = 5 * time.Second +) + +// DriveService defines the interface for remote storage services +type DriveService interface { + // Login authenticates with the service (no-op for local) + Login(ctx context.Context) error + // Upload creates a file with the given name and content + Upload(ctx context.Context, name string, data []byte) error + // List returns files matching the prefix, created within the given duration + List(ctx context.Context, prefix string, within time.Duration) ([]FileInfo, error) + // Download retrieves the content of a file + Download(ctx context.Context, name string) ([]byte, error) + // Delete removes a file + Delete(ctx context.Context, name string) error +} + +// FileInfo represents information about a file +type FileInfo struct { + Name string + CreatedAt time.Time +} + +// LocalDriveService implements DriveService using the local filesystem +type LocalDriveService struct { + remoteFolder string + mu sync.RWMutex +} + +// NewLocalDriveService creates a new LocalDriveService +func NewLocalDriveService(remoteFolder string) *LocalDriveService { + return &LocalDriveService{ + remoteFolder: remoteFolder, + } +} + +// Login is a no-op for local filesystem +func (s *LocalDriveService) Login(ctx context.Context) error { + // Ensure the folder exists + return os.MkdirAll(s.remoteFolder, 0755) +} + +// Upload creates a file with the given name and content +func (s *LocalDriveService) Upload(ctx context.Context, name string, data []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + + filePath := filepath.Join(s.remoteFolder, name) + return os.WriteFile(filePath, data, 0644) +} + +// List returns files matching the prefix, created within the given duration +func (s *LocalDriveService) List(ctx context.Context, prefix string, within time.Duration) ([]FileInfo, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + entries, err := os.ReadDir(s.remoteFolder) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + + cutoff := time.Now().Add(-within) + var files []FileInfo + + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + if !strings.HasPrefix(name, prefix) { + continue + } + info, err := entry.Info() + if err != nil { + continue + } + if info.ModTime().Before(cutoff) { + continue + } + files = append(files, FileInfo{ + Name: name, + CreatedAt: info.ModTime(), + }) + } + + return files, nil +} + +// Download retrieves the content of a file +func (s *LocalDriveService) Download(ctx context.Context, name string) ([]byte, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + filePath := filepath.Join(s.remoteFolder, name) + return os.ReadFile(filePath) +} + +// Delete removes a file +func (s *LocalDriveService) Delete(ctx context.Context, name string) error { + s.mu.Lock() + defer s.mu.Unlock() + + filePath := filepath.Join(s.remoteFolder, name) + err := os.Remove(filePath) + if os.IsNotExist(err) { + return nil + } + return err +} + +// ParseFileName parses a filename in format: sessionID-direction-seq +// Returns sessionID, direction ("up" or "down"), sequence number +func ParseFileName(name string) (sessionID string, direction string, seq int64, ok bool) { + // Remove extension if any + name = strings.TrimSuffix(name, filepath.Ext(name)) + + parts := strings.Split(name, "-") + if len(parts) < 7 { // UUID has 5 parts + direction + seq = at least 7 + return "", "", 0, false + } + + // UUID is the first 5 parts joined by "-" + sessionID = strings.Join(parts[:5], "-") + + // Validate it's a UUID + if _, err := uuid.ParseString(sessionID); err != nil { + return "", "", 0, false + } + + direction = parts[5] + if direction != "up" && direction != "down" { + return "", "", 0, false + } + + seq, err := strconv.ParseInt(parts[6], 10, 64) + if err != nil { + return "", "", 0, false + } + + return sessionID, direction, seq, true +} + +// MakeFileName creates a filename in format: sessionID-direction-seq +func MakeFileName(sessionID, direction string, seq int64) string { + return fmt.Sprintf("%s-%s-%d", sessionID, direction, seq) +} + +// XdriveConnection represents a connection over the XDRIVE transport +type XdriveConnection struct { + ctx context.Context + cancel context.CancelFunc + service DriveService + sessionID string + isClient bool // true for client, false for server + readDone *done.Instance + writeDone *done.Instance + readBuf []byte + readMu sync.Mutex + writeMu sync.Mutex + readSeq int64 + writeSeq int64 + localAddr net.Addr + remoteAddr net.Addr +} + +// newXdriveConnection creates a new XdriveConnection +func newXdriveConnection(ctx context.Context, service DriveService, sessionID string, isClient bool) *XdriveConnection { + ctx, cancel := context.WithCancel(ctx) + return &XdriveConnection{ + ctx: ctx, + cancel: cancel, + service: service, + sessionID: sessionID, + isClient: isClient, + readDone: done.New(), + writeDone: done.New(), + } +} + +// Read reads data from the connection by polling for files +func (c *XdriveConnection) Read(b []byte) (int, error) { + c.readMu.Lock() + defer c.readMu.Unlock() + + // If we have buffered data, return it first + if len(c.readBuf) > 0 { + n := copy(b, c.readBuf) + c.readBuf = c.readBuf[n:] + return n, nil + } + + // Determine which direction to read from + // Client reads "down" (server to client), server reads "up" (client to server) + readDirection := "down" + if !c.isClient { + readDirection = "up" + } + + prefix := fmt.Sprintf("%s-%s-", c.sessionID, readDirection) + + // Poll for the expected file with timeout + deadline := time.Now().Add(ReadTimeout) + + for time.Now().Before(deadline) { + select { + case <-c.ctx.Done(): + return 0, io.EOF + case <-c.readDone.Wait(): + return 0, io.EOF + default: + } + + // List files within retention window + files, err := c.service.List(c.ctx, prefix, RetentionWindow) + if err != nil { + return 0, err + } + + // Sort files by sequence number + sort.Slice(files, func(i, j int) bool { + _, _, seqI, _ := ParseFileName(files[i].Name) + _, _, seqJ, _ := ParseFileName(files[j].Name) + return seqI < seqJ + }) + + // Look for the expected sequence file + for _, file := range files { + _, _, seq, ok := ParseFileName(file.Name) + if !ok { + continue + } + + if seq == c.readSeq { + // Found the expected file, download it + data, err := c.service.Download(c.ctx, file.Name) + if err != nil { + if os.IsNotExist(err) { + continue + } + return 0, err + } + + // Delete the file after reading + c.service.Delete(c.ctx, file.Name) + + c.readSeq++ + + // Copy to buffer + n := copy(b, data) + if n < len(data) { + c.readBuf = data[n:] + } + return n, nil + } + } + + time.Sleep(ReadPollInterval) + } + + return 0, errors.New("read timeout") +} + +// Write writes data to the connection by creating files +func (c *XdriveConnection) Write(b []byte) (int, error) { + c.writeMu.Lock() + defer c.writeMu.Unlock() + + select { + case <-c.ctx.Done(): + return 0, io.ErrClosedPipe + case <-c.writeDone.Wait(): + return 0, io.ErrClosedPipe + default: + } + + // Determine which direction to write + // Client writes "up" (client to server), server writes "down" (server to client) + writeDirection := "up" + if !c.isClient { + writeDirection = "down" + } + + fileName := MakeFileName(c.sessionID, writeDirection, c.writeSeq) + + err := c.service.Upload(c.ctx, fileName, b) + if err != nil { + return 0, err + } + + c.writeSeq++ + return len(b), nil +} + +// Close closes the connection +func (c *XdriveConnection) Close() error { + c.readDone.Close() + c.writeDone.Close() + c.cancel() + return nil +} + +// LocalAddr returns the local address +func (c *XdriveConnection) LocalAddr() net.Addr { + if c.localAddr != nil { + return c.localAddr + } + return &net.TCPAddr{IP: net.IP{0, 0, 0, 0}, Port: 0} +} + +// RemoteAddr returns the remote address +func (c *XdriveConnection) RemoteAddr() net.Addr { + if c.remoteAddr != nil { + return c.remoteAddr + } + return &net.TCPAddr{IP: net.IP{0, 0, 0, 0}, Port: 0} +} + +// SetDeadline sets the deadline for the connection +func (c *XdriveConnection) SetDeadline(t time.Time) error { + return nil +} + +// SetReadDeadline sets the read deadline for the connection +func (c *XdriveConnection) SetReadDeadline(t time.Time) error { + return nil +} + +// SetWriteDeadline sets the write deadline for the connection +func (c *XdriveConnection) SetWriteDeadline(t time.Time) error { + return nil +} + +func init() { + common.Must(internet.RegisterProtocolConfigCreator(protocolName, func() interface{} { + return new(Config) + })) + common.Must(internet.RegisterTransportDialer(protocolName, Dial)) + common.Must(internet.RegisterTransportListener(protocolName, Serve)) +} + +// createDriveService creates a DriveService based on the configuration +func createDriveService(config *Config) (DriveService, error) { + switch config.Service { + case "local": + return NewLocalDriveService(config.RemoteFolder), nil + case "Google Drive": + // Placeholder for Google Drive implementation + return nil, errors.New("Google Drive service not yet implemented") + default: + return nil, errors.New("unsupported service: " + config.Service) + } +} + +// Dial creates a client connection to the XDRIVE transport +func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) { + config := streamSettings.ProtocolSettings.(*Config) + + // Validate secrets - for client, secrets[0] should be "client" + if len(config.Secrets) == 0 || config.Secrets[0] != "client" { + return nil, errors.New("client must have secrets[0] set to 'client'") + } + + service, err := createDriveService(config) + if err != nil { + return nil, err + } + + if err := service.Login(ctx); err != nil { + return nil, errors.New("failed to login to drive service").Base(err) + } + + // Generate a new session ID for this connection + newUUID := uuid.New() + sessionID := newUUID.String() + + errors.LogInfo(ctx, fmt.Sprintf("XDRIVE client dialing with session %s to folder %s", sessionID, config.RemoteFolder)) + + conn := newXdriveConnection(ctx, service, sessionID, true) + return stat.Connection(conn), nil +} + +// Server represents an XDRIVE server listener +type Server struct { + ctx context.Context + cancel context.CancelFunc + config *Config + service DriveService + addConn internet.ConnHandler + sessions sync.Map // sessionID -> *XdriveConnection + closeDone *done.Instance +} + +// Close closes the server +func (s *Server) Close() error { + s.closeDone.Close() + s.cancel() + return nil +} + +// Addr returns the address the server is listening on +func (s *Server) Addr() net.Addr { + return &net.TCPAddr{IP: net.IP{0, 0, 0, 0}, Port: 0} +} + +// Serve creates a server listener for the XDRIVE transport +func Serve(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) { + config := streamSettings.ProtocolSettings.(*Config) + + // Validate secrets - for server, secrets[0] should be "server" + if len(config.Secrets) == 0 || config.Secrets[0] != "server" { + return nil, errors.New("server must have secrets[0] set to 'server'") + } + + service, err := createDriveService(config) + if err != nil { + return nil, err + } + + if err := service.Login(ctx); err != nil { + return nil, errors.New("failed to login to drive service").Base(err) + } + + ctx, cancel := context.WithCancel(ctx) + + server := &Server{ + ctx: ctx, + cancel: cancel, + config: config, + service: service, + addConn: addConn, + closeDone: done.New(), + } + + errors.LogInfo(ctx, fmt.Sprintf("XDRIVE server listening on folder %s", config.RemoteFolder)) + + // Start polling for new connections + go server.pollForConnections() + + // Start cleanup routine for old files + go server.cleanupOldFiles() + + return server, nil +} + +// pollForConnections polls the folder for new upload files indicating new connections +func (s *Server) pollForConnections() { + ticker := time.NewTicker(ServerPollInterval) + defer ticker.Stop() + + for { + select { + case <-s.ctx.Done(): + return + case <-s.closeDone.Wait(): + return + case <-ticker.C: + s.checkForNewConnections() + } + } +} + +// checkForNewConnections checks for files from new client sessions +func (s *Server) checkForNewConnections() { + // List all files within the retention window + files, err := s.service.List(s.ctx, "", RetentionWindow) + if err != nil { + errors.LogWarning(s.ctx, "failed to list files: ", err) + return + } + + // Track seen session IDs + seenSessions := make(map[string]bool) + + for _, file := range files { + sessionID, direction, _, ok := ParseFileName(file.Name) + if !ok { + continue + } + + // Only process "up" files (client to server) + if direction != "up" { + continue + } + + // Check if we already have a session for this ID + if seenSessions[sessionID] { + continue + } + seenSessions[sessionID] = true + + if _, loaded := s.sessions.Load(sessionID); loaded { + continue + } + + // New session detected, create a connection + conn := newXdriveConnection(s.ctx, s.service, sessionID, false) + s.sessions.Store(sessionID, conn) + + errors.LogInfo(s.ctx, fmt.Sprintf("XDRIVE server accepted new connection: %s", sessionID)) + + // Handle the connection + s.addConn(stat.Connection(conn)) + } +} + +// cleanupOldFiles removes files older than the retention window +func (s *Server) cleanupOldFiles() { + ticker := time.NewTicker(CleanupInterval) + defer ticker.Stop() + + for { + select { + case <-s.ctx.Done(): + return + case <-s.closeDone.Wait(): + return + case <-ticker.C: + s.doCleanup(RetentionWindow) + } + } +} + +// doCleanup removes files older than the retention window +func (s *Server) doCleanup(retentionWindow time.Duration) { + localService, ok := s.service.(*LocalDriveService) + if !ok { + return + } + + localService.mu.Lock() + defer localService.mu.Unlock() + + entries, err := os.ReadDir(localService.remoteFolder) + if err != nil { + return + } + + cutoff := time.Now().Add(-retentionWindow) + + for _, entry := range entries { + if entry.IsDir() { + continue + } + + info, err := entry.Info() + if err != nil { + continue + } + + if info.ModTime().Before(cutoff) { + filePath := filepath.Join(localService.remoteFolder, entry.Name()) + os.Remove(filePath) + } + } +} diff --git a/transport/internet/xdrive/xdrive_test.go b/transport/internet/xdrive/xdrive_test.go new file mode 100644 index 000000000000..161071b34576 --- /dev/null +++ b/transport/internet/xdrive/xdrive_test.go @@ -0,0 +1,109 @@ +package xdrive_test + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + . "github.com/xtls/xray-core/transport/internet/xdrive" +) + +func TestLocalDriveService(t *testing.T) { + // Create a temporary directory for testing + tmpDir := filepath.Join(os.TempDir(), "xdrive_test") + defer os.RemoveAll(tmpDir) + + service := NewLocalDriveService(tmpDir) + ctx := context.Background() + + // Test Login (creates directory) + if err := service.Login(ctx); err != nil { + t.Fatalf("Login failed: %v", err) + } + + // Test Upload + testData := []byte("hello world") + testFileName := "test-file-1" + if err := service.Upload(ctx, testFileName, testData); err != nil { + t.Fatalf("Upload failed: %v", err) + } + + // Test List + files, err := service.List(ctx, "test-", 10*time.Second) + if err != nil { + t.Fatalf("List failed: %v", err) + } + if len(files) != 1 { + t.Fatalf("Expected 1 file, got %d", len(files)) + } + if files[0].Name != testFileName { + t.Fatalf("Expected file name %s, got %s", testFileName, files[0].Name) + } + + // Test Download + downloaded, err := service.Download(ctx, testFileName) + if err != nil { + t.Fatalf("Download failed: %v", err) + } + if string(downloaded) != string(testData) { + t.Fatalf("Downloaded data mismatch: expected %s, got %s", string(testData), string(downloaded)) + } + + // Test Delete + if err := service.Delete(ctx, testFileName); err != nil { + t.Fatalf("Delete failed: %v", err) + } + + // Verify file is deleted + files, err = service.List(ctx, "test-", 10*time.Second) + if err != nil { + t.Fatalf("List after delete failed: %v", err) + } + if len(files) != 0 { + t.Fatalf("Expected 0 files after delete, got %d", len(files)) + } +} + +func TestFileNameParsing(t *testing.T) { + // Test valid filename + sessionID := "550e8400-e29b-41d4-a716-446655440000" + fileName := sessionID + "-up-5" + + parsed, direction, seq, ok := ParseFileName(fileName) + if !ok { + t.Fatalf("Failed to parse valid filename") + } + if parsed != sessionID { + t.Fatalf("Session ID mismatch: expected %s, got %s", sessionID, parsed) + } + if direction != "up" { + t.Fatalf("Direction mismatch: expected up, got %s", direction) + } + if seq != 5 { + t.Fatalf("Seq mismatch: expected 5, got %d", seq) + } + + // Test invalid filenames + invalidNames := []string{ + "invalid-file", + "550e8400-e29b-41d4-a716-446655440000-invalid-5", + "not-a-uuid-up-5", + } + for _, name := range invalidNames { + _, _, _, ok := ParseFileName(name) + if ok { + t.Fatalf("Expected parsing to fail for %s", name) + } + } +} + +func TestMakeFileName(t *testing.T) { + sessionID := "550e8400-e29b-41d4-a716-446655440000" + expected := "550e8400-e29b-41d4-a716-446655440000-down-10" + result := MakeFileName(sessionID, "down", 10) + if result != expected { + t.Fatalf("MakeFileName mismatch: expected %s, got %s", expected, result) + } +}