Skip to content

Commit

Permalink
Merge pull request #158 from dandi/restream
Browse files Browse the repository at this point in the history
Reduce the sizes of a number of streams & futures
  • Loading branch information
jwodder committed Jul 10, 2024
2 parents 39d79c8 + ff44f1f commit 00d0714
Show file tree
Hide file tree
Showing 12 changed files with 368 additions and 190 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
In Development
--------------
- Reduced the sizes of a number of streams & futures

v0.4.0 (2024-07-09)
-------------------
- Set `Access-Control-Allow-Origin: *` header in all responses
Expand Down
25 changes: 2 additions & 23 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ publish = false

[dependencies]
anyhow = "1.0.86"
async-stream = "0.3.5"
async-trait = "0.1.81"
aws-config = { version = "1.5.3", features = ["behavior-version-latest"] }
aws-sdk-s3 = "1.39.0"
aws-smithy-async = "1.2.1"
aws-smithy-runtime-api = "1.7.1"
aws-smithy-types-convert = { version = "0.60.8", features = ["convert-time"] }
axum = { version = "0.7.5", default-features = false, features = ["http1", "tokio", "tower-log"] }
Expand All @@ -33,6 +33,7 @@ itertools = "0.13.0"
memory-stats = "1.2.0"
moka = { version = "0.12.8", features = ["future"] }
percent-encoding = "2.3.1"
pin-project-lite = "0.2.14"
reqwest = { version = "0.12.5", default-features = false, features = ["json", "rustls-tls-native-roots"] }
reqwest-middleware = "0.3.2"
reqwest-retry = "0.6.0"
Expand Down
78 changes: 25 additions & 53 deletions src/dandi/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
mod dandiset_id;
mod streams;
mod types;
mod version_id;
pub(crate) use self::dandiset_id::*;
use self::streams::Paginate;
pub(crate) use self::types::*;
pub(crate) use self::version_id::*;
use crate::consts::S3CLIENT_CACHE_SIZE;
Expand All @@ -10,7 +12,6 @@ use crate::paths::{ParsePureDirPathError, PureDirPath, PurePath};
use crate::s3::{
BucketSpec, GetBucketRegionError, PrefixedS3Client, S3Client, S3Error, S3Location,
};
use async_stream::try_stream;
use futures_util::{Stream, TryStreamExt};
use moka::future::{Cache, CacheBuilder};
use serde::de::DeserializeOwned;
Expand Down Expand Up @@ -51,20 +52,8 @@ impl DandiClient {
self.inner.get_json(url).await.map_err(Into::into)
}

fn paginate<T: DeserializeOwned + 'static>(
&self,
url: Url,
) -> impl Stream<Item = Result<T, DandiError>> + '_ {
try_stream! {
let mut url = Some(url);
while let Some(u) = url {
let page = self.inner.get_json::<Page<T>>(u).await?;
for r in page.results {
yield r;
}
url = page.next;
}
}
fn paginate<T: DeserializeOwned + 'static>(&self, url: Url) -> Paginate<T> {
Paginate::new(self, url)
}

