Skip to content

Commit

Permalink
add sftpfs.EnsureRegistered
Browse files Browse the repository at this point in the history
  • Loading branch information
ungerik committed Dec 21, 2023
1 parent 882d07f commit ac1b5b7
Showing 1 changed file with 54 additions and 18 deletions.
72 changes: 54 additions & 18 deletions sftpfs/sftpfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,48 +76,57 @@ type fileSystem struct {
// If no port is provided in the address, then port 22 will be used.
// The address can contain a username or a username and password.
func Dial(ctx context.Context, address string, loginCallback LoginCallback, hostKeyCallback ssh.HostKeyCallback) (fs.FileSystem, error) {
u, username, password, prefix, err := prepareDial(address, loginCallback, hostKeyCallback)
if err != nil {
return nil, err
}
client, err := dial(ctx, u.Host, username, password, hostKeyCallback)
if err != nil {
return nil, err
}
return &fileSystem{
client: client,
prefix: prefix,
}, nil
}

func prepareDial(address string, loginCallback LoginCallback, hostKeyCallback ssh.HostKeyCallback) (u *url.URL, username, password, prefix string, err error) {
if !strings.HasPrefix(address, "sftp://") {
if strings.Contains(address, "://") {
return nil, fmt.Errorf("URL must start with sftp:// but got %s", address)
return nil, "", "", "", fmt.Errorf("missing SFTP password for: %s", address)
}
address = "sftp://" + address
}
if loginCallback == nil {
return nil, fmt.Errorf("missing loginCallback")
return nil, "", "", "", fmt.Errorf("missing SFTP password for: %s", address)
}
if hostKeyCallback == nil {
return nil, fmt.Errorf("missing hostKeyCallback")
return nil, "", "", "", fmt.Errorf("missing SFTP password for: %s", address)
}
u, err := url.Parse(address)
u, err = url.Parse(address)
if err != nil {
return nil, err
return nil, "", "", "", fmt.Errorf("missing SFTP password for: %s", address)
}
if u.Scheme != "sftp" {
return nil, fmt.Errorf("URL scheme must be sftp:// but got %s://", u.Scheme)
return nil, "", "", "", fmt.Errorf("missing SFTP password for: %s", address)
}
if u.Port() == "" {
u.Host += ":22"
}

username, password, err := loginCallback(u)
username, password, err = loginCallback(u)
if err != nil {
return nil, err
return nil, "", "", "", err
}
if username == "" {
return nil, fmt.Errorf("missing SFTP username for: %s", address)
return nil, "", "", "", fmt.Errorf("missing SFTP username for: %s", address)
}
if password == "" {
return nil, fmt.Errorf("missing SFTP password for: %s", address)
}
client, err := dial(ctx, u.Host, username, password, hostKeyCallback)
if err != nil {
return nil, err
return nil, "", "", "", fmt.Errorf("missing SFTP password for: %s", address)
}
return &fileSystem{
client: client,
prefix: fmt.Sprintf("sftp://%s@%s", url.User(username), u.Host),
}, nil
prefix = fmt.Sprintf("sftp://%s@%s", url.User(username), u.Host)

return u, username, password, prefix, nil
}

// DialAndRegister dials a new SFTP connection and register it as file system.
Expand All @@ -134,6 +143,33 @@ func DialAndRegister(ctx context.Context, address string, loginCallback LoginCal
return fileSystem, nil
}

// EnsureRegistered first checks if a SFTP file system with the passed address
// is already registered. If not, then a new connection is dialed and registered.
// The returned free function has to be called to decrease the file system's
// reference count and close it when the reference count reaches 0.
func EnsureRegistered(ctx context.Context, address string, loginCallback LoginCallback, hostKeyCallback ssh.HostKeyCallback) (free func(), err error) {
u, username, password, prefix, err := prepareDial(address, loginCallback, hostKeyCallback)
if err != nil {
return nil, err
}
f := fs.GetFileSystemByPrefixOrNil(prefix)
if f != nil {
fs.Register(f) // Increase ref count
return func() { fs.Unregister(f) }, nil
}

client, err := dial(ctx, u.Host, username, password, hostKeyCallback)
if err != nil {
return nil, err
}
f = &fileSystem{
client: client,
prefix: prefix,
}
fs.Register(f)
return func() { fs.Unregister(f) }, nil
}

func dial(ctx context.Context, host, user, password string, hostKeyCallback ssh.HostKeyCallback) (*sftp.Client, error) {
config := &ssh.ClientConfig{
User: user,
Expand Down

0 comments on commit ac1b5b7

Please sign in to comment.