Skip to content

Commit

Permalink
foldseek DBSCAN with RBH filter and NN rescue
Browse files Browse the repository at this point in the history
  • Loading branch information
Woosub-Kim committed Feb 9, 2024
1 parent 3bf3cdf commit 2503196
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 33 deletions.
5 changes: 1 addition & 4 deletions src/strucclustutils/createcomplexreport.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,12 @@ const double FILTERED_OUT = 0.0;
const unsigned int UNCLUSTERED = 0;
const unsigned int CLUSTERED = 1;
const unsigned int MIN_PTS = 2;
const unsigned int SINGLE_CHAINED_COMPLEX = 1;
const float BIT_SCORE_MARGIN = 0.9;
//const float CLUSTERING_STEPS = 100.0;
//const float DEF_DIST = -1.0;
const float DEF_BIT_SCORE = -1.0;
const int UNINITIALIZED = 0;
const float LEARNING_RATE = 0.1;
const float DEFAULT_EPS = 0.1;
const unsigned int FINISH_CLUSTERING = 2;
const unsigned int MULTIPLE_CHAIN = 2;
typedef std::pair<std::string, std::string> compNameChainName_t;
typedef std::map<unsigned int, unsigned int> chainKeyToComplexId_t;
typedef std::map<unsigned int, std::vector<unsigned int>> complexIdToChainKeys_t;
Expand Down
81 changes: 52 additions & 29 deletions src/strucclustutils/scorecomplex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,12 @@ struct Assignment {
}
};

struct NeighborsWithDist {
NeighborsWithDist(unsigned int neighbor, float dist) : neighbor(neighbor), dist(dist) {}
unsigned int neighbor;
float dist;
};

