Skip to content

Commit

Permalink
Simplify clients (#4)
Browse files Browse the repository at this point in the history
* removed tmp_path tests

* added _BaseClient, refactored classes, tests

* fixes to Client class

* simplified conftest

* added tests for _BaseClient

* renamed __create_client to __connect_to_server
  • Loading branch information
charlottekostelic committed Aug 6, 2024
1 parent 240720e commit 857a6f8
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 223 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unit-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
python -m pip install --upgrade pip
python -m pip install -r dev-requirements.txt
- name: Run tests
run: pytest -m "not livetest and not tmpdir" --cov=file_retriever/
run: pytest -m "not livetest" --cov=file_retriever/
- name: Send report to Coveralls
uses: AndreMiras/coveralls-python-action@develop
with:
Expand Down
67 changes: 59 additions & 8 deletions file_retriever/_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

from abc import ABC, abstractmethod
import ftplib
import os
import paramiko
Expand All @@ -14,7 +15,53 @@
from file_retriever.file import File


class _ftpClient:
class _BaseClient(ABC):
""""""

@abstractmethod
def __init__(self, username: str, password: str, host: str, port: Union[str, int]):
self.connection: Union[ftplib.FTP, paramiko.SFTPClient] = (
self._connect_to_server(
username=username, password=password, host=host, port=int(port)
)
)

@abstractmethod
def __enter__(self, *args):
return self

@abstractmethod
def __exit__(self, *args):
self.connection.close()

@abstractmethod
def _connect_to_server(
self,
username: str,
password: str,
host: str,
port: int,
) -> Union[ftplib.FTP, paramiko.SFTPClient]:
pass

@abstractmethod
def get_remote_file_data(self, file: str, remote_dir: str) -> File:
pass

@abstractmethod
def list_remote_file_data(self, remote_dir: str) -> List[File]:
pass

@abstractmethod
def download_file(self, file: str, remote_dir: str, local_dir: str) -> None:
pass

@abstractmethod
def upload_file(self, file: str, remote_dir: str, local_dir: str) -> File:
pass


class _ftpClient(_BaseClient):
"""
An FTP client to use when interacting with remote storage. Supports
interactions with servers via the `ftplib` library.
Expand All @@ -36,7 +83,9 @@ def __init__(
port: port number for server
"""
self.connection = self._create_ftp_connection(
if port not in [21, "21"]:
raise ValueError("Invalid port number for FTP connection.")
self.connection: ftplib.FTP = self._connect_to_server(
username=username, password=password, host=host, port=int(port)
)

Expand All @@ -56,7 +105,7 @@ def __exit__(self, *args):
"""
self.connection.close()

def _create_ftp_connection(
def _connect_to_server(
self, username: str, password: str, host: str, port: int
) -> ftplib.FTP:
"""
Expand Down Expand Up @@ -135,7 +184,7 @@ def get_file_permissions(data):
f"Unable to retrieve file data: {sys.exc_info()[1]}"
)

def list_file_data(self, remote_dir: str) -> List[File]:
def list_remote_file_data(self, remote_dir: str) -> List[File]:
"""
Retrieves metadata for each file in `remote_dir` on server.
Expand Down Expand Up @@ -217,7 +266,7 @@ def upload_file(self, file: str, remote_dir: str, local_dir: str) -> File:
raise


class _sftpClient:
class _sftpClient(_BaseClient):
"""
An SFTP client to use when interacting with remote storage. Supports
interactions with servers via the `paramiko` library.
Expand All @@ -239,7 +288,9 @@ def __init__(
port: port number for server
"""
self.connection = self._create_sftp_connection(
if port not in [22, "22"]:
raise ValueError("Invalid port number for SFTP connection.")
self.connection: paramiko.SFTPClient = self._connect_to_server(
username=username, password=password, host=host, port=int(port)
)

Expand All @@ -259,7 +310,7 @@ def __exit__(self, *args):
"""
self.connection.close()

def _create_sftp_connection(
def _connect_to_server(
self, username: str, password: str, host: str, port: int
) -> paramiko.SFTPClient:
"""
Expand Down Expand Up @@ -311,7 +362,7 @@ def get_remote_file_data(self, file: str, remote_dir: str) -> File:
except OSError:
raise

def list_file_data(self, remote_dir: str) -> List[File]:
def list_remote_file_data(self, remote_dir: str) -> List[File]:
"""
Lists metadata for each file in `remote_dir` on server.
Expand Down
22 changes: 11 additions & 11 deletions file_retriever/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,23 @@ def __init__(

self.vendor = vendor
self.host = host
self.port = int(port)
self.port = port
self.remote_dir = remote_dir

self.session = self._create_client(username=username, password=password)
self.session = self.__connect_to_server(username=username, password=password)

def _create_client(
def __connect_to_server(
self, username: str, password: str
) -> Union[_ftpClient, _sftpClient]:
match self.port:
case 21:
case 21 | "21":
return _ftpClient(
username=username,
password=password,
host=self.host,
port=self.port,
)
case 22:
case 22 | "22":
return _sftpClient(
username=username,
password=password,
Expand Down Expand Up @@ -84,24 +84,24 @@ def check_file(self, file: str, check_dir: str, remote: bool) -> bool:
else:
return os.path.exists(os.path.join(check_dir, file))

def get_file_data(self, file: str, remote_dir: Optional[str] = None) -> List[File]:
def get_file_data(self, file: str, remote_dir: Optional[str] = None) -> File:
"""
Retrieve metadata for file in `remote_dir` on server. If `remote_dir` is not
provided then data for file in `self.remote_dir` will be retrieved.
Retrieve metadata for `file` in `remote_dir` on server. If `remote_dir` is not
provided then data for `file` in `self.remote_dir` will be retrieved.
Args:
file: name of file to retrieve metadata for
remote_dir: directory on server to interact with
Returns:
files in `remote_dir` represented as `File` object
file in `remote_dir` represented as `File` object
"""
if not remote_dir or remote_dir is None:
remote_dir = self.remote_dir
with self.session as session:
return session.get_remote_file_data(file, remote_dir)

def list_files(
def list_files_in_dir(
self, time_delta: int = 0, remote_dir: Optional[str] = None
) -> List[File]:
"""
Expand All @@ -122,7 +122,7 @@ def list_files(
if not remote_dir or remote_dir is None:
remote_dir = self.remote_dir
with self.session as session:
files = session.list_file_data(remote_dir)
files = session.list_remote_file_data(remote_dir)
if time_delta > 0:
return [
i
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ pytest-mock = "^3.14.0"
testpaths = ["tests"]
markers = [
"livetest: mark a test as hitting a live ftp/sftp server",
"tmpdir: mark a test as using a temporary directory",
]


Expand Down
Loading

0 comments on commit 857a6f8

Please sign in to comment.