From 053169e55f26ba25b52eb0c0685432ceb5f557e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD?= Date: Fri, 19 Apr 2024 12:06:32 +0800 Subject: [PATCH] Parallelized loading in MyDataService::new (#128) --- data_server/Cargo.lock | 26 ++++++++++++++++++++++++++ data_server/Cargo.toml | 1 + data_server/src/main.rs | 40 +++++++++++++++++++++++++++++----------- 3 files changed, 56 insertions(+), 11 deletions(-) diff --git a/data_server/Cargo.lock b/data_server/Cargo.lock index 9b37bf1d..a7992607 100644 --- a/data_server/Cargo.lock +++ b/data_server/Cargo.lock @@ -271,6 +271,7 @@ dependencies = [ "bytes", "clap", "env_logger", + "futures-lite", "log", "prost", "rand", @@ -347,6 +348,25 @@ version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + +[[package]] +name = "futures-lite" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52527eb5074e35e9339c6b4e8d12600c7128b68fb25dcb9fa9dec18f7c25f3a5" +dependencies = [ + "fastrand", + "futures-core", + "futures-io", + "parking", + "pin-project-lite", +] + [[package]] name = "futures-sink" version = "0.3.29" @@ -659,6 +679,12 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "parking" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" + [[package]] name = "parking_lot" version = "0.12.1" diff --git a/data_server/Cargo.toml b/data_server/Cargo.toml index 94f2f950..4d7f730c 100644 --- a/data_server/Cargo.toml +++ b/data_server/Cargo.toml @@ -9,6 +9,7 @@ edition = "2021" bytes = "1.5.0" clap = { version = "4.4.11", features = ["derive"] } env_logger = "0.10.1" +futures-lite = "2.3.0" log = "0.4.20" prost = "0.12.3" rand = "0.8.5" diff --git a/data_server/src/main.rs b/data_server/src/main.rs index 014696e6..65c8b380 100644 --- a/data_server/src/main.rs +++ b/data_server/src/main.rs @@ -1,10 +1,13 @@ use clap::Parser; +use futures_lite::future::block_on; use log::info; use prost::Message; use rand::seq::SliceRandom; use rand::{thread_rng, Rng}; use std::fs::File; use std::io::{self, BufReader, Read, Result as IoResult}; +use std::sync::Arc; +use std::sync::Mutex; use std::vec; use tonic::{transport::Server, Request, Response, Status}; @@ -57,20 +60,35 @@ fn read_pb_stream(mut reader: BufReader) -> io::Result impl MyDataService { pub fn new(files: Vec, causual_sampling: bool) -> IoResult { - let mut groups = Vec::new(); - let mut weights = Vec::new(); - - for filename in files.iter() { - let file = File::open(filename)?; - let reader = BufReader::new(file); + let groups = Vec::new(); + let weights = Vec::new(); + + let guarded = Arc::new(Mutex::new((groups, weights))); + + let mut joins = Vec::with_capacity(files.len()); + for filename in files { + let g = guarded.clone(); + joins.push(tokio::task::spawn_blocking(move || { + let file = File::open(filename)?; + let reader = BufReader::new(file); + + // Assuming read_pb_stream is implemented and it returns an iterator over TextData + for text_data in read_pb_stream(reader)? { + let (groups, weights) = &mut *g.lock().unwrap(); + groups.push(text_data.clone()); + weights.push(text_data.sentences.len() as f32); // Assuming sentences is a repeated field in TextData + } + + Ok::<_, io::Error>(()) + })); + } - // Assuming read_pb_stream is implemented and it returns an iterator over TextData - for text_data in read_pb_stream(reader)? { - groups.push(text_data.clone()); - weights.push(text_data.sentences.len() as f32); // Assuming sentences is a repeated field in TextData - } + for join in joins { + block_on(join)??; } + let (groups, weights) = Arc::into_inner(guarded).unwrap().into_inner().unwrap(); + info!("Loaded {} groups", groups.len()); Ok(MyDataService {