Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add sqltap #496

Merged
merged 16 commits into from
Jun 22, 2023
4 changes: 3 additions & 1 deletion store/neurostore/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from connexion.resolver import MethodViewResolver
from flask_cors import CORS
import prance
import sqltap.wsgi

from .or_json import ORJSONDecoder, ORJSONEncoder
from .database import init_db
Expand All @@ -24,7 +25,6 @@

db = init_db(app)


app.secret_key = app.config["JWT_SECRET_KEY"]

options = {"swagger_ui": True}
Expand Down Expand Up @@ -71,6 +71,8 @@ def get_bundled_specs(main_file):
},
)

if app.debug:
app.wsgi_app = sqltap.wsgi.SQLTapMiddleware(app.wsgi_app, path="/api/__sqltap__")

app.json_encoder = ORJSONEncoder
app.json_decoder = ORJSONDecoder
15 changes: 5 additions & 10 deletions store/neurostore/ingest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@

def ingest_neurovault(verbose=False, limit=20, overwrite=False):
# Store existing studies for quick lookup
all_studies = {
s.doi: s for s in Study.query.filter_by(source="neurovault").all()
}
all_studies = {s.doi: s for s in Study.query.filter_by(source="neurovault").all()}

def add_collection(data):
if data["DOI"] in all_studies and not overwrite:
Expand Down Expand Up @@ -67,7 +65,6 @@ def add_collection(data):
source="neurovault",
level="group",
base_study=base_study,

)

space = data.get("coordinate_space", None)
Expand Down Expand Up @@ -227,7 +224,9 @@ def ingest_neurosynth(max_rows=None):
# do not overwrite the verions column
# we want to append to this column
columns = [
c for c in source_base_study.__table__.columns if c != "versions"
c
for c in source_base_study.__table__.columns
if c != "versions"
]
for ab in base_studies[1:]:
for col in columns:
Expand Down Expand Up @@ -379,11 +378,7 @@ def ingest_neuroquery(max_rows=None):
base_study = BaseStudy.query.filter_by(pmid=id_).one_or_none()