async fn get_s3client(&self, loc: S3Location) -> Result<PrefixedS3Client, ZarrToS3Error> {
Expand Down Expand Up @@ -247,35 +236,26 @@ impl<'a> VersionEndpoint<'a> {
pub(crate) fn get_root_children(
&self,
) -> impl Stream<Item = Result<DandiResource, DandiError>> + '_ {
try_stream! {
let stream = self.get_entries_under_path(None);
tokio::pin!(stream);
while let Some(entry) = stream.try_next().await? {
self.get_entries_under_path(None)
.and_then(move |entry| async move {
match entry {
FolderEntry::Folder(subf) => yield DandiResource::Folder(subf),
FolderEntry::Folder(subf) => Ok(DandiResource::Folder(subf)),
FolderEntry::Asset { id, path } => match self.get_asset_by_id(&id).await {
Ok(asset) => yield DandiResource::Asset(asset),
Ok(asset) => Ok(DandiResource::Asset(asset)),
Err(DandiError::Http(HttpError::NotFound { .. })) => {
Err(DandiError::DisappearingAsset { asset_id: id, path })?;
Err(DandiError::DisappearingAsset { asset_id: id, path })
}
Err(e) => Err(e)?,
Err(e) => Err(e),
},
}
}
}
})
}

fn get_folder_entries(
&self,
path: &AssetFolder,
) -> impl Stream<Item = Result<FolderEntry, DandiError>> + '_ {
fn get_folder_entries(&self, path: &AssetFolder) -> Paginate<FolderEntry> {
self.get_entries_under_path(Some(&path.path))
}

fn get_entries_under_path(
&self,
path: Option<&PureDirPath>,
) -> impl Stream<Item = Result<FolderEntry, DandiError>> + '_ {
fn get_entries_under_path(&self, path: Option<&PureDirPath>) -> Paginate<FolderEntry> {
let mut url = self.client.get_url([
"dandisets",
self.dandiset_id.as_ref(),
Expand Down Expand Up @@ -304,8 +284,7 @@ impl<'a> VersionEndpoint<'a> {
.append_pair("metadata", "1")
.append_pair("order", "path");
let dirpath = path.to_dir_path();
let stream = self.client.paginate::<RawAsset>(url.clone());
tokio::pin!(stream);
let mut stream = self.client.paginate::<RawAsset>(url.clone());
while let Some(asset) = stream.try_next().await? {
if &asset.path == path {
return Ok(AtAssetPath::Asset(asset.try_into_asset(self)?));
Expand Down Expand Up @@ -371,8 +350,7 @@ impl<'a> VersionEndpoint<'a> {
match self.get_resource_with_s3(path).await? {
DandiResourceWithS3::Folder(folder) => {
let mut children = Vec::new();
let stream = self.get_folder_entries(&folder);
tokio::pin!(stream);
let mut stream = self.get_folder_entries(&folder);
while let Some(child) = stream.try_next().await? {
let child = match child {
FolderEntry::Folder(subf) => DandiResource::Folder(subf),
Expand All @@ -391,25 +369,19 @@ impl<'a> VersionEndpoint<'a> {
DandiResourceWithS3::Asset(Asset::Blob(r)) => Ok(DandiResourceWithChildren::Blob(r)),
DandiResourceWithS3::Asset(Asset::Zarr(zarr)) => {
let s3 = self.client.get_s3client_for_zarr(&zarr).await?;
let mut children = Vec::new();
{
let stream = s3.get_root_entries();
tokio::pin!(stream);
while let Some(child) = stream.try_next().await? {
children.push(zarr.make_resource(child));
}
}
let children = s3
.get_root_entries()
.map_ok(|child| zarr.make_resource(child))
.try_collect::<Vec<_>>()
.await?;
Ok(DandiResourceWithChildren::Zarr { zarr, children })
}
DandiResourceWithS3::ZarrFolder { folder, s3 } => {
let mut children = Vec::new();
{
let stream = s3.get_folder_entries(&folder.path);
tokio::pin!(stream);
while let Some(child) = stream.try_next().await? {
children.push(folder.make_resource(child));
}
}
let children = s3
.get_folder_entries(&folder.path)
.map_ok(|child| folder.make_resource(child))
.try_collect::<Vec<_>>()
.await?;
Ok(DandiResourceWithChildren::ZarrFolder { folder, children })
}
DandiResourceWithS3::ZarrEntry(r) => Ok(DandiResourceWithChildren::ZarrEntry(r)),
Expand Down
82 changes: 82 additions & 0 deletions src/dandi/streams.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use super::types::Page;
use super::{DandiClient, DandiError};
use crate::httputil::{Client, HttpError};
use futures_util::{future::BoxFuture, FutureExt, Stream};
use pin_project_lite::pin_project;
use serde::de::DeserializeOwned;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use url::Url;

pin_project! {
// Implementing paginate() as a manually-implemented Stream instead of via
// async_stream lets us save about 4700 bytes on dandidav's top-level
// Futures.
#[must_use = "streams do nothing unless polled"]
pub(super) struct Paginate<T> {
client: Client,
state: PaginateState<T>,
}
}

enum PaginateState<T> {
Requesting(BoxFuture<'static, Result<Page<T>, HttpError>>),
Yielding {
results: std::vec::IntoIter<T>,
next: Option<Url>,
},
Done,
}

impl<T> Paginate<T> {
pub(super) fn new(client: &DandiClient, url: Url) -> Self {
Paginate {
client: client.inner.clone(),
state: PaginateState::Yielding {
results: Vec::new().into_iter(),
next: Some(url),
},
}
}
}

impl<T> Stream for Paginate<T>
where
T: DeserializeOwned + 'static,
{
type Item = Result<T, DandiError>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
loop {
match this.state {
PaginateState::Requesting(ref mut fut) => match ready!(fut.as_mut().poll(cx)) {
Ok(page) => {
*this.state = PaginateState::Yielding {
results: page.results.into_iter(),
next: page.next,
}
}
Err(e) => {
*this.state = PaginateState::Done;
return Some(Err(DandiError::from(e))).into();
}
},
PaginateState::Yielding {
ref mut results,
ref mut next,
} => {
if let Some(item) = results.next() {
return Some(Ok(item)).into();
} else if let Some(url) = next.take() {
*this.state =
PaginateState::Requesting(this.client.get_json::<Page<T>>(url).boxed());
} else {
*this.state = PaginateState::Done;
}
}
PaginateState::Done => return None.into(),
}
}
}
}
29 changes: 13 additions & 16 deletions src/dav/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ impl DandiDav {
&self,
req: Request<Body>,
) -> Result<Response<Body>, Infallible> {
// Box large future:
let resp = match Box::pin(self.inner_handle_request(req)).await {
let resp = match self.inner_handle_request(req).await {
Ok(r) => r,
Err(e) if e.is_404() => {
let e = anyhow::Error::from(e);
Expand Down Expand Up @@ -246,12 +245,12 @@ impl DandiDav {
DavPath::Root => Ok(DavResourceWithChildren::root()),
DavPath::DandisetIndex => {
let col = DavCollection::dandiset_index();
let mut children = Vec::new();
let stream = self.dandi.get_all_dandisets();
tokio::pin!(stream);
while let Some(ds) = stream.try_next().await? {
children.push(DavResource::Collection(ds.into()));
}
let children = self
.dandi
.get_all_dandisets()
.map_ok(|ds| DavResource::Collection(ds.into()))
.try_collect::<Vec<_>>()
.await?;
Ok(DavResourceWithChildren::Collection { col, children })
}
DavPath::Dandiset { dandiset_id } => {
Expand Down Expand Up @@ -282,8 +281,7 @@ impl DandiDav {
let col = DavCollection::dandiset_releases(dandiset_id);
let mut children = Vec::new();
let endpoint = self.dandi.dandiset(dandiset_id.clone());
let stream = endpoint.get_all_versions();
tokio::pin!(stream);
let mut stream = endpoint.get_all_versions();
while let Some(v) = stream.try_next().await? {
if let VersionId::Published(ref pvid) = v.version {
let path = version_path(dandiset_id, &VersionSpec::Published(pvid.clone()));
Expand All @@ -299,12 +297,11 @@ impl DandiDav {
version,
} => {
let (col, endpoint) = self.get_dandiset_version(dandiset_id, version).await?;
let mut children = Vec::new();
let stream = endpoint.get_root_children();
tokio::pin!(stream);
while let Some(res) = stream.try_next().await? {
children.push(DavResource::from(res).under_version_path(dandiset_id, version));
}
let mut children = endpoint
.get_root_children()
.map_ok(|res| DavResource::from(res).under_version_path(dandiset_id, version))
.try_collect::<Vec<_>>()
.await?;
children.push(
self.get_dandiset_yaml(dandiset_id, version)
.await
Expand Down
24 changes: 18 additions & 6 deletions src/httputil.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use reqwest::{Method, Request, Response, StatusCode};
use reqwest_middleware::{Middleware, Next};
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use serde::de::DeserializeOwned;
use std::future::Future;
use thiserror::Error;
use tracing::Instrument;
use url::Url;
Expand Down Expand Up @@ -53,12 +54,23 @@ impl Client {
self.request(Method::GET, url).await
}

pub(crate) async fn get_json<T: DeserializeOwned>(&self, url: Url) -> Result<T, HttpError> {
self.get(url.clone())
.await?
.json::<T>()
.await
.map_err(move |source| HttpError::Deserialize { url, source })
pub(crate) fn get_json<T: DeserializeOwned>(
&self,
url: Url,
) -> impl Future<Output = Result<T, HttpError>> {
// Clone the client and move it into an async block (as opposed to just
// writing a "normal" async function) so that the resulting Future will
// be 'static rather than retaining a reference to &self, thereby
// facilitating the Future's use by the Paginate stream.
let client = self.clone();
async move {
client
.get(url.clone())
.await?
.json::<T>()
.await
.map_err(move |source| HttpError::Deserialize { url, source })
}
}
}

Expand Down
Loading

0 comments on commit 00d0714

Please sign in to comment.