diff --git a/file_retriever/_clients.py b/file_retriever/_clients.py index 9e68ff1..b963d09 100644 --- a/file_retriever/_clients.py +++ b/file_retriever/_clients.py @@ -1,24 +1,29 @@ -"""This module contains classes for interacting with remote storage via ftp and -sftp clients. - -Can be used to connect to vendor servers or internal network drives. +"""This module contains protected classes for interacting with remote storage +via ftp and sftp clients. +Can be used within `Client` class to connect to vendor servers or internal +network drives. """ from abc import ABC, abstractmethod import ftplib +import io import logging import os import paramiko -import sys -from typing import List, Union, Optional -from file_retriever.file import File +from typing import List, Union +from file_retriever.file import FileInfo, File +from file_retriever.errors import ( + RetrieverFileError, + RetrieverConnectionError, + RetrieverAuthenticationError, +) logger = logging.getLogger("file_retriever") class _BaseClient(ABC): - """""" + """An abstract base class for FTP and SFTP clients.""" @abstractmethod def __init__(self, username: str, password: str, host: str, port: Union[str, int]): @@ -39,34 +44,38 @@ def _connect_to_server( pass @abstractmethod - def close(self, *args) -> None: + def _check_dir(self, dir: str) -> None: pass @abstractmethod - def download_file(self, file: str, remote_dir: str, local_dir: str) -> None: + def close(self) -> None: pass @abstractmethod - def get_remote_file_data(self, file: str, remote_dir: str) -> File: + def fetch_file(self, file: FileInfo, dir: str) -> File: pass @abstractmethod - def is_active(self) -> bool: + def get_file_data(self, file_name: str, dir: str) -> FileInfo: + pass + + @abstractmethod + def list_file_data(self, dir: str) -> List[FileInfo]: pass @abstractmethod - def list_remote_file_data(self, remote_dir: str) -> List[File]: + def is_active(self) -> bool: pass @abstractmethod - def upload_file(self, file: str, remote_dir: str, local_dir: str) -> File: + def write_file(self, file: File, dir: str, remote: bool) -> FileInfo: pass class _ftpClient(_BaseClient): """ An FTP client to use when interacting with remote storage. Supports - interactions with servers via the `ftplib` library. + interactions with servers using an `ftplib.FTP` object. """ def __init__( @@ -103,7 +112,6 @@ def _connect_to_server( ftplib.error_temp: if unable to connect to server ftplib.error_perm: if unable to authenticate with server """ - logger.debug(f"Connecting to {host} via FTP client") try: ftp_client = ftplib.FTP() ftp_client.connect(host=host, port=port) @@ -112,108 +120,126 @@ def _connect_to_server( user=username, passwd=password, ) - logger.debug(f"Connected at {port} to {host}") return ftp_client - except ftplib.error_perm: - logger.error( - f"Unable to authenticate with provided credentials: {sys.exc_info()[1]}" - ) - raise - except ftplib.error_temp: - logger.error(f"Unable to connect to {host}: {sys.exc_info()[1]}") - raise + except ftplib.error_perm as e: + logger.error(f"Unable to authenticate with provided credentials: {e}") + raise RetrieverAuthenticationError + except ftplib.error_temp as e: + logger.error(f"Unable to connect to {host}: {e}") + raise RetrieverConnectionError + + def _check_dir(self, dir: str) -> None: + """Changes directory to `dir` if not already in `dir`.""" + if self.connection.pwd().lstrip("/") != dir.lstrip("/"): + self.connection.cwd(dir) + else: + pass - def close(self): + def close(self) -> None: """Closes connection to server.""" self.connection.close() - def download_file(self, file: str, remote_dir: str, local_dir: str) -> None: + def fetch_file(self, file: FileInfo, dir: str) -> File: """ - Downloads file from `remote_dir` on server to `local_dir`. + Retrieves file from `dir` on server as `File` object. The returned + `File` object contains the file's content as an `io.BytesIO` object + in the `File.file_stream` attribute and the file's metadata in the other + attributes. Args: file: - name of file to upload - remote_dir: - remote directory to download file from - local_dir: - local directory to download file to + `FileInfo` object representing metadata for file to fetch. + file is fetched based on `file_name` attribute. + dir: + directory on server to fetch file from Returns: - None + `File` object representing content and metadata of fetched file Raises: - OSError: if unable to download file from server or if file is not found + ftplib.error_perm: if unable to retrieve file from server + """ - local_file = os.path.normpath(os.path.join(local_dir, file)) - remote_file = os.path.normpath(os.path.join(remote_dir, file)) try: - logger.debug(f"Downloading {remote_file} to {local_file} via FTP client") - with open(local_file, "wb") as f: - self.connection.retrbinary(f"RETR {remote_file}", f.write) - except OSError: - logger.error( - f"Unable to download {remote_file} to {local_file}: {sys.exc_info()[1]}" - ) - raise - except ftplib.error_reply: - logger.error( - f"Received unexpected response from server: {sys.exc_info()[1]}" - ) - raise - - def get_remote_file_data(self, file: str, remote_dir: Optional[str] = None) -> File: + self._check_dir(dir) + fh = io.BytesIO() + self.connection.retrbinary(f"RETR {file.file_name}", fh.write) + fetched_file = File.from_fileinfo(file=file, file_stream=fh) + return fetched_file + except ftplib.error_perm as e: + logger.error(f"Unable to retrieve {file} from {dir}: {e}") + raise RetrieverFileError + + def get_file_data(self, file_name: str, dir: str) -> FileInfo: """ - Retrieves metadata for single file on server. + Retrieves metadata for file on server. Requires multiple + calls to server to retrieve file size, modification time, + and permissions. Args: - file: name of file to retrieve metadata for - remote_dir: directory on server to interact with + file_name: name of file to retrieve metadata for + dir: directory on server to interact with Returns: - `File` object representing file in `remote_dir` + `FileInfo` object representing metadata for file in `dir` Raises: - ftplib.error_reply: - if `file` or `remote_dir` does not exist or if server response - code is not in range 200-299 + ftplib.error_perm: + if unable to retrieve file data due to permissions error """ - if remote_dir is not None: - remote_file = f"{remote_dir}/{file}" - else: - remote_file = file - file_name = os.path.basename(remote_file) try: + self._check_dir(dir) + # Retrieve file permissions permissions = None def get_file_permissions(data): nonlocal permissions - permissions = File.parse_permissions(permissions_str=data) - - self.connection.retrlines(f"LIST {remote_file}", get_file_permissions), - if permissions is None: - logger.error(f"{file} not found on server.") - raise ftplib.error_perm("File not found on server.") - logger.debug(f"Retrieving file data for {remote_file}") - return File( + permissions = data[0:10] + + self.connection.retrlines(f"LIST {file_name}", get_file_permissions), + size = self.connection.size(file_name) + time = self.connection.voidcmd(f"MDTM {file_name}") + + if permissions is None or size is None or time is None: + logger.error(f"Unable to retrieve file data for {file_name}.") + raise RetrieverFileError + + return FileInfo( file_name=file_name, - file_size=self.connection.size(remote_file), - file_mtime=File.parse_mdtm_time( - self.connection.voidcmd(f"MDTM {remote_file}") - ), + file_size=size, + file_mtime=time[4:], file_mode=permissions, ) - except ftplib.error_reply: - logger.error( - f"Received unexpected response from server: {sys.exc_info()[1]}" - ) - raise except ftplib.error_perm: - logger.error( - f"Unable to retrieve file data for {file}: {sys.exc_info()[1]}" - ) - raise + raise RetrieverFileError + + def list_file_data(self, dir: str) -> List[FileInfo]: + """ + Retrieves metadata for each file in `dir` on server. + + Args: + dir: directory on server to interact with + + Returns: + list of `FileInfo` objects representing files in `dir` + returns an empty list if `dir` is empty or does not exist + + Raises: + ftplib.error_perm: + if unable to list file data due to permissions error + + """ + files = [] + try: + file_names = self.connection.nlst(dir) + for name in file_names: + file_base_name = os.path.basename(name) + file_info = self.get_file_data(file_name=file_base_name, dir=dir) + files.append(file_info) + except ftplib.error_perm: + raise RetrieverFileError + return files def is_active(self) -> bool: """ @@ -228,79 +254,62 @@ def is_active(self) -> bool: else: return False - def list_remote_file_data(self, remote_dir: str) -> List[File]: + def write_file(self, file: File, dir: str, remote: bool) -> FileInfo: """ - Retrieves metadata for each file in `remote_dir` on server. - - Args: - remote_dir: directory on server to interact with - - Returns: - list of `File` objects representing files in `remote_dir` - - Raises: - ftplib.error_reply: - if `remote_dir` does not exist or if server response code is - not in range 200-299 - """ - files = [] - try: - file_data_list = self.connection.nlst(remote_dir) - for data in file_data_list: - file = self.get_remote_file_data(data) - files.append(file) - except ftplib.error_reply: - logger.error( - f"Unable to retrieve file data for {remote_dir}: {sys.exc_info()[1]}" - ) - raise - logger.debug(f"Retrieved file data for {len(files)} files in {remote_dir}") - return files - - def upload_file(self, file: str, remote_dir: str, local_dir: str) -> File: - """ - Upload file from `local_dir` to `remote_dir` on server. + Writes file to directory. If `remote` is True, then file is written + to `dir` on server. If `remote` is False, then file is written to local + directory. Retrieves metadata for file after is has been written + and returns metadata as `FileInfo`. Args: file: - name of file to upload - remote_dir: - remote directory to upload file to - local_dir: - local directory to upload file from + `File` object representing file to write. content of file to + be written is stored in `File.file_stream` attribute. + dir: + directory to write file to + remote: + bool indicating if file should be written to remote or local + directory Returns: - uploaded file as `File` object + `FileInfo` object representing written file Raises: - OSError: - if unable to upload file to remote directory or if file is not found. - ftplib.error_reply: - if server response code is not in range 200-299 + ftplib.error_perm: if unable to write file to remote directory + OSError: if unable to write file to local directory """ - local_file = os.path.normpath(os.path.join(local_dir, file)) - remote_file = os.path.normpath(os.path.join(remote_dir, file)) - try: - logger.debug(f"Uploading {local_file} to {remote_dir} via FTP client") - with open(remote_file, "rb") as rf: - self.connection.storbinary(f"STOR {local_file}", rf) - return self.get_remote_file_data(file, remote_dir) - except OSError: - logger.error( - f"Unable to upload {local_file} to {remote_dir}: {sys.exc_info()[1]}" - ) - raise - except ftplib.error_reply: - logger.error( - f"Received unexpected response from server: {sys.exc_info()[1]}" - ) - raise + # make sure file stream is at beginning before writing + file.file_stream.seek(0) + + if remote is True: + try: + self._check_dir(dir) + self.connection.storbinary(f"STOR {file.file_name}", file.file_stream) + return self.get_file_data(file_name=file.file_name, dir=dir) + except ftplib.error_perm as e: + logger.error( + f"Unable to write {file.file_name} to remote directory: {e}" + ) + raise RetrieverFileError + else: + try: + local_file = f"{dir}/{file.file_name}" + with open(local_file, "wb") as lf: + lf.write(file.file_stream.getbuffer()) + return FileInfo.from_stat_data( + data=os.stat(local_file), file_name=file.file_name + ) + except OSError as e: + logger.error( + f"Unable to write {file.file_name} to local directory: {e}" + ) + raise RetrieverFileError class _sftpClient(_BaseClient): """ An SFTP client to use when interacting with remote storage. Supports - interactions with servers via the `paramiko` library. + interactions with servers using a `paramiko.SFTPClient` object. """ def __init__( @@ -338,7 +347,6 @@ def _connect_to_server( paramiko.AuthenticationException: if unable to authenticate with server """ - logger.debug(f"Connecting to {host} via SFTP client") try: ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) @@ -349,77 +357,102 @@ def _connect_to_server( password=password, ) sftp_client = ssh.open_sftp() - logger.debug(f"Connected at {port} to {host}") return sftp_client - except paramiko.AuthenticationException: - logger.error( - f"Unable to authenticate with provided credentials: {sys.exc_info()[1]}" - ) - raise - except paramiko.SSHException: - logger.error(f"Unable to connect to {host}: {sys.exc_info()[1]}") - raise + except paramiko.AuthenticationException as e: + logger.error(f"Unable to authenticate with provided credentials: {e}") + raise RetrieverAuthenticationError + except paramiko.SSHException as e: + logger.error(f"Unable to connect to {host}: {e}") + raise RetrieverConnectionError + + def _check_dir(self, dir: str) -> None: + """Changes directory to `dir` if not already in `dir`.""" + wd = self.connection.getcwd() + if wd is None: + self.connection.chdir(dir) + elif isinstance(wd, str) and wd.lstrip("/") != dir.lstrip("/"): + self.connection.chdir(None) + self.connection.chdir(dir) + else: + pass def close(self): """Closes connection to server.""" self.connection.close() - def download_file(self, file: str, remote_dir: str, local_dir: str) -> None: + def fetch_file(self, file: FileInfo, dir: str) -> File: """ - Downloads file from `remote_dir` on server to `local_dir`. + Retrieves file from `dir` on server as `File` object. The returned + `File` object contains the file's content as an `io.BytesIO` object + in the `File.file_stream` attribute and the file's metadata in the other + attributes. Args: file: - name of file to upload - remote_dir: - remote directory to download file from - local_dir: - local directory to download file to + `FileInfo` object representing metadata for file to fetch. + file is fetched based on `file_name` attribute. + dir: + directory on server to fetch file from Returns: - None + `File` object representing content and metadata of fetched file Raises: - OSError: if unable to download file from server or if file is not found + OSError: if unable to retrieve file from server + """ - local_file = os.path.normpath(os.path.join(local_dir, file)) - remote_file = os.path.normpath(os.path.join(remote_dir, file)) try: - logger.debug(f"Downloading {remote_file} to {local_file} via SFTP client") - self.connection.get(remote_file, local_file) - except OSError: - logger.error( - f"Unable to download {remote_file} to {local_file}: {sys.exc_info()[1]}" - ) - raise - - def get_remote_file_data(self, file: str, remote_dir: str) -> File: + self._check_dir(dir) + fh = io.BytesIO() + self.connection.getfo(remotepath=file.file_name, fl=fh) + fetched_file = File.from_fileinfo(file=file, file_stream=fh) + return fetched_file + except OSError as e: + logger.error(f"Unable to retrieve {file.file_name} from {dir}: {e}") + raise RetrieverFileError + + def get_file_data(self, file_name: str, dir: str) -> FileInfo: """ - Retrieves metadata for single file on server. + Retrieves metadata for file on server. Args: - file: name of file to retrieve metadata for - remote_dir: directory on server to interact with + file_name: name of file to retrieve metadata for + dir: directory on server to interact with Returns: - `File` object representing file in `remote_dir` + `FileInfo` object representing metadata for file in `dir` Raises: - ftplib.error_reply: - if `file` or `remote_dir` does not exist or if server response - code is not in range 200-299 + OSError: if file or `dir` does not exist """ - remote_file = os.path.normpath(os.path.join(remote_dir, file)) try: - logger.debug(f"Retrieving file data for {remote_file}") - return File.from_stat_data( - data=self.connection.stat(remote_file), file_name=file + self._check_dir(dir) + return FileInfo.from_stat_data( + data=self.connection.stat(file_name), file_name=file_name ) except OSError: - logger.error( - f"Unable to retrieve file data for {file}: {sys.exc_info()[1]}" - ) - raise + raise RetrieverFileError + + def list_file_data(self, dir: str) -> List[FileInfo]: + """ + Retrieves metadata for each file in `dir` on server. + + Args: + dir: directory on server to interact with + + Returns: + list of `FileInfo` objects representing files in `dir` + returns an empty list if `dir` is empty or does not exist + + Raises: + OSError: if `dir` does not exist + """ + try: + file_metadata = self.connection.listdir_attr(dir) + return [FileInfo.from_stat_data(data=i) for i in file_metadata] + except OSError as e: + logger.error(f"Unable to retrieve file data for {dir}: {e}") + raise RetrieverFileError def is_active(self) -> bool: """ @@ -434,60 +467,51 @@ def is_active(self) -> bool: else: return True - def list_remote_file_data(self, remote_dir: str) -> List[File]: - """ - Lists metadata for each file in `remote_dir` on server. - - Args: - remote_dir: directory on server to interact with - - Returns: - list of `File` objects representing files in `remote_dir` - - Raises: - OSError: if `remote_dir` does not exist - """ - try: - file_metadata = self.connection.listdir_attr(remote_dir) - logger.debug( - f"Retrieved file data for {len(file_metadata)} files in {remote_dir}" - ) - return [File.from_stat_data(data=i) for i in file_metadata] - except OSError: - logger.error( - f"Unable to retrieve file data for {remote_dir}: {sys.exc_info()[1]}" - ) - raise - - def upload_file(self, file: str, remote_dir: str, local_dir: str) -> File: + def write_file(self, file: File, dir: str, remote: bool) -> FileInfo: """ - Upload file from `local_dir` to `remote_dir` on server. + Writes file to directory. If `remote` is True, then file is written + to `dir` on server. If `remote` is False, then file is written to local + directory. Retrieves metadata for file after is has been written + and returns metadata as `FileInfo`. Args: file: - name of file to upload - remote_dir: - remote directory to upload file to - local_dir: - local directory to upload file from + `File` object representing file to write. content of file to + be written is stored in `File.file_stream` attribute. + dir: + directory to write file to + remote: + bool indicating if file should be written to remote or local + directory Returns: - uploaded file as `File` object + `FileInfo` object representing written file Raises: - OSError: - if unable to upload file to remote directory or if file is not found. - + OSError: if unable to write file to directory """ - local_file = os.path.normpath(os.path.join(local_dir, file)) - try: - logger.debug(f"Uploading {local_file} to {remote_dir} via SFTP client") - uploaded_file = self.connection.put( - local_file, f"{remote_dir}/{file}", confirm=True - ) - return File.from_stat_data(uploaded_file, file_name=file) - except OSError: - logger.error( - f"Unable to upload {local_file} to {remote_dir}: {sys.exc_info()[1]}" - ) - raise + file.file_stream.seek(0) + if remote: + try: + self._check_dir(dir) + written_file = self.connection.putfo( + file.file_stream, + remotepath=file.file_name, + ) + return FileInfo.from_stat_data(written_file, file_name=file.file_name) + except OSError as e: + logger.error( + f"Unable to write {file.file_name} to remote directory: {e}" + ) + raise RetrieverFileError + else: + try: + local_file = f"{dir}/{file}" + with open(local_file, "wb") as lf: + lf.write(file.file_stream.getbuffer()) + return FileInfo.from_stat_data(os.stat(local_file), file.file_name) + except OSError as e: + logger.error( + f"Unable to write {file.file_name} to local directory: {e}" + ) + raise RetrieverFileError diff --git a/file_retriever/connect.py b/file_retriever/connect.py index c1b31e3..777855e 100644 --- a/file_retriever/connect.py +++ b/file_retriever/connect.py @@ -1,6 +1,6 @@ -"""Public class for interacting with remote storage. Can be used for to create -ftp or sftp client. - +""" +This module contains the `Client` class which can be used to create an ftp or +sftp client to interact with remote storage. """ import datetime @@ -8,7 +8,8 @@ import os from typing import List, Optional, Union from file_retriever._clients import _ftpClient, _sftpClient -from file_retriever.file import File +from file_retriever.file import FileInfo, File +from file_retriever.errors import RetrieverFileError logger = logging.getLogger("file_retriever") @@ -22,7 +23,7 @@ class Client: def __init__( self, - vendor: str, + name: str, username: str, password: str, host: str, @@ -32,26 +33,37 @@ def __init__( """Initializes client instance. Args: - vendor: name of vendor - username: username for server - password: password for server - host: server address - port: port number for server - remote_dir: directory on server to interact with + name: + name of server or vendor (eg. 'leila', 'nsdrop'). primarily + used in logging to track client activity. + username: + username for server + password: + password for server + host: + server address + port: + port number for server. 21 for FTP, 22 for SFTP + remote_dir: + directory on server to interact with. for most vendor servers + there is a default directory to interact with (eg. 'files' or + 'invoices'). this directory will be used in methods that take + `remote_dir` as an arg if another value is not provided. """ - - self.vendor = vendor + self.name = name self.host = host self.port = port self.remote_dir = remote_dir self.session = self.__connect_to_server(username=username, password=password) + logger.info(f"({self.name}) Connected to server") def __connect_to_server( self, username: str, password: str ) -> Union[_ftpClient, _sftpClient]: match self.port: case 21 | "21": + logger.info(f"({self.name}) Connecting to {self.host} via FTP client") return _ftpClient( username=username, password=password, @@ -59,6 +71,7 @@ def __connect_to_server( port=self.port, ) case 22 | "22": + logger.info(f"({self.name}) Connecting to {self.host} via SFTP client") return _sftpClient( username=username, password=password, @@ -83,154 +96,180 @@ def __exit__(self, *args): Closes context manager. """ - logger.debug("Closing client session") + self.close() + + def close(self): + """Closes connection""" + logger.info(f"({self.name}) Closing client session") self.session.close() - logger.debug("Connection closed") + logger.info(f"({self.name}) Connection closed") def check_connection(self) -> bool: """Checks if connection to server is active.""" return self.session.is_active() - def check_file(self, file: str, check_dir: str, remote: bool) -> bool: + def file_exists(self, file: FileInfo, dir: str, remote: bool) -> bool: """ - Check if `file` exists in `check_dir`. If `remote` is True then check will - be performed on server, otherwise check will be performed locally. + Check if file (represented as `FileInfo` object) exists in `dir`. + If `remote` is the directory will be checked on the server connected + to via self.session, otherwise the local directory will be checked for + the file. Returns True if file with same name and size as `file` exists + in `dir`, otherwise False. Args: - file: name of file to check - check_dir: directory to check for file - remote: whether to check file on server (True) or locally (False) + file_name: file to check for as `FileInfo` object + dir: directory to check for file + remote: whether to check for file on server (True) or locally (False) Returns: - bool indicating if file exists in `check_dir` + bool indicating if `file` exists in `dir` """ if remote: - remote_file = self.session.get_remote_file_data(file, check_dir) - return remote_file.file_name == file + try: + check_file = self.session.get_file_data( + file_name=file.file_name, dir=dir + ) + return ( + check_file.file_name == file.file_name + and check_file.file_size == file.file_size + ) + except RetrieverFileError: + return False else: - return os.path.exists(os.path.join(check_dir, file)) + return os.path.exists(f"{dir}/{file.file_name}") - def get_file( - self, - file: str, - remote_dir: Optional[str] = None, - local_dir: str = ".", - check: bool = True, - ) -> File: + def get_file(self, file: FileInfo, remote_dir: Optional[str] = None) -> File: """ - Downloads `file` from `remote_dir` on server to `local_dir`. If `remote_dir` - is not provided then file will be downloaded from `self.remote_dir`. If - `local_dir` is not provided then file will be downloaded to cwd. If `check` is - True, then `local_dir` will be checked for file before downloading. + Fetches a file from a server. Args: - file: name of file to download - remote_dir: directory on server to download file from - local_dir: local directory to download file to - check: check if file exists in `local_dir` before downloading + files: file represented as `FileInfo` object + remote_dir: directory on server to fetch file from Returns: - file downloaded to `local_dir` as `File` object + file fetched from `remote_dir` as `File` object """ if not remote_dir or remote_dir is None: - logger.debug(f"Param `remote_dir` not passed. Using {self.remote_dir}.") remote_dir = self.remote_dir - if check and self.check_file(file, check_dir=local_dir, remote=False): - logger.error( - f"{file} not downloaded to {local_dir} because it already exists." - ) - raise FileExistsError - self.session.download_file( - file=file, remote_dir=remote_dir, local_dir=local_dir - ) - logger.debug(f"{file} downloaded to {local_dir} directory") - local_file = os.path.normpath(os.path.join(local_dir, file)) - return File.from_stat_data(os.stat(local_file), file) + logger.debug(f"({self.name}) Fetching {file.file_name} from " f"`{remote_dir}`") + return self.session.fetch_file(file=file, dir=remote_dir) - def get_file_data(self, file: str, remote_dir: Optional[str] = None) -> File: + def get_file_info( + self, file_name: str, remote_dir: Optional[str] = None + ) -> FileInfo: """ - Retrieve metadata for `file` in `remote_dir` on server. If `remote_dir` is not - provided then data for `file` in `self.remote_dir` will be retrieved. + Retrieves metadata for a file on server. Args: - file: name of file to retrieve metadata for + file_name: name of file to retrieve metadata for remote_dir: directory on server to interact with Returns: - file in `remote_dir` represented as `File` object + file in `remote_dir` represented as `FileInfo` object """ if not remote_dir or remote_dir is None: - logger.debug(f"Param `remote_dir` not passed. Using {self.remote_dir}.") remote_dir = self.remote_dir - return self.session.get_remote_file_data(file, remote_dir) + logger.debug( + f"({self.name}) Retrieving file info for {file_name} " f"from {remote_dir}" + ) + try: + return self.session.get_file_data(file_name=file_name, dir=remote_dir) + except RetrieverFileError as e: + logger.error( + f"({self.name}) Unable to retrieve file data for {file_name}: " f"{e}" + ) + raise e - def list_files_in_dir( - self, time_delta: int = 0, remote_dir: Optional[str] = None - ) -> List[File]: + def list_file_info( + self, + time_delta: Union[datetime.timedelta, int] = 0, + remote_dir: Optional[str] = None, + ) -> List[FileInfo]: """ - Lists each file in `remote_dir` directory on server. If `remote_dir` is not - provided then files in `self.remote_dir` will be listed. If `time_delta` - is provided then files created in the last x days will be listed where x - is the `time_delta`. + Lists each file in a directory on server. If `time_delta` + is provided then files created in period since today - time_delta + will be listed. time_delta can be an integer representing the number of + days or a `datetime.timedelta` object. Args: - time_delta: number of days to go back in time to list files - remote_dir: directory on server to interact with + time_delta: + how far back to check for files. can be an integer representing + the number of days or a `datetime.timedelta` object. default is 0, + ie. all files will be listed. + remote_dir: + directory on server to interact with Returns: - list of files in `remote_dir` represented as `File` objects + list of files in `remote_dir` represented as `FileInfo` objects """ - today = datetime.datetime.now() - + today = datetime.datetime.now(tz=datetime.timezone.utc) + if isinstance(time_delta, int): + time_delta = datetime.timedelta(days=time_delta) + else: + time_delta = time_delta if not remote_dir or remote_dir is None: - logger.debug(f"Param `remote_dir` not passed. Using {self.remote_dir}.") remote_dir = self.remote_dir - files = self.session.list_remote_file_data(remote_dir) - if time_delta > 0: - logger.debug(f"Checking for files modified in last {time_delta} days.") - return [ + logger.debug(f"({self.name}) Retrieving list of files in `{remote_dir}`") + files = self.session.list_file_data(dir=remote_dir) + if time_delta > datetime.timedelta(days=0): + logger.debug( + f"({self.name}) Filtering list for files created " + f"since {datetime.datetime.strftime((today - time_delta), '%Y-%m-%d')}" + ) + recent_files = [ i for i in files if datetime.datetime.fromtimestamp( i.file_mtime, tz=datetime.timezone.utc ) - >= today - datetime.timedelta(days=time_delta) + >= today - time_delta ] + logger.debug( + f"({self.name}) {len(recent_files)} recent files in `{remote_dir}`" + ) + return recent_files else: + logger.debug(f"({self.name}) {len(files)} in `{remote_dir}`") return files def put_file( self, - file: str, - local_dir: str = ".", - remote_dir: Optional[str] = None, - check: bool = True, - ) -> File: + file: File, + dir: str, + remote: bool, + check: bool, + ) -> Optional[FileInfo]: """ - Uploads file from local directory to server. If `remote_dir` is not - provided then file will be uploaded to `self.remote_dir`. If `local_dir` - is not provided then file will be uploaded from cwd. If `check` is - True, then `remote_dir` will be checked for file before downloading. + Writes file to directory. Args: - file: name of file to upload - local_dir: local directory to upload file from - remote_dir: remote directory to upload file to - check: check if file exists in `remote_dir` before uploading + file: + file as `File` object + dir: + directory to write file to + remote: + bool indicating if file should be written to remote or local storage. + + If True, then file is written to `dir` on server. + If False, then file is written to local `dir` directory. + check: + bool indicating if directory should be checked before writing file. + + If True, then `dir` will be checked for files matching the file_name + and file_size of `file` before writing to `dir`. If a match is found + then `file` will not be written. Returns: - file uploaded to `remote_dir` as `File` object + `FileInfo` objects representing written file """ - if remote_dir is None: - logger.debug(f"Param `remote_dir` not passed. Using {self.remote_dir}.") - remote_dir = self.remote_dir - if check and self.check_file(file, check_dir=remote_dir, remote=True): - logger.error( - f"{file} not uploaded to {remote_dir} because it already exists" + if check: + logger.debug(f"({self.name}) Checking for file in `{dir}` before writing") + if check and self.file_exists(file=file, dir=dir, remote=remote) is True: + logger.debug( + f"({self.name}) Skipping {file.file_name}. File already " + f"exists in `{dir}`." ) - raise FileExistsError - uploaded_file = self.session.upload_file( - file=file, remote_dir=remote_dir, local_dir=local_dir - ) - logger.debug(f"{file} uploaded from {local_dir} to {remote_dir} directory") - return uploaded_file + return None + else: + logger.debug(f"({self.name}) Writing {file.file_name} to `{dir}`") + return self.session.write_file(file=file, dir=dir, remote=remote) diff --git a/file_retriever/errors.py b/file_retriever/errors.py new file mode 100644 index 0000000..f9ca22a --- /dev/null +++ b/file_retriever/errors.py @@ -0,0 +1,25 @@ +"""This module contains custom exceptions for the file_retriever package.""" + + +class FileRetrieverError(Exception): + """Base class for exceptions in the file_retriever package.""" + + pass + + +class RetrieverAuthenticationError(FileRetrieverError): + """Exception raised for errors in authenticating to a server.""" + + pass + + +class RetrieverConnectionError(FileRetrieverError): + """Exception raised for errors in connecting to the file server.""" + + pass + + +class RetrieverFileError(FileRetrieverError): + """Exception raised for errors in finding or accessing a requested file.""" + + pass diff --git a/file_retriever/file.py b/file_retriever/file.py index 03563c3..70172b5 100644 --- a/file_retriever/file.py +++ b/file_retriever/file.py @@ -1,99 +1,175 @@ -from dataclasses import dataclass +"""This module contains classes to store file metadata and content.""" + import datetime +import io import os import paramiko from typing import Optional, Union -@dataclass -class File: - """A dataclass to store file information.""" +class FileInfo: + """A class to store file metadata.""" + + def __init__( + self, + file_name: str, + file_mtime: Union[float, int, str], + file_mode: Union[str, int], + file_size: int, + file_uid: Optional[int] = None, + file_gid: Optional[int] = None, + file_atime: Optional[float] = None, + ): + """Initialize `FileInfo` object with file metadata. + + File metadata includes attributes included in `os.stat_result` and + `paramiko.SFTPAttributes` objects. The `file_mtime` attribute can be a + float, int or string. If it is a string, it is parsed to a timestamp + as an int. The `file_mode` attribute can be a string or int. If it is a + string, it is parsed to a decimal value. `file_uid`, `file_gid` and + `file_atime` are optional attribute as they are not always available, + especially for files accessed via FTP. + + Args: + + file_name: name of file + file_mtime: file modification time + file_mode: file permissions + file_size: file size + file_uid: file owner user id + file_gid: file owner group id + file_atime: file access time + """ + self.file_name = file_name + self.file_size = file_size + self.file_uid = file_uid + self.file_gid = file_gid + self.file_atime = file_atime + if isinstance(file_mtime, str): + self.file_mtime = self.__parse_mdtm_time(file_mtime) + else: + self.file_mtime = int(file_mtime) - file_name: str - file_mtime: float - file_size: Optional[int] = None - file_uid: Optional[int] = None - file_gid: Optional[int] = None - file_atime: Optional[float] = None - file_mode: Optional[int] = None + if isinstance(file_mode, str): + self.file_mode = self.__parse_permissions(file_mode) + else: + self.file_mode = int(file_mode) @classmethod def from_stat_data( cls, data: Union[os.stat_result, paramiko.SFTPAttributes], file_name: Optional[str] = None, - ) -> "File": + ) -> "FileInfo": """ - Creates a `File` object from `os.stat_result` or `paramiko.SFTPAttributes` data. - Accepts data returned by `paramiko.SFTPClient.stat`, `paramiko.SFTPClient.put`, - `paramiko.SFTPClient.listdir_attr` or `os.stat` methods. + Creates a `FileInfo` object from `os.stat_result` or `paramiko.SFTPAttributes` + data. Accepts data returned by `paramiko.SFTPClient.stat`, + `paramiko.SFTPClient.put`, `paramiko.SFTPClient.listdir_attr` or `os.stat` + methods. Args: - stat_result_data: data formatted like os.stat_result + data: data formatted like `os.stat_result` object file_name: name of file Returns: - `File` object + `FileInfo` object """ - if file_name is not None: - filename = file_name - elif ( - isinstance(data, paramiko.SFTPAttributes) - and hasattr(data, "filename") is True - ): - filename = data.filename - elif ( - isinstance(data, paramiko.SFTPAttributes) - and hasattr(data, "longname") is True - ): - filename = data.longname[56:] + match data, file_name: + case data, file_name if file_name is not None: + file_name = file_name + case data, None if isinstance(data, paramiko.SFTPAttributes) and hasattr( + data, "filename" + ) and data.filename is not None: + file_name = data.filename + case data, None if isinstance(data, paramiko.SFTPAttributes) and hasattr( + data, "longname" + ) and data.longname is not None: + file_name = data.longname[56:] + case _: + raise AttributeError("No filename provided") + + match data.st_mode: + case data.st_mode if isinstance(data.st_mode, int): + st_mode: Union[str, int] = data.st_mode + case data.st_mode if isinstance( + data, paramiko.SFTPAttributes + ) and data.st_mode is None and hasattr( + data, "longname" + ) and data.longname is not None: + st_mode = data.longname[0:10] + case _: + raise AttributeError("No file mode provided") + + if hasattr(data, "st_size") and data.st_size is not None: + st_size = data.st_size else: - raise AttributeError("No filename provided") + raise AttributeError("No file size provided") - if not hasattr(data, "st_mtime") or data.st_mtime is None: + if ( + hasattr(data, "st_mtime") + and data.st_mtime is not None + and isinstance(data.st_mtime, float | int) + ): + st_mtime = data.st_mtime + else: raise AttributeError("No file modification time provided") return cls( - filename, - data.st_mtime, - data.st_size, - data.st_uid, - data.st_gid, - data.st_atime, - data.st_mode, + file_name=file_name, + file_mtime=st_mtime, + file_mode=st_mode, + file_size=st_size, + file_uid=data.st_uid, + file_gid=data.st_gid, + file_atime=data.st_atime, ) - @staticmethod - def parse_mdtm_time(mdtm_time: str) -> int: - """parse string returned by MDTM command to timestamp as int.""" + def __parse_mdtm_time(self, mdtm_time: str) -> int: + """parse date formatted as string (YYYYMMDDHHMMSS) to int timestamp.""" return int( - datetime.datetime.strptime(mdtm_time[4:], "%Y%m%d%H%M%S") + datetime.datetime.strptime(mdtm_time, "%Y%m%d%H%M%S") .replace(tzinfo=datetime.timezone.utc) .timestamp() ) - @staticmethod - def parse_permissions(permissions_str: str) -> int: + def __parse_permissions(self, file_mode: str) -> int: """ - parse permissions string to decimal value. - - permissions: - a 10 character string representing the permissions associated with - a file. The first character represents the file type, the next 9 - characters represent the permissions. - eg. '-rw-rw-rw-' - this string is parsed to extract the file mode in decimal notation - using the following formula: - digit 1 (filetype), digits 2-4 (owner permissions), digits 5-7 - (group permissions), and digits 8-10 (other permissions) are - converted to octal value (eg: '-rwxrwxrwx' -> 100777) the octal - number is then converted to a decimal value: - (filetype * 8^5) + (0 * 8^4) + (0 * 8^3) + (owner * 8^2) + - (group * 8^1) + (others * 8^0) = decimal value + parse permissions written as string in symbolic notation + (eg. -rwxrwxrwx) to decimal value. + + file_mode: + a 10 character string representing the permissions associated with + a file. The first character represents the file type, the next 9 + characters represent the owner, group, and public permissions. + eg. '-rwxrw-r--' + this string is parsed calculate the file's permission mode in + decimal notation using the following formula: + digit 1 (filetype): + the first character is converted to an octal value based + on the type: + d -> 4 (directory) + - -> 1 (file) + digits 2-10 (permissions by group): + within each group (2-4: owner, 5-7: group, and 8-10: public), + each digit is converted to an octal value: + r -> 4 (read) + w -> 2 (write) + x -> 1 (execute) + - -> 0 (no permission) + the octal values for each group of 3 characters are then added + 'rwxrwxrwx' -> '(4+2+1) (4+2+1) (4+2+1)' -> 777 + the octal values of the file type and permissions are then converted + to a decimal value using the following formula: + (filetype * 8^5) + (0 * 8^4) + (0 * 8^3) + (owner * 8^2) + + (group * 8^1) + (others * 8^0) = decimal value + example: + '-rwxrw-r--' -> 100764 + (1 * 8^5) + (0 * 8^4) + (0 * 8^3) + + (7 * 8^2) + (6 * 8^1) + (4 * 8^0) = 33264 """ - file_type = permissions_str[0].replace("d", "4").replace("-", "1") + file_type = file_mode[0].replace("d", "4").replace("-", "1") file_perm = ( - permissions_str[1:10] + file_mode[1:10] .replace("-", "0") .replace("r", "4") .replace("w", "2") @@ -107,3 +183,69 @@ def parse_permissions(permissions_str: str) -> int: + (int(int(file_perm[3]) + int(file_perm[4]) + int(file_perm[5])) * 8**1) + (int(int(file_perm[6]) + int(file_perm[7]) + int(file_perm[8])) * 8**0) ) + + +class File(FileInfo): + """A class to store file metadata and data stream.""" + + def __init__( + self, + file_name: str, + file_mtime: Union[float, str], + file_mode: Union[str, int], + file_size: int, + file_stream: io.BytesIO, + file_uid: Optional[int] = None, + file_gid: Optional[int] = None, + file_atime: Optional[float] = None, + ): + """Initialize `File` object with file metadata and data stream. + + File metadata includes attributes inherited from `FileInfo` class. + The `file_stream` attribute is a `io.BytesIO` object containing the + content of the file. + + Args: + + file_name: name of file + file_mtime: file modification time + file_mode: file permissions + file_size: file size + file_stream: file stream as `io.BytesIO` + file_uid: file owner user id + file_gid: file owner group id + file_atime: file access time + """ + super().__init__( + file_name=file_name, + file_mtime=file_mtime, + file_mode=file_mode, + file_size=file_size, + file_uid=file_uid, + file_gid=file_gid, + file_atime=file_atime, + ) + self.file_stream = file_stream + + @classmethod + def from_fileinfo(cls, file: FileInfo, file_stream: io.BytesIO) -> "File": + """ + Creates a `File` object from a `FileInfo` object and a file stream. + + Args: + file: `FileInfo` object + file_stream: file stream as `io.BytesIO` + + Returns: + `File` object + """ + return cls( + file_name=file.file_name, + file_mtime=file.file_mtime, + file_mode=file.file_mode, + file_size=file.file_size, + file_stream=file_stream, + file_uid=file.file_uid, + file_gid=file.file_gid, + file_atime=file.file_atime, + ) diff --git a/file_retriever/utils.py b/file_retriever/utils.py index 7f7f550..3f96848 100644 --- a/file_retriever/utils.py +++ b/file_retriever/utils.py @@ -1,9 +1,25 @@ +"""This module contains helper functions for the file_retriever package.""" + import os +from typing import List import yaml def logger_config() -> dict: - """Create dict for logger configuration""" + """ + Create and return dict for logger configuration. + + INFO and DEBUG logs are recorded in methods of the `Client` class while + ERROR logs are primarily recorded in methods of the `_ftpClient` and + `_sftpClient` classes. The one exception to this is ERROR messages + logged by the `_ftpClient` and `_sftpClient` `get_file_data` methods. + These are logged as errors in the `Client` class in order avoid logging + errors when files are not found by the `Client.file_exists` method. + + Returns: + dict: dictionary with logger configuration + + """ log_config_dict = { "version": 1, "formatters": { @@ -30,9 +46,24 @@ def logger_config() -> dict: return log_config_dict -def vendor_config(config_path: str) -> None: - """Set environment variables from config file""" +def client_config(config_path: str) -> List[str]: + """ + Read config file with credentials and set creds as environment variables. + Returns a list of names for servers whose credentials are stored in the + config file and have been added to env vars. + + Args: + config_path (str): Path to the yaml file with credendtials. + + Returns: + list of names of servers (eg. EASTVIEW, NSDROP) whose credentials are + stored in the config file and have been added to env vars + """ with open(config_path, "r") as file: config = yaml.safe_load(file) for k, v in config.items(): os.environ[k] = v + vendor_list = [ + i.split("_HOST")[0] for i in config.keys() if i.endswith("_HOST") + ] + return vendor_list diff --git a/tests/conftest.py b/tests/conftest.py index 2ca2aab..965dca5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,18 +1,22 @@ import datetime import ftplib +import logging import os import paramiko -from typing import Dict, List +from typing import Dict, List, Optional import yaml import pytest from file_retriever._clients import _ftpClient, _sftpClient, _BaseClient from file_retriever.connect import Client +from file_retriever.file import FileInfo + +logger = logging.getLogger("file_retriever") class FakeUtcNow(datetime.datetime): @classmethod - def now(cls, tzinfo=datetime.timezone.utc): - return cls(2024, 6, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc) + def now(cls, tz=datetime.timezone.utc): + return cls(2024, 6, 1, 1, 0, 0, 0, datetime.timezone.utc) class MockChannel: @@ -39,19 +43,47 @@ def __init__(self): self.st_uid = 0 self.st_size = 140401 - def create_SFTPAttributes(self): + def sftp_attr(self): sftp_attr = paramiko.SFTPAttributes() - sftp_attr.__dict__ = self.__dict__ sftp_attr.filename = self.file_name + sftp_attr.st_mtime = self.st_mtime + sftp_attr.st_mode = self.st_mode + sftp_attr.st_atime = self.st_atime + sftp_attr.st_gid = self.st_gid + sftp_attr.st_uid = self.st_uid + sftp_attr.st_size = self.st_size return sftp_attr + def file_info(self): + return FileInfo( + file_name=self.file_name, + file_mtime=self.st_mtime, + file_size=self.st_size, + file_uid=self.st_uid, + file_gid=self.st_gid, + file_atime=self.st_atime, + file_mode=self.st_mode, + ) + + def os_stat_result(self): + result = os.stat_result() + result.st_mtime = self.st_mtime + result.st_mode = self.st_mode + result.st_atime = self.st_atime + result.st_gid = self.st_gid + result.st_uid = self.st_uid + result.st_size = self.st_size + return result + @pytest.fixture -def mock_file_data(monkeypatch): - def mock_stat(*args, **kwargs): - return MockFileData() +def mock_sftp_attr(): + return MockFileData().sftp_attr() - monkeypatch.setattr(os, "stat", mock_stat) + +@pytest.fixture +def mock_file_info(): + return MockFileData().file_info() @pytest.fixture @@ -67,18 +99,25 @@ class MockFTP: def close(self, *args, **kwargs) -> None: pass + def cwd(self, pathname) -> str: + return pathname + def nlst(self, *args, **kwargs) -> List[str]: - return ["foo.mrc"] + return [MockFileData().file_name] - def retrbinary(self, *args, **kwargs) -> None: - pass + def pwd(self, *args, **kwargs) -> str: + return "/" + + def retrbinary(self, *args, **kwargs) -> bytes: + file = b"00000" + return args[1](file) def retrlines(self, *args, **kwargs) -> str: files = "-rw-r--r-- 1 0 0 140401 Jan 1 00:01 foo.mrc" return args[1](files) def size(self, *args, **kwargs) -> int: - return 140401 + return MockFileData().st_size def storbinary(self, *args, **kwargs) -> None: pass @@ -93,23 +132,29 @@ def voidcmd(self, *args, **kwargs) -> str: class MockSFTPClient: """Mock response from SFTP for a successful login""" - def close(self, *args, **kwargs) -> None: + def chdir(self, *args, **kwargs) -> None: pass - def get(self, *args, **kwargs) -> None: - open(args[1], "x+") + def close(self, *args, **kwargs) -> None: + pass def get_channel(self, *args, **kwargs) -> MockChannel: return MockChannel() + def getcwd(self) -> Optional[str]: + return None + + def getfo(self, remotepath, fl, *args, **kwargs) -> bytes: + return fl.write(b"00000") + def listdir_attr(self, *args, **kwargs) -> List[paramiko.SFTPAttributes]: - return [MockFileData().create_SFTPAttributes()] + return [MockFileData().sftp_attr()] - def put(self, *args, **kwargs) -> paramiko.SFTPAttributes: - return MockFileData().create_SFTPAttributes() + def putfo(self, *args, **kwargs) -> paramiko.SFTPAttributes: + return MockFileData().sftp_attr() def stat(self, *args, **kwargs) -> paramiko.SFTPAttributes: - return MockFileData().create_SFTPAttributes() + return MockFileData().sftp_attr() class MockABCClient: @@ -147,17 +192,39 @@ def mock_ftp_client(*args, **kwargs): def mock_sftp_client(*args, **kwargs): return MockSFTPClient() + def mock_stat(*args, **kwargs): + return MockFileData() + + monkeypatch.setattr(os, "stat", mock_stat) monkeypatch.setattr(_ftpClient, "_connect_to_server", mock_ftp_client) monkeypatch.setattr(_sftpClient, "_connect_to_server", mock_sftp_client) @pytest.fixture -def mock_Client(monkeypatch, mock_ftpClient_sftpClient, mock_file_data): +def mock_cwd(monkeypatch, mock_ftpClient_sftpClient): + def mock_root(*args, **kwargs): + return "/" + + monkeypatch.setattr(MockSFTPClient, "getcwd", mock_root) + monkeypatch.setattr(MockSFTPClient, "getcwd", mock_root) + + +@pytest.fixture +def mock_other_dir(monkeypatch, mock_ftpClient_sftpClient): + def mock_dir(*args, **kwargs): + return "bar" + + monkeypatch.setattr(MockSFTPClient, "getcwd", mock_dir) + monkeypatch.setattr(MockSFTPClient, "getcwd", mock_dir) + + +@pytest.fixture +def mock_Client(monkeypatch, mock_ftpClient_sftpClient): def mock_file_exists(*args, **kwargs): return False monkeypatch.setattr(os.path, "exists", mock_file_exists) - monkeypatch.setattr(Client, "check_file", mock_file_exists) + monkeypatch.setattr(Client, "file_exists", mock_file_exists) @pytest.fixture @@ -192,14 +259,6 @@ def mock_ssh_error(*args, **kwargs): monkeypatch.setattr(ftplib.FTP, "login", mock_ftp_error_temp) -@pytest.fixture -def mock_permissions_error(monkeypatch, mock_open_file, stub_client): - def mock_retrlines(*args, **kwargs): - return None - - monkeypatch.setattr(ftplib.FTP, "retrlines", mock_retrlines) - - @pytest.fixture def mock_file_error(monkeypatch, mock_open_file, mock_ftpClient_sftpClient): def mock_os_error(*args, **kwargs): @@ -208,18 +267,31 @@ def mock_os_error(*args, **kwargs): def mock_ftp_error_perm(*args, **kwargs): raise ftplib.error_perm - def mock_retrlines(*args, **kwargs): + def mock_none_return(*args, **kwargs): return None - monkeypatch.setattr(MockFTP, "voidcmd", mock_ftp_error_perm) - monkeypatch.setattr(MockFTP, "size", mock_ftp_error_perm) - monkeypatch.setattr(MockFTP, "retrlines", mock_retrlines) - monkeypatch.setattr(MockFTP, "retrbinary", mock_os_error) - monkeypatch.setattr(MockFTP, "storbinary", mock_os_error) monkeypatch.setattr(MockSFTPClient, "stat", mock_os_error) - monkeypatch.setattr(MockSFTPClient, "get", mock_os_error) - monkeypatch.setattr(MockSFTPClient, "put", mock_os_error) + monkeypatch.setattr(MockSFTPClient, "getfo", mock_os_error) + monkeypatch.setattr(MockSFTPClient, "putfo", mock_os_error) monkeypatch.setattr(MockSFTPClient, "listdir_attr", mock_os_error) + monkeypatch.setattr(os, "stat", mock_os_error) + monkeypatch.setattr(MockFTP, "voidcmd", mock_ftp_error_perm) + monkeypatch.setattr(MockFTP, "nlst", mock_ftp_error_perm) + monkeypatch.setattr(MockFTP, "retrbinary", mock_ftp_error_perm) + monkeypatch.setattr(MockFTP, "storbinary", mock_ftp_error_perm) + monkeypatch.setattr(MockFTP, "size", mock_ftp_error_perm) + monkeypatch.setattr(MockFTP, "retrlines", mock_none_return) + + +@pytest.fixture +def mock_ftp_file_not_found(monkeypatch, mock_open_file, mock_ftpClient_sftpClient): + def mock_none_return(*args, **kwargs): + return None + + monkeypatch.setattr(MockFTP, "voidcmd", mock_none_return) + monkeypatch.setattr(MockFTP, "nlst", mock_none_return) + monkeypatch.setattr(MockFTP, "size", mock_none_return) + monkeypatch.setattr(MockFTP, "retrlines", mock_none_return) @pytest.fixture @@ -265,7 +337,7 @@ def live_ftp_creds() -> Dict[str, str]: "password": data["LEILA_PASSWORD"], "host": data["LEILA_HOST"], "port": data["LEILA_PORT"], - "vendor": "leila", + "name": "leila", "remote_dir": data["LEILA_SRC"], } @@ -281,7 +353,7 @@ def live_sftp_creds() -> Dict[str, str]: "password": data["EASTVIEW_PASSWORD"], "host": data["EASTVIEW_HOST"], "port": data["EASTVIEW_PORT"], - "vendor": "eastview", + "name": "eastview", "remote_dir": data["EASTVIEW_SRC"], } @@ -296,5 +368,7 @@ def NSDROP_creds() -> Dict[str, str]: "username": data["NSDROP_USER"], "password": data["NSDROP_PASSWORD"], "host": data["NSDROP_HOST"], - "port": "22", + "port": data["NSDROP_PORT"], + "name": "nsdrop", + "remote_dir": data["NSDROP_SRC"], } diff --git a/tests/test_clients.py b/tests/test_clients.py index 94ee3fe..dca2ee2 100644 --- a/tests/test_clients.py +++ b/tests/test_clients.py @@ -1,23 +1,34 @@ from contextlib import nullcontext as does_not_raise import datetime -import ftplib -import os -import paramiko +import io +import logging +import logging.config import pytest from file_retriever._clients import _ftpClient, _sftpClient, _BaseClient -from file_retriever.file import File +from file_retriever.file import FileInfo, File +from file_retriever.utils import logger_config +from file_retriever.errors import ( + RetrieverFileError, + RetrieverConnectionError, + RetrieverAuthenticationError, +) +logger = logging.getLogger("file_retriever") +config = logger_config() +logging.config.dictConfig(config) -def test_BaseClient(): + +def test_BaseClient(mock_file_info): _BaseClient.__abstractmethods__ = set() ftp_bc = _BaseClient(username="foo", password="bar", host="baz", port=21) assert ftp_bc.__dict__ == {"connection": None} + assert ftp_bc._check_dir(dir="foo") is None assert ftp_bc.close() is None - assert ftp_bc.download_file("foo.mrc", "bar", "baz") is None - assert ftp_bc.get_remote_file_data("foo.mrc", "bar") is None + assert ftp_bc.fetch_file(file="foo.mrc", dir="bar") is None + assert ftp_bc.get_file_data(file_name="foo.mrc", dir="bar") is None + assert ftp_bc.list_file_data(dir="foo") is None assert ftp_bc.is_active() is None - assert ftp_bc.list_remote_file_data("foo") is None - assert ftp_bc.upload_file("foo.mrc", "bar", "baz") is None + assert ftp_bc.write_file(file=mock_file_info, dir="bar", remote=True) is None class TestMock_ftpClient: @@ -35,137 +46,147 @@ def test_ftpClient_no_creds(self, stub_client): def test_ftpClient_error_perm(self, mock_auth_error, stub_creds): stub_creds["port"] = "21" - with pytest.raises(ftplib.error_perm): + with pytest.raises(RetrieverAuthenticationError): _ftpClient(**stub_creds) def test_ftpClient_error_temp(self, mock_login_connection_error, stub_creds): stub_creds["port"] = "21" - with pytest.raises(ftplib.error_temp): + with pytest.raises(RetrieverConnectionError): _ftpClient(**stub_creds) - def test_ftpClient_close(self, mock_ftpClient_sftpClient, stub_creds): + def test_ftpClient_check_dir(self, mock_ftpClient_sftpClient, stub_creds): stub_creds["port"] = "21" ftp = _ftpClient(**stub_creds) - connection = ftp.close() - assert connection is None + with does_not_raise(): + ftp._check_dir(dir="foo") - def test_ftpClient_download_file(self, mock_ftpClient_sftpClient, stub_creds): + def test_ftpClient_check_dir_cwd(self, mock_cwd, stub_creds): stub_creds["port"] = "21" + ftp = _ftpClient(**stub_creds) with does_not_raise(): - ftp = _ftpClient(**stub_creds) - ftp.download_file(file="foo.mrc", remote_dir="bar", local_dir="test") + ftp._check_dir(dir="/") - def test_ftpClient_download_file_not_found(self, mock_file_error, stub_creds): + def test_ftpClient_close(self, mock_ftpClient_sftpClient, stub_creds): stub_creds["port"] = "21" - with pytest.raises(OSError): - ftp = _ftpClient(**stub_creds) - ftp.download_file(file="foo.mrc", remote_dir="bar", local_dir="test") + ftp = _ftpClient(**stub_creds) + connection = ftp.close() + assert connection is None - def test_ftpClient_download_connection_error( - self, mock_connection_error_reply, stub_creds + def test_ftpClient_fetch_file( + self, mock_ftpClient_sftpClient, mock_file_info, stub_creds ): stub_creds["port"] = "21" - with pytest.raises(ftplib.error_reply): - ftp = _ftpClient(**stub_creds) - ftp.download_file(file="foo.mrc", remote_dir="bar", local_dir="test") + ftp = _ftpClient(**stub_creds) + fh = ftp.fetch_file(file=mock_file_info, dir="bar") + assert fh.file_stream.getvalue()[0:1] == b"0" - def test_ftpClient_get_remote_file_data( - self, mock_ftpClient_sftpClient, stub_creds + def test_ftpClient_fetch_file_permissions_error( + self, mock_file_info, mock_file_error, stub_creds ): + stub_creds["port"] = "21" + with pytest.raises(RetrieverFileError): + ftp = _ftpClient(**stub_creds) + ftp.fetch_file(file=mock_file_info, dir="bar") + + def test_ftpClient_get_file_data(self, mock_ftpClient_sftpClient, stub_creds): stub_creds["port"] = "21" ftp = _ftpClient(**stub_creds) - files = ftp.get_remote_file_data("foo.mrc", "testdir") - assert files == File( - file_name="foo.mrc", - file_mtime=1704070800, - file_size=140401, - file_mode=33188, - file_uid=None, - file_gid=None, - file_atime=None, - ) + file_data = ftp.get_file_data(file_name="foo.mrc", dir="testdir") + assert file_data.file_name == "foo.mrc" + assert file_data.file_mtime == 1704070800 + assert file_data.file_size == 140401 + assert file_data.file_mode == 33188 + assert file_data.file_uid is None + assert file_data.file_gid is None + assert file_data.file_atime is None - def test_ftpClient_get_remote_file_data_connection_error( - self, mock_connection_error_reply, stub_creds - ): + def test_ftpClient_get_file_data_error_perm(self, mock_file_error, stub_creds): stub_creds["port"] = "21" ftp = _ftpClient(**stub_creds) - with pytest.raises(ftplib.error_reply): - ftp.get_remote_file_data("foo.mrc", "testdir") + with pytest.raises(RetrieverFileError): + ftp.get_file_data(file_name="foo.mrc", dir="testdir") - def test_ftpClient_get_remote_file_data_not_found( - self, mock_file_error, stub_creds + def test_ftpClient_get_file_data_file_not_found( + self, mock_ftp_file_not_found, stub_creds ): stub_creds["port"] = "21" ftp = _ftpClient(**stub_creds) - with pytest.raises(ftplib.error_perm): - ftp.get_remote_file_data("foo.mrc", "testdir") + with pytest.raises(RetrieverFileError): + ftp.get_file_data(file_name="foo.mrc", dir="testdir") - def test_ftpClient_get_remote_file_data_non_permissions( - self, mock_permissions_error, stub_creds - ): + def test_ftpClient_list_file_data(self, mock_ftpClient_sftpClient, stub_creds): + stub_creds["port"] = "21" + ftp = _ftpClient(**stub_creds) + files = ftp.list_file_data(dir="testdir") + assert all(isinstance(file, FileInfo) for file in files) + assert len(files) == 1 + assert files[0].file_name == "foo.mrc" + assert files[0].file_mtime == 1704070800 + assert files[0].file_size == 140401 + assert files[0].file_mode == 33188 + assert files[0].file_uid is None + assert files[0].file_gid is None + assert files[0].file_atime is None + + def test_ftpClient_list_file_data_error_perm(self, mock_file_error, stub_creds): stub_creds["port"] = "21" ftp = _ftpClient(**stub_creds) - with pytest.raises(ftplib.error_perm): - ftp.get_remote_file_data("foo.mrc", "testdir") + with pytest.raises(RetrieverFileError): + ftp.list_file_data(dir="testdir") - def test_ftpClient_is_active(self, mock_ftpClient_sftpClient, stub_creds): + def test_ftpClient_is_active_true(self, mock_ftpClient_sftpClient, stub_creds): stub_creds["port"] = "21" ftp = _ftpClient(**stub_creds) live_connection = ftp.is_active() assert live_connection is True - def test_ftpClient_is_inactive(self, mock_connection_dropped, stub_creds): + def test_ftpClient_is_active_false(self, mock_connection_dropped, stub_creds): stub_creds["port"] = "21" ftp = _ftpClient(**stub_creds) live_connection = ftp.is_active() assert live_connection is False - def test_ftpClient_list_remote_file_data( - self, mock_ftpClient_sftpClient, stub_creds + def test_ftpClient_write_file( + self, mock_ftpClient_sftpClient, mock_file_info, stub_creds ): stub_creds["port"] = "21" ftp = _ftpClient(**stub_creds) - files = ftp.list_remote_file_data("testdir") - assert files == [ - File( - file_name="foo.mrc", - file_mtime=1704070800, - file_size=140401, - file_mode=33188, - file_uid=None, - file_gid=None, - file_atime=None, - ) - ] - - def test_ftpClient_list_remote_file_data_connection_error( - self, mock_connection_error_reply, stub_creds + assert mock_file_info.file_name == "foo.mrc" + file_obj = File.from_fileinfo(file=mock_file_info, file_stream=io.BytesIO(b"0")) + assert file_obj.file_name == "foo.mrc" + remote_file = ftp.write_file(file=file_obj, dir="bar", remote=True) + local_file = ftp.write_file(file=file_obj, dir="bar", remote=False) + assert remote_file.file_mtime == 1704070800 + assert remote_file.file_size == 140401 + assert local_file.file_mtime == 1704070800 + assert local_file.file_size == 140401 + + def test_ftpClient_write_file_no_file_stream( + self, mock_file_error, mock_file_info, stub_creds ): stub_creds["port"] = "21" ftp = _ftpClient(**stub_creds) - with pytest.raises(ftplib.error_reply): - ftp.list_remote_file_data("testdir") + with pytest.raises(AttributeError) as exc: + ftp.write_file(file=mock_file_info, dir="bar", remote=False) + assert "'FileInfo' object has no attribute 'file_stream'" in str(exc.value) - def test_ftpClient_upload_file(self, mock_ftpClient_sftpClient, stub_creds): + def test_ftpClient_write_file_local_not_found( + self, mock_file_error, mock_file_info, stub_creds + ): stub_creds["port"] = "21" ftp = _ftpClient(**stub_creds) - file = ftp.upload_file(file="foo.mrc", local_dir="foo", remote_dir="bar") - assert file.file_mtime == 1704070800 - - def test_ftpClient_upload_file_not_found(self, mock_file_error, stub_creds): - stub_creds["port"] = "21" - with pytest.raises(OSError): - ftp = _ftpClient(**stub_creds) - ftp.upload_file(file="foo.mrc", remote_dir="foo", local_dir="bar") + file = File.from_fileinfo(file=mock_file_info, file_stream=io.BytesIO(b"0")) + with pytest.raises(RetrieverFileError): + ftp.write_file(file=file, dir="bar", remote=False) - def test_ftpClient_upload_connection_error( - self, mock_connection_error_reply, stub_creds + def test_ftpClient_write_file_remote_not_found( + self, mock_file_error, mock_file_info, stub_creds ): stub_creds["port"] = "21" - with pytest.raises(ftplib.error_reply): - ftp = _ftpClient(**stub_creds) - ftp.upload_file(file="foo.mrc", remote_dir="foo", local_dir="bar") + ftp = _ftpClient(**stub_creds) + file = File.from_fileinfo(file=mock_file_info, file_stream=io.BytesIO(b"0")) + with pytest.raises(RetrieverFileError): + ftp.write_file(file=file, dir="bar", remote=True) class TestMock_sftpClient: @@ -183,170 +204,210 @@ def test_sftpClient_no_creds(self, stub_client): def test_sftpClient_auth_error(self, mock_auth_error, stub_creds): stub_creds["port"] = "22" - with pytest.raises(paramiko.AuthenticationException): + with pytest.raises(RetrieverAuthenticationError): _sftpClient(**stub_creds) - def test_sftpclient_error_reply(self, mock_login_connection_error, stub_creds): + def test_sftpclient_SSHException(self, mock_login_connection_error, stub_creds): stub_creds["port"] = "22" - with pytest.raises(paramiko.SSHException): + with pytest.raises(RetrieverConnectionError): _sftpClient(**stub_creds) - def test_ftpClient_close(self, mock_ftpClient_sftpClient, stub_creds): + def test_sftpClient_check_dir(self, mock_ftpClient_sftpClient, stub_creds): stub_creds["port"] = "22" sftp = _sftpClient(**stub_creds) - connection = sftp.close() - assert connection is None + with does_not_raise(): + sftp._check_dir(dir="foo") - def test_sftpClient_download_file(self, mock_ftpClient_sftpClient, stub_creds): + def test_sftpClient_check_dir_cwd(self, mock_cwd, stub_creds): stub_creds["port"] = "22" + sftp = _sftpClient(**stub_creds) with does_not_raise(): - sftp = _sftpClient(**stub_creds) - sftp.download_file(file="foo.mrc", remote_dir="bar", local_dir="test") + sftp._check_dir(dir="/") - def test_sftpClient_download_file_not_found(self, mock_file_error, stub_creds): + def test_sftpClient_check_dir_other_dir(self, mock_other_dir, stub_creds): stub_creds["port"] = "22" sftp = _sftpClient(**stub_creds) - with pytest.raises(OSError): - sftp.download_file(file="foo.mrc", remote_dir="bar", local_dir="test") + with does_not_raise(): + sftp._check_dir(dir="foo") + + def test_sftpClient_close(self, mock_ftpClient_sftpClient, stub_creds): + stub_creds["port"] = "22" + sftp = _sftpClient(**stub_creds) + connection = sftp.close() + assert connection is None - def test_sftpClient_get_remote_file_data( - self, mock_ftpClient_sftpClient, stub_creds + def test_sftpClient_fetch_file( + self, mock_ftpClient_sftpClient, mock_file_info, stub_creds ): stub_creds["port"] = "22" - ftp = _sftpClient(**stub_creds) - file = ftp.get_remote_file_data("foo.mrc", "testdir") - assert file == File( - file_name="foo.mrc", - file_mtime=1704070800, - file_size=140401, - file_uid=0, - file_gid=0, - file_atime=None, - file_mode=33188, - ) + sftp = _sftpClient(**stub_creds) + fh = sftp.fetch_file(file=mock_file_info, dir="bar") + assert fh.file_stream.getvalue()[0:1] == b"0" - def test_sftpClient_get_remote_file_data_not_found( - self, mock_file_error, stub_creds + def test_sftpClient_fetch_file_not_found( + self, mock_file_info, mock_file_error, stub_creds ): stub_creds["port"] = "22" sftp = _sftpClient(**stub_creds) - with pytest.raises(OSError): - sftp.get_remote_file_data("foo.mrc", "testdir") + with pytest.raises(RetrieverFileError): + sftp.fetch_file(file=mock_file_info, dir="bar") - def test_sftpClient_is_active(self, mock_ftpClient_sftpClient, stub_creds): + def test_sftpClient_get_file_data(self, mock_ftpClient_sftpClient, stub_creds): + stub_creds["port"] = "22" + ftp = _sftpClient(**stub_creds) + file_data = ftp.get_file_data(file_name="foo.mrc", dir="testdir") + assert file_data.file_name == "foo.mrc" + assert file_data.file_mtime == 1704070800 + assert file_data.file_size == 140401 + assert file_data.file_mode == 33188 + assert file_data.file_uid == 0 + assert file_data.file_gid == 0 + assert file_data.file_atime is None + + def test_sftpClient_get_file_data_not_found(self, mock_file_error, stub_creds): + stub_creds["port"] = "22" + sftp = _sftpClient(**stub_creds) + with pytest.raises(RetrieverFileError): + sftp.get_file_data(file_name="foo.mrc", dir="testdir") + + def test_sftpClient_list_file_data(self, mock_ftpClient_sftpClient, stub_creds): + stub_creds["port"] = "22" + ftp = _sftpClient(**stub_creds) + files = ftp.list_file_data(dir="testdir") + assert all(isinstance(file, FileInfo) for file in files) + assert len(files) == 1 + assert files[0].file_name == "foo.mrc" + assert files[0].file_mtime == 1704070800 + assert files[0].file_size == 140401 + assert files[0].file_mode == 33188 + assert files[0].file_uid == 0 + assert files[0].file_gid == 0 + assert files[0].file_atime is None + + def test_sftpClient_list_file_data_not_found(self, mock_file_error, stub_creds): + stub_creds["port"] = "22" + sftp = _sftpClient(**stub_creds) + with pytest.raises(RetrieverFileError): + sftp.list_file_data(dir="testdir") + + def test_sftpClient_is_active_true(self, mock_ftpClient_sftpClient, stub_creds): stub_creds["port"] = "22" sftp = _sftpClient(**stub_creds) live_connection = sftp.is_active() assert live_connection is True - def test_sftpClient_is_inactive(self, mock_connection_dropped, stub_creds): + def test_sftpClient_is_active_false(self, mock_connection_dropped, stub_creds): stub_creds["port"] = "22" sftp = _sftpClient(**stub_creds) live_connection = sftp.is_active() assert live_connection is False - def test_sftpClient_list_remote_file_data( - self, mock_ftpClient_sftpClient, stub_creds + def test_sftpClient_write_file( + self, mock_ftpClient_sftpClient, mock_file_info, stub_creds ): stub_creds["port"] = "22" - ftp = _sftpClient(**stub_creds) - files = ftp.list_remote_file_data("testdir") - assert files == [ - File( - file_name="foo.mrc", - file_mtime=1704070800, - file_size=140401, - file_uid=0, - file_gid=0, - file_atime=None, - file_mode=33188, - ) - ] - - def test_sftpClient_list_remote_file_data_not_found( - self, mock_file_error, stub_creds + sftp = _sftpClient(**stub_creds) + file_obj = File.from_fileinfo(file=mock_file_info, file_stream=io.BytesIO(b"0")) + remote_file = sftp.write_file(file=file_obj, dir="bar", remote=True) + local_file = sftp.write_file(file=file_obj, dir="bar", remote=False) + assert remote_file.file_mtime == 1704070800 + assert local_file.file_mtime == 1704070800 + + def test_sftpClient_write_file_not_found_remote( + self, mock_file_error, mock_file_info, stub_creds, caplog ): stub_creds["port"] = "22" sftp = _sftpClient(**stub_creds) - with pytest.raises(OSError): - sftp.list_remote_file_data("testdir") + file_obj = File.from_fileinfo(file=mock_file_info, file_stream=io.BytesIO(b"0")) + with pytest.raises(RetrieverFileError): + sftp.write_file(file=file_obj, dir="bar", remote=True) + assert ( + f"Unable to write {mock_file_info.file_name} to remote directory" + in caplog.text + ) - def test_sftpClient_upload_file(self, mock_ftpClient_sftpClient, stub_creds): + def test_sftpClient_write_file_not_found_local( + self, mock_file_error, mock_file_info, stub_creds, caplog + ): stub_creds["port"] = "22" sftp = _sftpClient(**stub_creds) - file = sftp.upload_file(file="foo.mrc", local_dir="foo", remote_dir="bar") - assert file.file_mtime == 1704070800 + file_obj = File.from_fileinfo(file=mock_file_info, file_stream=io.BytesIO(b"0")) + with pytest.raises(RetrieverFileError): + sftp.write_file(file=file_obj, dir="bar", remote=False) + assert ( + f"Unable to write {mock_file_info.file_name} to local directory" + in caplog.text + ) - def test_sftpClient_upload_file_not_found(self, mock_file_error, stub_creds): + def test_sftpClient_write_file_no_file_stream( + self, mock_file_error, mock_file_info, stub_creds + ): stub_creds["port"] = "22" sftp = _sftpClient(**stub_creds) - with pytest.raises(OSError): - sftp.upload_file(file="foo.mrc", local_dir="foo", remote_dir="bar") + with pytest.raises(AttributeError) as exc: + sftp.write_file(file=mock_file_info, dir="bar", remote=False) + assert "'FileInfo' object has no attribute 'file_stream'" in str(exc.value) @pytest.mark.livetest class TestLiveClients: def test_ftpClient_live_test(self, live_ftp_creds): remote_dir = live_ftp_creds["remote_dir"] - del live_ftp_creds["remote_dir"], live_ftp_creds["vendor"] + del live_ftp_creds["remote_dir"], live_ftp_creds["name"] live_ftp = _ftpClient(**live_ftp_creds) - files = live_ftp.list_remote_file_data(remote_dir) - file_names = [file.file_name for file in files] - file_data = live_ftp.get_remote_file_data("Sample_Full_RDA.mrc", remote_dir) + file_list = live_ftp.list_file_data(dir=remote_dir) + file_names = [file.file_name for file in file_list] + file_data = live_ftp.get_file_data( + file_name="Sample_Full_RDA.mrc", dir=remote_dir + ) + fetched_file = live_ftp.fetch_file(file_data, remote_dir) assert "Sample_Full_RDA.mrc" in file_names assert "220" in live_ftp.connection.getwelcome() assert file_data.file_size == 7015 assert file_data.file_mode == 33188 + assert fetched_file.file_stream.getvalue()[0:1] == b"0" def test_ftpClient_live_test_no_creds(self, stub_creds): - with pytest.raises(OSError) as exc: + with pytest.raises(OSError): stub_creds["port"] = "21" _ftpClient(**stub_creds) - assert "getaddrinfo failed" in str(exc) def test_ftpClient_live_test_error_perm(self, live_ftp_creds): - del live_ftp_creds["remote_dir"], live_ftp_creds["vendor"] - with pytest.raises(ftplib.error_perm) as exc: + del live_ftp_creds["remote_dir"], live_ftp_creds["name"] + with pytest.raises(RetrieverAuthenticationError): live_ftp_creds["username"] = "bpl" _ftpClient(**live_ftp_creds) - assert "Login incorrect" in str(exc) def test_sftpClient_live_test(self, live_sftp_creds): remote_dir = live_sftp_creds["remote_dir"] - del live_sftp_creds["remote_dir"], live_sftp_creds["vendor"] + del live_sftp_creds["remote_dir"], live_sftp_creds["name"] live_sftp = _sftpClient(**live_sftp_creds) - files = live_sftp.list_remote_file_data(remote_dir) - file_data = live_sftp.get_remote_file_data("20049552_NYPL.mrc", remote_dir) + file_list = live_sftp.list_file_data(dir=remote_dir) + file_data = live_sftp.get_file_data( + file_name=file_list[0].file_name, dir=remote_dir + ) + fetched_file = live_sftp.fetch_file(file=file_data, dir=remote_dir) + assert live_sftp.connection.get_channel().active == 1 assert datetime.datetime.fromtimestamp( - files[0].file_mtime + file_list[0].file_mtime ) >= datetime.datetime(2020, 1, 1) - assert len(files) > 1 - assert live_sftp.connection.get_channel().active == 1 - assert file_data.file_size == 18759 - assert file_data.file_mode == 33261 + assert len(file_list) > 1 + assert file_data.file_size > 1 + assert file_data.file_mode > 32768 + assert fetched_file.file_stream.getvalue()[0:1] == b"0" def test_sftpClient_live_test_auth_error(self, live_sftp_creds): - del live_sftp_creds["remote_dir"], live_sftp_creds["vendor"] - with pytest.raises(paramiko.AuthenticationException) as exc: + del live_sftp_creds["remote_dir"], live_sftp_creds["name"] + with pytest.raises(RetrieverAuthenticationError): live_sftp_creds["username"] = "bpl" _sftpClient(**live_sftp_creds) - assert "Authentication failed." in str(exc) - - def test_sftpClient_NSDROP(self, NSDROP_creds, live_sftp_creds): - local_test_dir = "C://Users/ckostelic/github/file-retriever/temp" - nsdrop_remote_dir = "NSDROP/file_retriever_test/test_vendor" - ev_remote_dir = live_sftp_creds["remote_dir"] - ev_creds = { - k: v - for k, v in live_sftp_creds.items() - if k != "remote_dir" and k != "vendor" - } - ev_sftp = _sftpClient(**ev_creds) - ev_files = ev_sftp.list_remote_file_data(ev_remote_dir) - ev_sftp.download_file(ev_files[0].file_name, ev_remote_dir, local_test_dir) - nsdrop_sftp = _sftpClient(**NSDROP_creds) - nsdrop_file = nsdrop_sftp.upload_file( - ev_files[0].file_name, nsdrop_remote_dir, local_test_dir - ) - assert ev_files[0].file_name in os.listdir(local_test_dir) - assert ev_files[0].file_name == nsdrop_file.file_name + + def test_sftpClient_NSDROP(self, NSDROP_creds): + remote_dir = "NSDROP/file_retriever_test/test_vendor" + del NSDROP_creds["remote_dir"], NSDROP_creds["name"] + live_sftp = _sftpClient(**NSDROP_creds) + get_file = live_sftp.get_file_data(file_name="test.txt", dir=remote_dir) + fetched_file = live_sftp.fetch_file(file=get_file, dir=remote_dir) + assert fetched_file.file_stream.getvalue() == b"" + assert get_file.file_name == "test.txt" + assert get_file.file_size == 0 diff --git a/tests/test_connect.py b/tests/test_connect.py index 5e91905..f8d0659 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1,9 +1,20 @@ -import ftplib -import paramiko +import datetime +import io +import logging +import logging.config import pytest from file_retriever.connect import Client from file_retriever._clients import _ftpClient, _sftpClient -from file_retriever.file import File +from file_retriever.file import FileInfo, File +from file_retriever.utils import logger_config +from file_retriever.errors import ( + RetrieverFileError, + RetrieverAuthenticationError, +) + +logger = logging.getLogger("file_retriever") +config = logger_config() +logging.config.dictConfig(config) class TestMockClient: @@ -26,10 +37,10 @@ def test_Client(self, mock_Client, stub_creds, port, client_type): ( stub_creds["port"], stub_creds["remote_dir"], - stub_creds["vendor"], + stub_creds["name"], ) = (port, "testdir", "test") connect = Client(**stub_creds) - assert connect.vendor == "test" + assert connect.name == "test" assert connect.host == "ftp.testvendor.com" assert connect.port == port assert connect.remote_dir == "testdir" @@ -39,7 +50,7 @@ def test_Client_invalid_port(self, mock_Client, stub_creds): ( stub_creds["port"], stub_creds["remote_dir"], - stub_creds["vendor"], + stub_creds["name"], ) = (1, "testdir", "test") with pytest.raises(ValueError) as e: Client(**stub_creds) @@ -49,18 +60,18 @@ def test_Client_ftp_auth_error(self, mock_auth_error, stub_creds): ( stub_creds["port"], stub_creds["remote_dir"], - stub_creds["vendor"], + stub_creds["name"], ) = (21, "testdir", "test") - with pytest.raises(ftplib.error_perm): + with pytest.raises(RetrieverAuthenticationError): Client(**stub_creds) def test_Client_sftp_auth_error(self, mock_auth_error, stub_creds): ( stub_creds["port"], stub_creds["remote_dir"], - stub_creds["vendor"], + stub_creds["name"], ) = (22, "testdir", "test") - with pytest.raises(paramiko.AuthenticationException): + with pytest.raises(RetrieverAuthenticationError): Client(**stub_creds) @pytest.mark.parametrize( @@ -71,7 +82,7 @@ def test_Client_context_manager(self, mock_Client, stub_creds, port): ( stub_creds["port"], stub_creds["remote_dir"], - stub_creds["vendor"], + stub_creds["name"], ) = (port, "testdir", "test") with Client(**stub_creds) as connect: assert connect.session is not None @@ -84,7 +95,7 @@ def test_Client_check_connection(self, mock_Client, stub_creds, port): ( stub_creds["port"], stub_creds["remote_dir"], - stub_creds["vendor"], + stub_creds["name"], ) = (port, "testdir", "test") connect = Client(**stub_creds) live_connection = connect.check_connection() @@ -94,298 +105,248 @@ def test_Client_check_connection(self, mock_Client, stub_creds, port): "port", [21, 22], ) - def test_Client_check_file_local(self, mock_Client_file_exists, stub_creds, port): + def test_Client_file_exists_true( + self, mock_Client_file_exists, stub_creds, port, mock_file_info + ): ( stub_creds["port"], stub_creds["remote_dir"], - stub_creds["vendor"], + stub_creds["name"], ) = (port, "testdir", "test") connect = Client(**stub_creds) - file_exists = connect.check_file(file="foo.mrc", check_dir="bar", remote=False) - assert file_exists is True + get_file = connect.get_file_info(file_name="foo.mrc", remote_dir="testdir") + local_file_exists = connect.file_exists( + file=mock_file_info, dir="bar", remote=False + ) + assert mock_file_info.file_name == get_file.file_name + assert mock_file_info.file_size == get_file.file_size + remote_file_exists = connect.file_exists( + file=mock_file_info, dir="bar", remote=True + ) + assert local_file_exists is True + assert remote_file_exists is True - @pytest.mark.parametrize( - "port", - [21, 22], - ) - def test_Client_check_file_remote(self, mock_Client_file_exists, stub_creds, port): + def test_Client_file_exists_sftp_file_not_found( + self, mock_file_error, stub_creds, mock_file_info + ): ( stub_creds["port"], stub_creds["remote_dir"], - stub_creds["vendor"], - ) = (port, "testdir", "test") + stub_creds["name"], + ) = (22, "testdir", "test") connect = Client(**stub_creds) - file_exists = connect.check_file(file="foo.mrc", check_dir="bar", remote=True) - assert file_exists is True + file_exists = connect.file_exists(file=mock_file_info, dir="bar", remote=True) + assert file_exists is False @pytest.mark.parametrize( - "port, dir, uid_gid", - [(21, "testdir", None), (21, None, None), (22, "testdir", 0), (22, None, 0)], + "port, dir", + [(21, "testdir"), (21, None), (22, "testdir"), (22, None)], ) - def test_Client_get_file_data(self, mock_Client, stub_creds, port, dir, uid_gid): + def test_Client_get_file(self, mock_Client, mock_file_info, stub_creds, port, dir): ( stub_creds["port"], stub_creds["remote_dir"], - stub_creds["vendor"], + stub_creds["name"], ) = (port, "testdir", "test") connect = Client(**stub_creds) - file = connect.get_file_data(file="foo.mrc", remote_dir=dir) - assert file == File( - file_name="foo.mrc", - file_mtime=1704070800, - file_size=140401, - file_mode=33188, - file_uid=uid_gid, - file_gid=uid_gid, - file_atime=None, - ) + file = File.from_fileinfo(file=mock_file_info, file_stream=io.BytesIO(b"0")) + file = connect.get_file(file=file, remote_dir=dir) + assert isinstance(file, File) - def test_Client_ftp_get_file_data_not_found( - self, mock_connection_error_reply, stub_creds + def test_Client_ftp_get_file_error_perm( + self, mock_file_error, mock_file_info, stub_creds ): ( stub_creds["port"], stub_creds["remote_dir"], - stub_creds["vendor"], + stub_creds["name"], ) = (21, "testdir", "test") connect = Client(**stub_creds) - with pytest.raises(ftplib.error_reply): - connect.get_file_data("foo.mrc", "testdir") + file_obj = File.from_fileinfo(file=mock_file_info, file_stream=io.BytesIO(b"0")) + with pytest.raises(RetrieverFileError): + connect.get_file(file=file_obj, remote_dir="bar_dir") - def test_Client_sftp_get_file_data_not_found(self, mock_file_error, stub_creds): + def test_Client_sftp_get_file_not_found( + self, mock_file_error, mock_file_info, stub_creds + ): ( stub_creds["port"], stub_creds["remote_dir"], - stub_creds["vendor"], + stub_creds["name"], ) = (22, "testdir", "test") connect = Client(**stub_creds) - with pytest.raises(OSError): - connect.get_file_data("foo.mrc", "testdir") + file_obj = File.from_fileinfo(file=mock_file_info, file_stream=io.BytesIO(b"0")) + with pytest.raises(RetrieverFileError): + connect.get_file(file=file_obj, remote_dir="bar_dir") - @pytest.mark.parametrize("port, uid_gid", [(21, None), (22, 0)]) - def test_Client_list_files_in_dir(self, mock_Client, stub_creds, port, uid_gid): + @pytest.mark.parametrize( + "port, dir, uid_gid", + [(21, "testdir", None), (21, None, None), (22, "testdir", 0), (22, None, 0)], + ) + def test_Client_get_file_info(self, mock_Client, stub_creds, port, dir, uid_gid): ( stub_creds["port"], stub_creds["remote_dir"], - stub_creds["vendor"], + stub_creds["name"], ) = (port, "testdir", "test") connect = Client(**stub_creds) - all_files = connect.list_files_in_dir() - recent_files = connect.list_files_in_dir(time_delta=5, remote_dir="testdir") - assert all_files == [ - File( - file_name="foo.mrc", - file_mtime=1704070800, - file_size=140401, - file_mode=33188, - file_uid=uid_gid, - file_gid=uid_gid, - file_atime=None, - ) - ] - assert recent_files == [] + file = connect.get_file_info(file_name="foo.mrc", remote_dir=dir) + assert isinstance(file, FileInfo) + assert file.file_name == "foo.mrc" + assert file.file_mtime == 1704070800 + assert file.file_size == 140401 + assert file.file_mode == 33188 + assert file.file_uid == uid_gid + assert file.file_gid == uid_gid + assert file.file_atime is None - def test_Client_list_ftp_file_not_found( - self, mock_connection_error_reply, stub_creds - ): + def test_Client_ftp_get_file_info_not_found(self, mock_file_error, stub_creds): ( stub_creds["port"], stub_creds["remote_dir"], - stub_creds["vendor"], + stub_creds["name"], ) = (21, "testdir", "test") connect = Client(**stub_creds) - with pytest.raises(ftplib.error_reply): - connect.list_files_in_dir() + with pytest.raises(RetrieverFileError): + connect.get_file_info(file_name="foo.mrc", remote_dir="testdir") - def test_Client_list_sftp_file_not_found(self, mock_file_error, stub_creds): + def test_Client_sftp_get_file_info_not_found(self, mock_file_error, stub_creds): ( stub_creds["port"], stub_creds["remote_dir"], - stub_creds["vendor"], + stub_creds["name"], ) = (22, "testdir", "test") connect = Client(**stub_creds) - with pytest.raises(OSError): - connect.list_files_in_dir() + with pytest.raises(RetrieverFileError): + connect.get_file_info(file_name="foo.mrc", remote_dir="testdir") - @pytest.mark.parametrize( - "port, dir", - [(21, "testdir"), (21, None), (22, "testdir"), (22, None)], - ) - def test_Client_get_file(self, mock_Client, stub_creds, port, dir): + @pytest.mark.parametrize("port, uid_gid", [(21, None), (22, 0)]) + def test_Client_list_file_info(self, mock_Client, stub_creds, port, uid_gid): ( stub_creds["port"], stub_creds["remote_dir"], - stub_creds["vendor"], + stub_creds["name"], ) = (port, "testdir", "test") connect = Client(**stub_creds) - downloaded_file = connect.get_file( - "foo.mrc", remote_dir=dir, local_dir="baz_dir", check=False - ) - assert downloaded_file == File( - file_name="foo.mrc", - file_mtime=1704070800, - file_size=140401, - file_mode=33188, - file_uid=0, - file_gid=0, - file_atime=None, + all_files = connect.list_file_info() + recent_files_int = connect.list_file_info(time_delta=5, remote_dir="testdir") + recent_files_dt = connect.list_file_info( + time_delta=datetime.timedelta(days=5), remote_dir="testdir" ) + assert all(isinstance(file, FileInfo) for file in all_files) + assert all(isinstance(file, FileInfo) for file in recent_files_int) + assert all(isinstance(file, FileInfo) for file in recent_files_dt) + assert len(all_files) == 1 + assert len(recent_files_int) == 0 + assert len(recent_files_dt) == 0 + assert all_files[0].file_name == "foo.mrc" + assert all_files[0].file_mtime == 1704070800 + assert all_files[0].file_size == 140401 + assert all_files[0].file_mode == 33188 + assert all_files[0].file_uid == uid_gid + assert all_files[0].file_gid == uid_gid + assert all_files[0].file_atime is None + assert recent_files_int == [] + assert recent_files_dt == [] - @pytest.mark.parametrize("port", [21, 22]) - def test_Client_get_file_not_found(self, mock_file_error, stub_creds, port): + def test_Client_list_sftp_file_not_found(self, mock_file_error, stub_creds): ( stub_creds["port"], stub_creds["remote_dir"], - stub_creds["vendor"], - ) = (port, "testdir", "test") + stub_creds["name"], + ) = (22, "testdir", "test") connect = Client(**stub_creds) - with pytest.raises(OSError): - connect.get_file( - "foo.mrc", remote_dir="bar_dir", local_dir="baz_dir", check=False - ) + with pytest.raises(RetrieverFileError): + connect.list_file_info() @pytest.mark.parametrize( - "port", - [21, 22], + "port, check", + [(21, True), (21, False), (22, True), (22, False)], ) - def test_Client_get_check_file_exists_true( - self, mock_Client_file_exists, stub_creds, port + def test_Client_put_file( + self, mock_Client, mock_file_info, stub_creds, port, check ): ( stub_creds["port"], stub_creds["remote_dir"], - stub_creds["vendor"], + stub_creds["name"], ) = (port, "testdir", "test") connect = Client(**stub_creds) - with pytest.raises(FileExistsError): - connect.get_file("foo.mrc", "testdir", check=True) + file = mock_file_info + file.file_stream = io.BytesIO(b"0") + local_file = connect.put_file(file=file, dir="bar", remote=False, check=check) + remote_file = connect.put_file(file=file, dir="bar", remote=True, check=check) + assert remote_file.file_mtime == 1704070800 + assert local_file.file_mtime == 1704070800 - @pytest.mark.parametrize( - "port, dir", - [(21, "testdir"), (21, None), (22, "testdir"), (22, None)], - ) - def test_Client_get_check_file_exists_false( - self, mock_Client, stub_creds, port, dir + def test_Client_ftp_put_file_error_perm( + self, mock_file_error, mock_file_info, stub_creds ): ( stub_creds["port"], stub_creds["remote_dir"], - stub_creds["vendor"], - ) = (port, "testdir", "test") - connect = Client(**stub_creds) - downloaded_file = connect.get_file( - "foo.mrc", remote_dir=dir, local_dir="baz_dir", check=True - ) - assert downloaded_file == File( - file_name="foo.mrc", - file_mtime=1704070800, - file_size=140401, - file_mode=33188, - file_uid=0, - file_gid=0, - file_atime=None, - ) - - @pytest.mark.parametrize( - "port, dir, uid_gid", - [(21, None, None), (21, "test", None), (22, None, 0), (22, "test", 0)], - ) - def test_Client_put_file(self, mock_Client, stub_creds, port, dir, uid_gid): - ( - stub_creds["port"], - stub_creds["remote_dir"], - stub_creds["vendor"], - ) = (port, "foo", "test") - connect = Client(**stub_creds) - put_file = connect.put_file( - "foo.mrc", remote_dir=dir, local_dir="baz_dir", check=False - ) - assert put_file == File( - file_name="foo.mrc", - file_mtime=1704070800, - file_size=140401, - file_mode=33188, - file_uid=uid_gid, - file_gid=uid_gid, - file_atime=None, - ) - - @pytest.mark.parametrize("port", [21, 22]) - def test_Client_put_file_not_found(self, mock_file_error, stub_creds, port): - ( - stub_creds["port"], - stub_creds["remote_dir"], - stub_creds["vendor"], - ) = (port, "testdir", "test") + stub_creds["name"], + ) = (21, "testdir", "test") connect = Client(**stub_creds) - with pytest.raises(OSError): - connect.put_file( - "foo.mrc", remote_dir="bar_dir", local_dir="baz_dir", check=False - ) + mock_file_info.file_stream = io.BytesIO(b"0") + with pytest.raises(RetrieverFileError): + connect.put_file(file=mock_file_info, dir="bar", remote=True, check=False) - def test_Client_put_client_error_reply( - self, mock_connection_error_reply, stub_creds + def test_Client_ftp_put_file_OSError( + self, mock_file_error, mock_file_info, stub_creds ): ( stub_creds["port"], stub_creds["remote_dir"], - stub_creds["vendor"], + stub_creds["name"], ) = (21, "testdir", "test") - client = Client(**stub_creds) - with pytest.raises(ftplib.error_reply): - client.put_file( - "foo.mrc", remote_dir="bar_dir", local_dir="baz_dir", check=False - ) + connect = Client(**stub_creds) + mock_file_info.file_stream = io.BytesIO(b"0") + with pytest.raises(RetrieverFileError): + connect.put_file(file=mock_file_info, dir="bar", remote=False, check=False) @pytest.mark.parametrize( - "port", - [21, 22], + "remote", + [True, False], ) - def test_Client_put_check_file_exists_true( - self, mock_Client_file_exists, stub_creds, port + def test_Client_sftp_put_file_OSError( + self, mock_file_error, mock_file_info, stub_creds, remote ): ( stub_creds["port"], stub_creds["remote_dir"], - stub_creds["vendor"], - ) = (port, "testdir", "test") + stub_creds["name"], + ) = (22, "testdir", "test") connect = Client(**stub_creds) - with pytest.raises(FileExistsError): - connect.put_file( - "foo.mrc", remote_dir="bar_dir", local_dir="baz_dir", check=True - ) + mock_file_info.file_stream = io.BytesIO(b"0") + with pytest.raises(RetrieverFileError): + connect.put_file(file=mock_file_info, dir="bar", remote=remote, check=False) @pytest.mark.parametrize( - "port, uid_gid", - [(21, None), (22, 0)], + "port, remote", + [(21, True), (21, False), (22, True), (22, False)], ) - def test_Client_put_check_file_exists_false( - self, mock_Client, stub_creds, port, uid_gid + def test_Client_put_file_exists( + self, mock_Client_file_exists, mock_file_info, stub_creds, caplog, port, remote ): ( stub_creds["port"], stub_creds["remote_dir"], - stub_creds["vendor"], + stub_creds["name"], ) = (port, "testdir", "test") connect = Client(**stub_creds) - uploaded_file = connect.put_file( - "foo.mrc", remote_dir="bar_dir", local_dir="baz_dir", check=True - ) - assert uploaded_file == File( - file_name="foo.mrc", - file_mtime=1704070800, - file_size=140401, - file_mode=33188, - file_gid=uid_gid, - file_uid=uid_gid, - file_atime=None, + mock_file_info.file_stream = io.BytesIO(b"0") + connect.put_file(file=mock_file_info, dir="bar", remote=remote, check=True) + assert ( + f"Skipping {mock_file_info.file_name}. File already exists in `bar`." + in caplog.text ) @pytest.mark.livetest def test_Client_ftp_live_test(live_ftp_creds): live_ftp = Client(**live_ftp_creds) - files = live_ftp.list_files_in_dir() + files = live_ftp.list_file_info() assert len(files) > 1 assert "220" in live_ftp.session.connection.getwelcome() @@ -393,6 +354,6 @@ def test_Client_ftp_live_test(live_ftp_creds): @pytest.mark.livetest def test_Client_sftp_live_test(live_sftp_creds): live_sftp = Client(**live_sftp_creds) - files = live_sftp.list_files_in_dir() + files = live_sftp.list_file_info() assert len(files) > 1 assert live_sftp.session.connection.get_channel().active == 1 diff --git a/tests/test_file.py b/tests/test_file.py index 5544b77..7fe2f33 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -1,88 +1,103 @@ -import os +import io import paramiko import pytest -from file_retriever.file import File +from file_retriever.file import FileInfo, File -def test_File(): - file = File(file_name="foo.mrc", file_mtime=1704070800) +def test_FileInfo(): + file = FileInfo( + file_name="foo.mrc", file_mtime=1704070800, file_mode="-rw-r--r--", file_size=1 + ) assert file.file_name == "foo.mrc" assert file.file_mtime == 1704070800 + assert file.file_mode == 33188 + assert file.file_size == 1 + assert file.file_gid is None + assert file.file_uid is None + assert file.file_atime is None assert isinstance(file.file_name, str) assert isinstance(file.file_mtime, int) - assert isinstance(file, File) + assert isinstance(file, FileInfo) -def test_File_from_stat_data(mock_file_data): - foo_attr = paramiko.SFTPAttributes.from_stat( - obj=os.stat("foo.mrc"), filename="foo.mrc" - ) - foo = File.from_stat_data(data=foo_attr) - bar_attr = paramiko.SFTPAttributes.from_stat(obj=os.stat("bar.mrc")) - bar = File.from_stat_data(data=bar_attr, file_name="bar.mrc") - baz_attr = paramiko.SFTPAttributes.from_stat(obj=os.stat("baz.mrc")) +def test_FileInfo_from_stat_data(mock_sftp_attr): + foo = FileInfo.from_stat_data(data=mock_sftp_attr) + bar_attr = mock_sftp_attr + bar_attr.filename = None + bar = FileInfo.from_stat_data(data=bar_attr, file_name="bar.mrc") + baz_attr = mock_sftp_attr + baz_attr.filename, baz_attr.st_mode = None, None baz_attr.longname = ( - "-rw-r--r-- 1 0 0 140401 Jan 1 00:01 baz.mrc" + "-rwxrwxrwx 1 0 0 140401 Jan 1 00:01 baz.mrc" ) - baz = File.from_stat_data(data=baz_attr) - assert isinstance(foo_attr, paramiko.SFTPAttributes) + baz = FileInfo.from_stat_data(data=baz_attr) + assert isinstance(bar_attr, paramiko.SFTPAttributes) + assert isinstance(baz_attr, paramiko.SFTPAttributes) assert foo.file_name == "foo.mrc" - assert foo.file_mtime == 1704070800 - assert foo.file_size == 140401 - assert foo.file_uid == 0 - assert foo.file_gid == 0 - assert foo.file_mode == 33188 assert bar.file_name == "bar.mrc" assert baz.file_name == "baz.mrc" + assert foo.file_mtime == 1704070800 + assert bar.file_mtime == 1704070800 + assert baz.file_mtime == 1704070800 + assert foo.file_mode == 33188 + assert bar.file_mode == 33188 + assert baz.file_mode == 33279 -def test_File_from_stat_data_no_filename(mock_file_data): - sftp_attr = paramiko.SFTPAttributes.from_stat(obj=os.stat("foo.mrc")) +def test_FileInfo_from_stat_data_no_file_name(mock_sftp_attr): + sftp_attr = mock_sftp_attr + sftp_attr.filename = None with pytest.raises(AttributeError) as exc: - File.from_stat_data(data=sftp_attr) + FileInfo.from_stat_data(data=sftp_attr) assert "No filename provided" in str(exc) -def test_File_from_stat_data_no_st_mtime(mock_file_data): - sftp_attr = paramiko.SFTPAttributes.from_stat( - obj=os.stat("foo.mrc"), filename="foo.mrc" - ) +def test_FileInfo_from_stat_data_no_file_size(mock_sftp_attr): + sftp_attr = mock_sftp_attr + sftp_attr.st_size = None + with pytest.raises(AttributeError) as exc: + FileInfo.from_stat_data(data=sftp_attr) + assert "No file size provided" in str(exc) + + +def test_FileInfo_from_stat_data_no_file_mtime(mock_sftp_attr): + sftp_attr = mock_sftp_attr delattr(sftp_attr, "st_mtime") with pytest.raises(AttributeError) as exc: - File.from_stat_data(data=sftp_attr) + FileInfo.from_stat_data(data=sftp_attr) assert "No file modification time provided" in str(exc) -def test_File_from_stat_data_None_st_mtime(mock_file_data): - sftp_attr = paramiko.SFTPAttributes.from_stat( - obj=os.stat("foo.mrc"), filename="foo.mrc" - ) - sftp_attr.st_mtime = None +def test_FileInfo_from_stat_data_no_file_mode(mock_sftp_attr): + sftp_attr = mock_sftp_attr + sftp_attr.st_mode = None with pytest.raises(AttributeError) as exc: - File.from_stat_data(data=sftp_attr) - assert "No file modification time provided" in str(exc) + FileInfo.from_stat_data(data=sftp_attr) + assert "No file mode provided" in str(exc) @pytest.mark.parametrize( "str_time, mtime", [ ( - "220 20240101010000", + "20240101010000", 1704070800, ), ( - "220 20240202020202", + "20240202020202", 1706839322, ), ( - "220 20240303030303", + "20240303030303", 1709434983, ), ], ) -def test_File_parse_mdtm_time(str_time, mtime): - parsed = File.parse_mdtm_time(str_time) - assert parsed == mtime +def test_FileInfo_parse_mdtm_time(str_time, mtime): + file = FileInfo( + file_name="foo.mrc", file_mtime=str_time, file_mode=33188, file_size=1 + ) + assert file.file_mtime == mtime @pytest.mark.parametrize( @@ -100,8 +115,61 @@ def test_File_parse_mdtm_time(str_time, mtime): "-rxwrxwrxw", 33279, ), + ("-r--------", 33024), ], ) -def test_File_parse_permissions(str_permissions, decimal_permissions): - parsed = File.parse_permissions(str_permissions) - assert parsed == decimal_permissions +def test_FileInfo_parse_permissions(str_permissions, decimal_permissions): + file = FileInfo( + file_name="foo.mrc", + file_mtime=1704070800, + file_mode=str_permissions, + file_size=1, + ) + assert file.file_mode == decimal_permissions + + +def test_File(): + foo = File( + file_name="foo.mrc", + file_mtime=1704070800, + file_size=1, + file_uid=None, + file_gid=None, + file_atime=None, + file_mode="-rw-r--r--", + file_stream=io.BytesIO(b"foo"), + ) + bar = File( + file_name="bar.txt", + file_mtime="20240101010000", + file_size=1, + file_uid=0, + file_gid=0, + file_atime=None, + file_mode="-rw-r--r--", + file_stream=io.BytesIO(b"foo"), + ) + assert foo.file_name == "foo.mrc" + assert foo.file_mtime == 1704070800 + assert bar.file_name == "bar.txt" + assert bar.file_mtime == 1704070800 + assert isinstance(foo.file_name, str) + assert isinstance(foo.file_mtime, int) + assert isinstance(bar.file_name, str) + assert isinstance(bar.file_mtime, int) + assert isinstance(foo, FileInfo) + assert isinstance(foo, File) + assert isinstance(bar, FileInfo) + assert isinstance(bar, File) + + +def test_File_from_fileinfo(mock_file_info): + file = File.from_fileinfo( + file=mock_file_info, + file_stream=io.BytesIO(b"foo"), + ) + assert file.file_name == "foo.mrc" + assert file.file_mtime == 1704070800 + assert isinstance(file.file_name, str) + assert isinstance(file.file_mtime, int) + assert isinstance(file, FileInfo) diff --git a/tests/test_utils.py b/tests/test_utils.py index be0b1ab..f8c5180 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,7 +3,7 @@ import logging.config import os import pytest -from file_retriever.utils import logger_config, vendor_config +from file_retriever.utils import logger_config, client_config def test_logger_config(): @@ -40,19 +40,25 @@ def test_logger_config_stream(message, level, caplog): def test_vendor_config(mocker): yaml_string = """ - TEST_HOST: foo - TEST_USER: bar - TEST_PASSWORD: baz - TEST_PORT: '22' - TEST_SRC: test_src + FOO_HOST: foo + FOO_USER: bar + FOO_PASSWORD: baz + FOO_PORT: '21' + FOO_SRC: foo_src + BAR_HOST: foo + BAR_USER: bar + BAR_PASSWORD: baz + BAR_PORT: '22' + BAR_SRC: bar_src """ m = mocker.mock_open(read_data=yaml_string) mocker.patch("builtins.open", m) - vendor_config("foo.yaml") - - assert os.environ["TEST_HOST"] == "foo" - assert os.environ["TEST_USER"] == "bar" - assert os.environ["TEST_PASSWORD"] == "baz" - assert os.environ["TEST_PORT"] == "22" - assert os.environ["TEST_SRC"] == "test_src" + client_list = client_config("foo.yaml") + assert len(client_list) == 2 + assert client_list == ["FOO", "BAR"] + assert os.environ["FOO_HOST"] == "foo" + assert os.environ["FOO_USER"] == "bar" + assert os.environ["FOO_PASSWORD"] == "baz" + assert os.environ["FOO_PORT"] == "21" + assert os.environ["FOO_SRC"] == "foo_src"