if base_study is None:
base_study = BaseStudy(
name=metadata_row["title"],
level="group",
pmid=id_
)
base_study = BaseStudy(name=metadata_row["title"], level="group", pmid=id_)
study_coord_data = coord_data.loc[[id_]]
s = Study(
name=metadata_row["title"] or base_study.name,
Expand Down
2 changes: 2 additions & 0 deletions store/neurostore/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .data import (
Studyset,
StudysetStudy,
Annotation,
BaseStudy,
Study,
Expand All @@ -16,6 +17,7 @@

__all__ = [
"Studyset",
"StudysetStudy",
"Annotation",
"BaseStudy",
"Study",
Expand Down
56 changes: 39 additions & 17 deletions store/neurostore/models/data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy import event, ForeignKeyConstraint
from sqlalchemy.ext.associationproxy import association_proxy
Expand All @@ -7,6 +8,7 @@
from sqlalchemy.sql import func
import shortuuid

from .migration_types import TSVector
from ..database import db


Expand Down Expand Up @@ -105,21 +107,31 @@ class BaseStudy(BaseMixin, db.Model):

name = db.Column(db.String)
description = db.Column(db.String)
publication = db.Column(db.String)
doi = db.Column(db.String, nullable=True)
pmid = db.Column(db.String, nullable=True)
authors = db.Column(db.String)
year = db.Column(db.Integer)
publication = db.Column(db.String, index=True)
doi = db.Column(db.String, nullable=True, index=True)
pmid = db.Column(db.String, nullable=True, index=True)
authors = db.Column(db.String, index=True)
year = db.Column(db.Integer, index=True)
public = db.Column(db.Boolean, default=True)
level = db.Column(db.String)
metadata_ = db.Column(JSONB)
user_id = db.Column(db.Text, db.ForeignKey("users.external_id"))
user_id = db.Column(db.Text, db.ForeignKey("users.external_id"), index=True)
__ts_vector__ = db.Column(
TSVector(),
db.Computed(
"to_tsvector('english', coalesce(name, '') || ' ' || coalesce(description, ''))",
persisted=True,
),
)

user = relationship("User", backref=backref("base_studies"))
# retrieve versions of same study
versions = relationship("Study", backref=backref("base_study"))

__table_args__ = (
db.CheckConstraint(level.in_(["group", "meta"])),
db.UniqueConstraint('doi', 'pmid', name='doi_pmid'),
db.UniqueConstraint("doi", "pmid", name="doi_pmid"),
sa.Index("ix_base_study___ts_vector__", __ts_vector__, postgresql_using="gin"),
)


Expand All @@ -128,27 +140,37 @@ class Study(BaseMixin, db.Model):

name = db.Column(db.String)
description = db.Column(db.String)
publication = db.Column(db.String)
doi = db.Column(db.String)
pmid = db.Column(db.String)
authors = db.Column(db.String)
year = db.Column(db.Integer)
publication = db.Column(db.String, index=True)
doi = db.Column(db.String, index=True)
pmid = db.Column(db.String, index=True)
authors = db.Column(db.String, index=True)
year = db.Column(db.Integer, index=True)
public = db.Column(db.Boolean, default=True)
level = db.Column(db.String)
metadata_ = db.Column(JSONB)
source = db.Column(db.String)
source_id = db.Column(db.String)
source = db.Column(db.String, index=True)
source_id = db.Column(db.String, index=True)
source_updated_at = db.Column(db.DateTime(timezone=True))
base_study_id = db.Column(db.Text, db.ForeignKey('base_studies.id'))
user_id = db.Column(db.Text, db.ForeignKey("users.external_id"))
base_study_id = db.Column(db.Text, db.ForeignKey("base_studies.id"))
user_id = db.Column(db.Text, db.ForeignKey("users.external_id"), index=True)
__ts_vector__ = db.Column(
TSVector(),
db.Computed(
"to_tsvector('english', coalesce(name, '') || ' ' || coalesce(description, ''))",
persisted=True,
),
)
user = relationship("User", backref=backref("studies"))
analyses = relationship(
"Analysis",
backref=backref("study"),
cascade="all, delete, delete-orphan",
)

__table_args__ = (db.CheckConstraint(level.in_(["group", "meta"])),)
__table_args__ = (
db.CheckConstraint(level.in_(["group", "meta"])),
sa.Index("ix_study___ts_vector__", __ts_vector__, postgresql_using="gin"),
)


class StudysetStudy(db.Model):
Expand Down
7 changes: 7 additions & 0 deletions store/neurostore/models/migration_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import TSVECTOR


class TSVector(sa.types.TypeDecorator):
"""Class for full text search"""
impl = TSVECTOR
2 changes: 1 addition & 1 deletion store/neurostore/openapi
1 change: 1 addition & 0 deletions store/neurostore/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ python-jose~=3.3
shortuuid~=1.0
sqlalchemy~=1.4
sqlalchemy-utils~=0.36
sqltap
webargs~=7.0
wheel~=0.36
wrapt~=1.12
Expand Down
26 changes: 13 additions & 13 deletions store/neurostore/resources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,13 @@ class ListView(BaseView):

def __init__(self):
# Initialize expected arguments based on class attributes
self._fulltext_fields = self._multi_search or self._search_fields
self._fulltext_fields = self._multi_search
self._user_args = {
**LIST_USER_ARGS,
**self._view_fields,
**{f: fields.Str() for f in self._fulltext_fields},
}
if self._fulltext_fields:
self._user_args.update({f: fields.Str() for f in self._fulltext_fields})

def view_search(self, q, args):
return q
Expand All @@ -241,9 +242,8 @@ def serialize_records(self, records, args):
).dump(records)
return content

def create_metadata(self, q):
count = len(q.all())
return {"total_count": count}
def create_metadata(self, q, total):
return {"total_count": total}

def search(self):
# Parse arguments using webargs
Expand All @@ -266,10 +266,7 @@ def search(self):

# For multi-column search, default to using search fields
if s is not None and self._fulltext_fields:
search_expr = [
getattr(m, field).ilike(f"%{s}%") for field in self._fulltext_fields
]
q = q.filter(sae.or_(*search_expr))
q = q.filter(m.__ts_vector__.match(s))

# Alternatively (or in addition), search on individual fields.
for field in self._search_fields:
Expand Down Expand Up @@ -299,11 +296,14 @@ def search(self):
# join the relevant tables for output
q = self.join_tables(q)

records = q.paginate(
page=args["page"], per_page=args["page_size"], error_out=False,
).items
pagination_query = q.paginate(
page=args["page"],
per_page=args["page_size"],
error_out=False,
)
records = pagination_query.items
content = self.serialize_records(records, args)
metadata = self.create_metadata(q)
metadata = self.create_metadata(q, pagination_query.total)
response = {
"metadata": metadata,
"results": content,
Expand Down
44 changes: 30 additions & 14 deletions store/neurostore/resources/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from marshmallow import EXCLUDE
from webargs import fields
import sqlalchemy.sql.expression as sae
from sqlalchemy.orm import joinedload

from .utils import view_maker
from .base import BaseView, ObjectView, ListView
Expand Down Expand Up @@ -62,6 +63,7 @@ class StudysetsView(ObjectView, ListView):
_linked = {
"annotations": "AnnotationsView",
}
_multi_search = ("name", "description")
_search_fields = ("name", "description", "publication", "doi", "pmid")

def view_search(self, q, args):
Expand All @@ -88,6 +90,7 @@ class AnnotationsView(ObjectView, ListView):
"studyset": "StudysetsView",
}

_multi_search = ("name", "description")
_search_fields = ("name", "description")

def view_search(self, q, args):
Expand Down Expand Up @@ -141,14 +144,14 @@ def load_from_neurostore(cls, source_id):

@view_maker
class BaseStudiesView(ObjectView, ListView):
_nested = {
"versions": "StudiesView"
}
_nested = {"versions": "StudiesView"}

_view_fields = {
"level": fields.String(default="group", missing="group"),
}

_multi_search = ("name", "description")

_search_fields = (
"name",
"description",
Expand All @@ -164,14 +167,22 @@ def view_search(self, q, args):
# search studies for data_type
if args.get("data_type"):
if args["data_type"] == "coordinate":
q = q.filter(self._model.versions.any(Study.analyses.any(Analysis.points.any())))
q = q.filter(
self._model.versions.any(Study.analyses.any(Analysis.points.any()))
)
elif args["data_type"] == "image":
q = q.filter(self._model.versions.any(Study.analyses.any(Analysis.images.any())))
q = q.filter(
self._model.versions.any(Study.analyses.any(Analysis.images.any()))
)
elif args["data_type"] == "both":
q = q.filter(
sae.or_(
self._model.versions.any(Study.analyses.any(Analysis.points.any())),
self._model.versions.any(Study.analyses.any(Analysis.images.any())),
self._model.versions.any(
Study.analyses.any(Analysis.points.any())
),
self._model.versions.any(
Study.analyses.any(Analysis.images.any())
),
)
)
# filter by level of analysis (group or meta)
Expand All @@ -180,6 +191,11 @@ def view_search(self, q, args):

return q

def join_tables(self, q):
"join relevant tables to speed up query"
q = q.options(joinedload("versions"))
return q


@view_maker
class StudiesView(ObjectView, ListView):
Expand Down Expand Up @@ -235,16 +251,16 @@ def view_search(self, q, args):
"doi" if isinstance(unique_col, bool) and unique_col else unique_col
)
if unique_col:
q_null = q.filter(getattr(self._model, unique_col).is_(None))
q_distinct = q.distinct(getattr(self._model, unique_col))
q = q_distinct.union(q_null)
q = q.order_by(getattr(self._model, unique_col))
subquery = q.distinct(getattr(self._model, unique_col)).subquery()
q = q.join(
subquery,
getattr(self._model, unique_col) == getattr(subquery.c, unique_col),
)
return q

def join_tables(self, q):
"join relevant tables to speed up query"
q = q.outerjoin(Analysis)

q = q.options(joinedload("analyses"))
return q

def serialize_records(self, records, args):
Expand Down Expand Up @@ -281,7 +297,7 @@ def load_from_neurostore(cls, source_id):
data["source"] = "neurostore"
data["source_id"] = source_id
data["source_updated_at"] = study.updated_at or study.created_at
data['base_study'] = {"id": study.base_study_id}
data["base_study"] = {"id": study.base_study_id}
return data

@classmethod
Expand Down
6 changes: 3 additions & 3 deletions store/neurostore/schemas/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,9 @@ class StudySchema(BaseDataSchema):
source_id = fields.String(
dump_only=True, metadata={"db_only": True}, allow_none=True
)
studysets = fields.Nested(
"StudySetStudyInfoSchema", dump_only=True, metadata={"db_only": True}, many=True
)
# studysets = fields.Nested(
# "StudySetStudyInfoSchema", dump_only=True, metadata={"db_only": True}, many=True
# )
source_updated_at = fields.DateTime(
dump_only=True, metadata={"db_only": True}, allow_none=True
)
Expand Down
4 changes: 2 additions & 2 deletions store/neurostore/tests/api/test_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def test_create(auth_client, user_data, endpoint, model, schema):
for row in rows:
payload = schema(copy=True).dump(row)
if model is BaseStudy:
payload['doi'] = payload['doi'] + "new"
payload['pmid'] = payload['pmid'] + "new"
payload["doi"] = payload["doi"] + "new"
payload["pmid"] = payload["pmid"] + "new"

resp = auth_client.post(f"/api/{endpoint}/", data=payload)
if resp.status_code == 422:
Expand Down
Loading
Loading