Skip to content

Commit

Permalink
Parallelized loading in MyDataService::new (#128)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethe committed Apr 19, 2024
1 parent cb4cd1b commit 053169e
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 11 deletions.
26 changes: 26 additions & 0 deletions data_server/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions data_server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ edition = "2021"
bytes = "1.5.0"
clap = { version = "4.4.11", features = ["derive"] }
env_logger = "0.10.1"
futures-lite = "2.3.0"
log = "0.4.20"
prost = "0.12.3"
rand = "0.8.5"
Expand Down
40 changes: 29 additions & 11 deletions data_server/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use clap::Parser;
use futures_lite::future::block_on;
use log::info;
use prost::Message;
use rand::seq::SliceRandom;
use rand::{thread_rng, Rng};
use std::fs::File;
use std::io::{self, BufReader, Read, Result as IoResult};
use std::sync::Arc;
use std::sync::Mutex;
use std::vec;
use tonic::{transport::Server, Request, Response, Status};

Expand Down Expand Up @@ -57,20 +60,35 @@ fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<TextData>

impl MyDataService {
pub fn new(files: Vec<String>, causual_sampling: bool) -> IoResult<Self> {
let mut groups = Vec::new();
let mut weights = Vec::new();

for filename in files.iter() {
let file = File::open(filename)?;
let reader = BufReader::new(file);
let groups = Vec::new();
let weights = Vec::new();

let guarded = Arc::new(Mutex::new((groups, weights)));

let mut joins = Vec::with_capacity(files.len());
for filename in files {
let g = guarded.clone();
joins.push(tokio::task::spawn_blocking(move || {
let file = File::open(filename)?;
let reader = BufReader::new(file);

// Assuming read_pb_stream is implemented and it returns an iterator over TextData
for text_data in read_pb_stream(reader)? {
let (groups, weights) = &mut *g.lock().unwrap();
groups.push(text_data.clone());
weights.push(text_data.sentences.len() as f32); // Assuming sentences is a repeated field in TextData
}

Ok::<_, io::Error>(())
}));
}

// Assuming read_pb_stream is implemented and it returns an iterator over TextData
for text_data in read_pb_stream(reader)? {
groups.push(text_data.clone());
weights.push(text_data.sentences.len() as f32); // Assuming sentences is a repeated field in TextData
}
for join in joins {
block_on(join)??;
}

let (groups, weights) = Arc::into_inner(guarded).unwrap().into_inner().unwrap();

info!("Loaded {} groups", groups.len());

Ok(MyDataService {
Expand Down

0 comments on commit 053169e

Please sign in to comment.