Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Es 1011531 #355

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions examples/query_execute.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
import threading
from databricks import sql
import os
import logging


logger = logging.getLogger("databricks.sql")
logger.setLevel(logging.DEBUG)
fh = logging.FileHandler('pysqllogs.log')
fh.setFormatter(logging.Formatter("%(asctime)s %(process)d %(thread)d %(message)s"))
fh.setLevel(logging.DEBUG)
logger.addHandler(fh)

with sql.connect(server_hostname = os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path = os.getenv("DATABRICKS_HTTP_PATH"),
access_token = os.getenv("DATABRICKS_TOKEN")) as connection:

with connection.cursor() as cursor:
cursor.execute("SELECT * FROM default.diamonds LIMIT 2")
result = cursor.fetchall()
access_token = os.getenv("DATABRICKS_TOKEN"),
# max_download_threads = 2
) as connection:

for row in result:
print(row)
with connection.cursor(
# arraysize=100
) as cursor:
# cursor.execute("SELECT * FROM range(0, 10000000) AS t1 LEFT JOIN (SELECT 1) AS t2")
cursor.execute("SELECT * FROM andre.plotly_iot_dashboard.bronze_sensors limit 1000001")
try:
result = cursor.fetchall()
print(f"result length: {len(result)}")
except sql.exc.ResultSetDownloadError as e:
print(f"error: {e}")
# buffer_size_bytes=4857600
37 changes: 6 additions & 31 deletions src/databricks/sql/cloudfetch/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ResultSetDownloadHandler,
DownloadableResultSettings,
)
from databricks.sql.exc import ResultSetDownloadError
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink

logger = logging.getLogger(__name__)
Expand All @@ -34,8 +35,6 @@ def __init__(self, max_download_threads: int, lz4_compressed: bool):
self.download_handlers: List[ResultSetDownloadHandler] = []
self.thread_pool = ThreadPoolExecutor(max_workers=max_download_threads + 1)
self.downloadable_result_settings = DownloadableResultSettings(lz4_compressed)
self.fetch_need_retry = False
self.num_consecutive_result_file_download_retries = 0