bool compareChainToChainAlnByDbComplexId(const ChainToChainAln &first, const ChainToChainAln &second) {
if (first.dbChain.complexId < second.dbChain.complexId)
return true;
Expand Down Expand Up @@ -291,6 +297,14 @@ bool compareAssignment(const Assignment &first, const Assignment &second) {
return false;
}

bool compareNeighborWithDist(const NeighborsWithDist &first, const NeighborsWithDist &second) {
if (first.dist < second.dist)
return true;
if (first.dist > second.dist)
return false;
return false;
}

class DBSCANCluster {
public:
DBSCANCluster(SearchResult &searchResult, double minCov) : searchResult(searchResult) {
Expand All @@ -307,18 +321,18 @@ class DBSCANCluster {
unsigned int getAlnClusters() {
// rbh filter
filterAlnsByRBH();
fillDistMap();
// To skip DBSCAN clustering when alignments are few enough.
if (searchResult.alnVec.size() <= idealClusterSize)
return checkClusteringNecessity();
fillDistMap();

return runDBSCAN();
}

private:
SearchResult &searchResult;
float eps;
float maxDist;
// float minDist;
float learningRate;
unsigned int cLabel;
unsigned int prevMaxClusterSize;
Expand All @@ -327,13 +341,14 @@ class DBSCANCluster {
unsigned int minClusterSize;
std::vector<unsigned int> neighbors;
std::vector<unsigned int> neighborsOfCurrNeighbor;
std::vector<unsigned int> qFoundChainKeys;
std::vector<unsigned int> dbFoundChainKeys;
std::set<unsigned int> qFoundChainKeys;
std::set<unsigned int> dbFoundChainKeys;
distMap_t distMap;
std::vector<cluster_t> currClusters;
std::set<cluster_t> finalClusters;
std::map<unsigned int, float> qBestBitScore;
std::map<unsigned int, float> dbBestBitScore;
std::vector<NeighborsWithDist> neighborsWithDist;

unsigned int runDBSCAN() {
initializeAlnLabels();
Expand Down Expand Up @@ -368,13 +383,8 @@ class DBSCANCluster {
}
}

// too big cluster
if (neighbors.size() > idealClusterSize)
continue;

// redundant chains
if (checkChainRedundancy())
continue;
if (neighbors.size() > idealClusterSize || checkChainRedundancy())
getNearestNeighbors(centerAlnIdx);

// too small cluster
if (neighbors.size() < maxClusterSize)
Expand Down Expand Up @@ -407,22 +417,16 @@ class DBSCANCluster {

void fillDistMap() {
float dist;
// minDist = DEF_DIST;
distMap.clear();
for (size_t i=0; i < searchResult.alnVec.size(); i++) {
ChainToChainAln &prevAln = searchResult.alnVec[i];
for (size_t j = i+1; j < searchResult.alnVec.size(); j++) {
ChainToChainAln &currAln = searchResult.alnVec[j];
dist = prevAln.getDistance(currAln);
maxDist = std::max(maxDist, dist);
// minDist = minDist<UNINITIALIZED ? dist : std::min(minDist, dist);
distMap.insert({{i,j}, dist});
}
}
// eps = minDist;
// learningRate = (maxDist - minDist) / CLUSTERING_STEPS;
// eps = 0.1;
// learningRate = 0.1;
}

void getNeighbors(unsigned int centerIdx, std::vector<unsigned int> &neighborVec) {
Expand Down Expand Up @@ -451,19 +455,12 @@ class DBSCANCluster {
dbFoundChainKeys.clear();

for (auto neighborIdx : neighbors) {
unsigned int qChainKey = searchResult.alnVec[neighborIdx].qChain.chainKey;
unsigned int dbChainKey = searchResult.alnVec[neighborIdx].dbChain.chainKey;

if (std::find(qFoundChainKeys.begin(), qFoundChainKeys.end(), qChainKey) != qFoundChainKeys.end())
if (!qFoundChainKeys.insert(searchResult.alnVec[neighborIdx].qChain.chainKey).second)
return true;

if (std::find(dbFoundChainKeys.begin(), dbFoundChainKeys.end(), dbChainKey) != dbFoundChainKeys.end())
if (!dbFoundChainKeys.insert(searchResult.alnVec[neighborIdx].dbChain.chainKey).second)
return true;

qFoundChainKeys.emplace_back(qChainKey);
dbFoundChainKeys.emplace_back(dbChainKey);
}

return false;
}

Expand All @@ -475,9 +472,9 @@ class DBSCANCluster {
}
if (checkChainRedundancy()) {
neighbors.clear();
if (searchResult.alnVec.size() < FINISH_CLUSTERING)
if (searchResult.alnVec.size() < MULTIPLE_CHAIN)
finishDBSCAN();
fillDistMap();

return runDBSCAN();
}
prevMaxClusterSize = neighbors.size();
Expand Down Expand Up @@ -531,6 +528,32 @@ class DBSCANCluster {
}
searchResult.alnVec.erase(searchResult.alnVec.begin() + alnIdx);
}
// return;
}

void getNearestNeighbors(unsigned int centerIdx) {
qFoundChainKeys.clear();
dbFoundChainKeys.clear();
neighborsWithDist.clear();

for (auto neighborIdx: neighbors) {
if (neighborIdx == centerIdx) {
neighborsWithDist.emplace_back(neighborIdx, 0.0);
continue;
}
neighborsWithDist.emplace_back(neighborIdx, neighborIdx < centerIdx ? distMap[{neighborIdx, centerIdx}] : distMap[{centerIdx, neighborIdx}]);
}
SORT_SERIAL(neighborsWithDist.begin(), neighborsWithDist.end(), compareNeighborWithDist);
neighbors.clear();
for (auto neighborWithDist : neighborsWithDist) {
if (!qFoundChainKeys.insert(searchResult.alnVec[neighborWithDist.neighbor].qChain.chainKey).second)
break;

if (!dbFoundChainKeys.insert(searchResult.alnVec[neighborWithDist.neighbor].dbChain.chainKey).second)
break;

neighbors.emplace_back(neighborWithDist.neighbor);
}
// return;
}
};
Expand Down Expand Up @@ -793,7 +816,7 @@ int scorecomplex(int argc, const char **argv, const Command &command) {
for (size_t qCompIdx = 0; qCompIdx < qComplexIndices.size(); qCompIdx++) {
unsigned int qComplexId = qComplexIndices[qCompIdx];
std::vector<unsigned int> &qChainKeys = qComplexIdToChainKeysMap.at(qComplexId);
if (qChainKeys.size() <= SINGLE_CHAINED_COMPLEX)
if (qChainKeys.size() < MULTIPLE_CHAIN)
continue;
complexScorer.getSearchResults(qComplexId, qChainKeys, dbChainKeyToComplexIdMap, dbComplexIdToChainKeysMap, searchResults);
// for each db complex
Expand Down

0 comments on commit 2503196

Please sign in to comment.