Skip to content

Commit

Permalink
Add an option to specify data server address.
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed Jan 20, 2024
1 parent 39f6902 commit 61a8ab1
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions data_server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ impl MyDataService {

info!("Loaded {} groups", groups.len());

Ok(MyDataService { groups, weights, causual_sampling })
Ok(MyDataService {
groups,
weights,
causual_sampling,
})
}
}

Expand Down Expand Up @@ -105,30 +109,28 @@ impl DataService for MyDataService {
let max = group.sentences.len() - num_samples;
if max <= 0 {
return Ok(Response::new(SampledData {
name: group.name.clone(),
name: group.name.clone(),
source: group.source.clone(),
samples: group.sentences.clone(),
}));
}

let start = rng.gen_range(0..max);
Ok(Response::new(SampledData {
name: group.name.clone(),
name: group.name.clone(),
source: group.source.clone(),
samples: group.sentences[start..start + num_samples].to_vec(),
}))
} else {
let sentences_ref = group
.sentences
.choose_multiple(&mut rng, num_samples);
let sentences_ref = group.sentences.choose_multiple(&mut rng, num_samples);

let sentences: Vec<Sentence> = sentences_ref
.into_iter()
.cloned() // Clone each &Sentence to get Sentence
.collect();

Ok(Response::new(SampledData {
name: group.name.clone(),
name: group.name.clone(),
source: group.source.clone(),
samples: sentences,
}))
Expand All @@ -146,7 +148,11 @@ struct Args {

/// Causual sampling
#[clap(short, long, default_value = "false")]
causal: bool
causal: bool,

/// Address to bind to
#[clap(short, long, default_value = "127.0.0.1:50051")]
address: String,
}

#[tokio::main]
Expand All @@ -157,7 +163,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();
info!("Arguments: {:?}", args);

let addr = "127.0.0.1:50051".parse()?;
let addr = args.address.parse()?;
let data_service = MyDataService::new(args.files, args.causal)?;

info!("Starting server at {}", addr);
Expand Down

0 comments on commit 61a8ab1

Please sign in to comment.