def add_file_links(
self, t_spark_arrow_result_links: List[TSparkArrowResultLink]
Expand Down Expand Up @@ -81,13 +80,15 @@ def get_next_downloaded_file(

# Find next file
idx = self._find_next_file_index(next_row_offset)
# is this correct?
if idx is None:
self._shutdown_manager()
logger.debug("could not find next file index")
return None
handler = self.download_handlers[idx]

# Check (and wait) for download status
if self._check_if_download_successful(handler):
if handler.is_file_download_successful():
# Buffer should be empty so set buffer to new ArrowQueue with result_file
result = DownloadedFile(
handler.result_file,
Expand All @@ -97,9 +98,9 @@ def get_next_downloaded_file(
self.download_handlers.pop(idx)
# Return True upon successful download to continue loop and not force a retry
return result
# Download was not successful for next download item, force a retry
# Download was not successful for next download item. Fail
self._shutdown_manager()
return None
raise ResultSetDownloadError(f"Download failed for result set starting at {next_row_offset}")

def _remove_past_handlers(self, next_row_offset: int):
# Any link in which its start to end range doesn't include the next row to be fetched does not need downloading
Expand Down Expand Up @@ -133,32 +134,6 @@ def _find_next_file_index(self, next_row_offset: int):
]
return next_indices[0] if len(next_indices) > 0 else None

def _check_if_download_successful(self, handler: ResultSetDownloadHandler):
# Check (and wait until download finishes) if download was successful
if not handler.is_file_download_successful():
if handler.is_link_expired:
self.fetch_need_retry = True
return False
elif handler.is_download_timedout:
# Consecutive file retries should not exceed threshold in settings
if (
self.num_consecutive_result_file_download_retries
>= self.downloadable_result_settings.max_consecutive_file_download_retries
):
self.fetch_need_retry = True
return False
self.num_consecutive_result_file_download_retries += 1

# Re-submit handler run to thread pool and recursively check download status
self.thread_pool.submit(handler.run)
return self._check_if_download_successful(handler)
else:
self.fetch_need_retry = True
return False

self.num_consecutive_result_file_download_retries = 0
self.fetch_need_retry = False
return True

def _shutdown_manager(self):
# Clear download handlers and shutdown the thread pool
Expand Down
118 changes: 92 additions & 26 deletions src/databricks/sql/cloudfetch/downloader.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import logging
from dataclasses import dataclass

from datetime import datetime
import requests
import lz4.frame
import threading
import time

import os
from threading import get_ident
from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink

logging.basicConfig(format="%(asctime)s %(message)s")
logger = logging.getLogger(__name__)

DEFAULT_CLOUD_FILE_TIMEOUT = int(os.getenv("DATABRICKS_CLOUD_FILE_TIMEOUT", 60))


@dataclass
class DownloadableResultSettings:
Expand All @@ -20,13 +24,17 @@ class DownloadableResultSettings:
is_lz4_compressed (bool): Whether file is expected to be lz4 compressed.
link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs.
download_timeout (int): Timeout for download requests. Default 60 secs.
max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down.
download_max_retries (int): Number of consecutive download retries before shutting down.
max_retries (int): Number of consecutive download retries before shutting down.
backoff_factor (int): Factor to increase wait time between retries.

"""

is_lz4_compressed: bool
link_expiry_buffer_secs: int = 0
download_timeout: int = 60
max_consecutive_file_download_retries: int = 0
download_timeout: int = DEFAULT_CLOUD_FILE_TIMEOUT
max_retries: int = 5
backoff_factor: int = 2


class ResultSetDownloadHandler(threading.Thread):
Expand Down Expand Up @@ -57,16 +65,21 @@ def is_file_download_successful(self) -> bool:
else None
)
try:
logger.debug(
f"waiting for at most {timeout} seconds for download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)

if not self.is_download_finished.wait(timeout=timeout):
self.is_download_timedout = True
logger.debug(
"Cloud fetch download timed out after {} seconds for link representing rows {} to {}".format(
self.settings.download_timeout,
self.result_link.startRowOffset,
self.result_link.startRowOffset + self.result_link.rowCount,
)
f"cloud fetch download timed out after {self.settings.download_timeout} seconds for link representing rows {self.result_link.startRowOffset} to {self.result_link.startRowOffset + self.result_link.rowCount}"
)
return False
# there are some weird cases when the is_download_finished is not set, but the file is downloaded successfully
return self.is_file_downloaded_successfully

logger.debug(
f"finish waiting for download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
except Exception as e:
logger.error(e)
return False
Expand All @@ -80,25 +93,38 @@ def run(self):
file, and signals to waiting threads that the download is finished and whether it was successful.
"""
self._reset()


# Check if link is already expired or is expiring
if ResultSetDownloadHandler.check_link_expired(
self.result_link, self.settings.link_expiry_buffer_secs
):
self.is_link_expired = True
return
try:
# Check if link is already expired or is expiring
if ResultSetDownloadHandler.check_link_expired(
self.result_link, self.settings.link_expiry_buffer_secs
):
self.is_link_expired = True
return

session = requests.Session()
session.timeout = self.settings.download_timeout
logger.debug(
f"started to download file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)

try:
# Get the file via HTTP request
response = session.get(self.result_link.fileLink)
response = http_get_with_retry(
url=self.result_link.fileLink,
max_retries=self.settings.max_retries,
backoff_factor=self.settings.backoff_factor,
download_timeout=self.settings.download_timeout,
)

if not response.ok:
self.is_file_downloaded_successfully = False
if not response:
logger.error(
f"failed downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
return

logger.debug(
f"success downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)

# Save (and decompress if needed) the downloaded file
compressed_data = response.content
decompressed_data = (
Expand All @@ -109,18 +135,26 @@ def run(self):
self.result_file = decompressed_data

# The size of the downloaded file should match the size specified from TSparkArrowResultLink
self.is_file_downloaded_successfully = (
len(self.result_file) == self.result_link.bytesNum
success = len(self.result_file) == self.result_link.bytesNum
logger.debug(
f"download successful file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
self.is_file_downloaded_successfully = success
except Exception as e:
logger.debug(
f"exception downloading file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
logger.error(e)
self.is_file_downloaded_successfully = False

finally:
session and session.close()
logger.debug(
f"signal finished file: startRow {self.result_link.startRowOffset}, rowCount {self.result_link.rowCount}, endRow {self.result_link.startRowOffset + self.result_link.rowCount}"
)
# Awaken threads waiting for this to be true which signals the run is complete
self.is_download_finished.set()


def _reset(self):
"""
Reset download-related flags for every retry of run()
Expand All @@ -145,6 +179,9 @@ def check_link_expired(
link.expiryTime < current_time
or link.expiryTime - current_time < expiry_buffer_secs
):
logger.debug(
f"{(os.getpid(), get_ident())} - link expired"
)
return True
return False

Expand All @@ -171,3 +208,32 @@ def decompress_data(compressed_data: bytes) -> bytes:
uncompressed_data += data
start += num_bytes
return uncompressed_data


def http_get_with_retry(url, max_retries=5, backoff_factor=2, download_timeout=60):
attempts = 0

while attempts < max_retries:
try:
session = requests.Session()
session.timeout = download_timeout
response = session.get(url)

# Check if the response status code is in the 2xx range for success
if response.status_code == 200:
return response
else:
logger.error(response)
except requests.RequestException as e:
print(f"request failed with exception: {e}")
finally:
session.close()
# Exponential backoff before the next attempt
wait_time = backoff_factor ** attempts
logger.info(f"retrying in {wait_time} seconds...")
time.sleep(wait_time)

attempts += 1

logger.error(f"exceeded maximum number of retries ({max_retries}) while downloading result.")
return None
4 changes: 4 additions & 0 deletions src/databricks/sql/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,7 @@ class SessionAlreadyClosedError(RequestError):

class CursorAlreadyClosedError(RequestError):
"""Thrown if CancelOperation receives a code 404. ThriftBackend should gracefully proceed as this is expected."""


class ResultSetDownloadError(RequestError):
"""Thrown if there was an error during the download of a result set"""
20 changes: 17 additions & 3 deletions src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def attempt_request(attempt):
try:
this_method_name = getattr(method, "__name__")

logger.debug("Sending request: {}(<REDACTED>)".format(this_method_name))
logger.debug("sending thrift request: {}(<REDACTED>)".format(this_method_name))
unsafe_logger.debug("Sending request: {}".format(request))

# These three lines are no-ops if the v3 retry policy is not in use
Expand All @@ -406,7 +406,7 @@ def attempt_request(attempt):

# We need to call type(response) here because thrift doesn't implement __name__ attributes for thrift responses
logger.debug(
"Received response: {}(<REDACTED>)".format(type(response).__name__)
"received thrift response: {}(<REDACTED>)".format(type(response).__name__)
)
unsafe_logger.debug("Received response: {}".format(response))
return response
Expand Down Expand Up @@ -764,6 +764,9 @@ def _results_message_to_execute_response(self, resp, operation_state):
lz4_compressed = t_result_set_metadata_resp.lz4Compressed
is_staging_operation = t_result_set_metadata_resp.isStagingOperation
if direct_results and direct_results.resultSet:
logger.debug(
f"received direct results"
)
assert direct_results.resultSet.results.startRowOffset == 0
assert direct_results.resultSetMetadata

Expand All @@ -776,6 +779,9 @@ def _results_message_to_execute_response(self, resp, operation_state):
description=description,
)
else:
logger.debug(
f"must fetch results"
)
arrow_queue_opt = None
return ExecuteResponse(
arrow_queue=arrow_queue_opt,
Expand Down Expand Up @@ -835,11 +841,15 @@ def execute_command(
max_bytes,
lz4_compression,
cursor,
use_cloud_fetch=True,
use_cloud_fetch=True, # change here
parameters=[],
):
assert session_handle is not None

logger.debug(
f"executing: cloud fetch: {use_cloud_fetch}, max rows: {max_rows}, max bytes: {max_bytes}"
)

spark_arrow_types = ttypes.TSparkArrowTypes(
timestampAsArrow=self._use_arrow_native_timestamps,
decimalAsArrow=self._use_arrow_native_decimals,
Expand Down Expand Up @@ -955,6 +965,9 @@ def get_columns(
return self._handle_execute_response(resp, cursor)

def _handle_execute_response(self, resp, cursor):
logger.debug(
f"got execute response"
)
cursor.active_op_handle = resp.operationHandle
self._check_direct_results_for_error(resp.directResults)

Expand All @@ -975,6 +988,7 @@ def fetch_results(
arrow_schema_bytes,
description,
):
logger.debug("started to fetch results")
assert op_handle is not None

req = ttypes.TFetchResultsReq(
Expand Down
Loading
Loading