Skip to content

Commit

Permalink
refactor(cubesql): Make Postgres authentication extensible (#8709)
Browse files Browse the repository at this point in the history
  • Loading branch information
MazterQyou committed Sep 18, 2024
1 parent 7d77b4a commit 2bb9aa3
Show file tree
Hide file tree
Showing 8 changed files with 277 additions and 108 deletions.
7 changes: 4 additions & 3 deletions rust/cubesql/cubesql/src/compile/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ use crate::{
},
config::{ConfigObj, ConfigObjImpl},
sql::{
compiler_cache::CompilerCacheImpl, dataframe::batches_to_dataframe, AuthContextRef,
AuthenticateResponse, HttpAuthContext, ServerManager, Session, SessionManager,
SqlAuthService,
compiler_cache::CompilerCacheImpl, dataframe::batches_to_dataframe,
pg_auth_service::PostgresAuthServiceDefaultImpl, AuthContextRef, AuthenticateResponse,
HttpAuthContext, ServerManager, Session, SessionManager, SqlAuthService,
},
transport::{
CubeStreamReceiver, LoadRequestMeta, SpanId, SqlGenerator, SqlResponse, SqlTemplates,
Expand Down Expand Up @@ -610,6 +610,7 @@ async fn get_test_session_with_config_and_transport(
let server = Arc::new(ServerManager::new(
get_test_auth(),
test_transport.clone(),
Arc::new(PostgresAuthServiceDefaultImpl::new()),
Arc::new(CompilerCacheImpl::new(config_obj.clone(), test_transport)),
None,
config_obj,
Expand Down
12 changes: 11 additions & 1 deletion rust/cubesql/cubesql/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use crate::{
injection::{DIService, Injector},
processing_loop::{ProcessingLoop, ShutdownMode},
},
sql::{PostgresServer, ServerManager, SessionManager, SqlAuthDefaultImpl, SqlAuthService},
sql::{
pg_auth_service::{PostgresAuthService, PostgresAuthServiceDefaultImpl},
PostgresServer, ServerManager, SessionManager, SqlAuthDefaultImpl, SqlAuthService,
},
transport::{HttpTransport, TransportService},
CubeError,
};
Expand Down Expand Up @@ -302,6 +305,12 @@ impl Config {
})
.await;

self.injector
.register_typed::<dyn PostgresAuthService, _, _, _>(|_| async move {
Arc::new(PostgresAuthServiceDefaultImpl::new())
})
.await;

self.injector
.register_typed::<dyn CompilerCache, _, _, _>(|i| async move {
let config = i.get_service_typed::<dyn ConfigObj>().await;
Expand All @@ -319,6 +328,7 @@ impl Config {
i.get_service_typed().await,
i.get_service_typed().await,
i.get_service_typed().await,
i.get_service_typed().await,
config.nonce().clone(),
config.clone(),
))
Expand Down
1 change: 1 addition & 0 deletions rust/cubesql/cubesql/src/sql/postgres/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub(crate) mod extended;
pub mod pg_auth_service;
pub(crate) mod pg_type;
pub(crate) mod service;
pub(crate) mod shim;
Expand Down
110 changes: 110 additions & 0 deletions rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
use std::{collections::HashMap, fmt::Debug, sync::Arc};

use async_trait::async_trait;

use crate::{
sql::{AuthContextRef, SqlAuthService},
CubeError,
};

pub use pg_srv::{
buffer as pg_srv_buffer,
protocol::{
AuthenticationRequest, AuthenticationRequestExtension, FrontendMessage,
FrontendMessageExtension,
},
MessageTagParser, MessageTagParserDefaultImpl, ProtocolError,
};

#[derive(Debug)]
pub enum AuthenticationStatus {
UnexpectedFrontendMessage,
Failed(String),
// User name + auth context
Success(String, AuthContextRef),
}

#[async_trait]
pub trait PostgresAuthService: Sync + Send + Debug {
fn get_auth_method(&self, parameters: &HashMap<String, String>) -> AuthenticationRequest;

async fn authenticate(
&self,
service: Arc<dyn SqlAuthService>,
request: AuthenticationRequest,
secret: FrontendMessage,
parameters: &HashMap<String, String>,
) -> AuthenticationStatus;

fn get_pg_message_tag_parser(&self) -> Arc<dyn MessageTagParser>;
}

#[derive(Debug)]
pub struct PostgresAuthServiceDefaultImpl {
pg_message_tag_parser: Arc<dyn MessageTagParser>,
}

impl PostgresAuthServiceDefaultImpl {
pub fn new() -> Self {
Self {
pg_message_tag_parser: Arc::new(MessageTagParserDefaultImpl::default()),
}
}
}

#[async_trait]
impl PostgresAuthService for PostgresAuthServiceDefaultImpl {
fn get_auth_method(&self, _: &HashMap<String, String>) -> AuthenticationRequest {
AuthenticationRequest::CleartextPassword
}

async fn authenticate(
&self,
service: Arc<dyn SqlAuthService>,
request: AuthenticationRequest,
secret: FrontendMessage,
parameters: &HashMap<String, String>,
) -> AuthenticationStatus {
let FrontendMessage::PasswordMessage(password_message) = secret else {
return AuthenticationStatus::UnexpectedFrontendMessage;
};

if !matches!(request, AuthenticationRequest::CleartextPassword) {
return AuthenticationStatus::UnexpectedFrontendMessage;
}

let user = parameters.get("user").unwrap().clone();
let authenticate_response = service
.authenticate(Some(user.clone()), Some(password_message.password.clone()))
.await;

let auth_fail = || {
AuthenticationStatus::Failed(format!(
"password authentication failed for user \"{}\"",
user
))
};

let Ok(authenticate_response) = authenticate_response else {
return auth_fail();
};

if !authenticate_response.skip_password_check {
let is_password_correct = match authenticate_response.password {
None => false,
Some(password) => password == password_message.password,
};
if !is_password_correct {
return auth_fail();
}
}

AuthenticationStatus::Success(user, authenticate_response.context)
}

fn get_pg_message_tag_parser(&self) -> Arc<dyn MessageTagParser> {
Arc::clone(&self.pg_message_tag_parser)
}
}

crate::di_service!(PostgresAuthServiceDefaultImpl, [PostgresAuthService]);
131 changes: 60 additions & 71 deletions rust/cubesql/cubesql/src/sql/postgres/shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
time::SystemTime,
};

use super::extended::PreparedStatement;
use super::{extended::PreparedStatement, pg_auth_service::AuthenticationStatus};
use crate::{
compile::{
convert_statement_to_cube_query,
Expand All @@ -24,8 +24,11 @@ use crate::{
use futures::{pin_mut, FutureExt, StreamExt};
use log::{debug, error, trace};
use pg_srv::{
buffer, protocol,
protocol::{ErrorCode, ErrorResponse, Format, InitialMessage, PortalCompletion},
buffer,
protocol::{
self, AuthenticationRequest, ErrorCode, ErrorResponse, Format, InitialMessage,
PortalCompletion,
},
PgType, PgTypeId, ProtocolError,
};
use sqlparser::ast::{self, CloseCursor, FetchDirection, Query, SetExpr, Statement, Value};
Expand All @@ -46,10 +49,9 @@ pub struct AsyncPostgresShim {
logger: Arc<dyn ContextLogger>,
}

#[derive(PartialEq, Eq)]
pub enum StartupState {
// Initial parameters which client sends in the first message, we use it later in auth method
Success(HashMap<String, String>),
Success(HashMap<String, String>, AuthenticationRequest),
SslRequested,
Denied,
CancelRequest,
Expand Down Expand Up @@ -313,25 +315,23 @@ impl AsyncPostgresShim {
}

pub async fn run(&mut self) -> Result<(), ConnectionError> {
let initial_parameters = match self.process_initial_message().await? {
StartupState::Success(parameters) => parameters,
let (initial_parameters, auth_method) = match self.process_initial_message().await? {
StartupState::Success(parameters, auth_method) => (parameters, auth_method),
StartupState::SslRequested => match self.process_initial_message().await? {
StartupState::Success(parameters) => parameters,
StartupState::Success(parameters, auth_method) => (parameters, auth_method),
_ => return Ok(()),
},
StartupState::Denied | StartupState::CancelRequest => return Ok(()),
};

match buffer::read_message(&mut self.socket).await? {
protocol::FrontendMessage::PasswordMessage(password_message) => {
if !self
.authenticate(password_message, initial_parameters)
.await?
{
return Ok(());
}
}
_ => return Ok(()),
let message_tag_parser = self.session.server.pg_auth.get_pg_message_tag_parser();
let auth_secret =
buffer::read_message(&mut self.socket, Arc::clone(&message_tag_parser)).await?;
if !self
.authenticate(auth_method, auth_secret, initial_parameters)
.await?
{
return Ok(());
}

self.ready().await?;
Expand All @@ -351,7 +351,7 @@ impl AsyncPostgresShim {
true = async { semifast_shutdownable && { semifast_shutdown_interruptor.cancelled().await; true } } => {
return Self::flush_and_write_admin_shutdown_fatal_message(self).await;
}
message_result = buffer::read_message(&mut self.socket) => message_result?
message_result = buffer::read_message(&mut self.socket, Arc::clone(&message_tag_parser)) => message_result?
};

let result = match message {
Expand Down Expand Up @@ -716,73 +716,62 @@ impl AsyncPostgresShim {
return Ok(StartupState::Denied);
}

self.write(protocol::Authentication::new(
protocol::AuthenticationRequest::CleartextPassword,
))
.await?;
let auth_method = self.session.server.pg_auth.get_auth_method(&parameters);
self.write(protocol::Authentication::new(auth_method.clone()))
.await?;

Ok(StartupState::Success(parameters))
Ok(StartupState::Success(parameters, auth_method))
}

pub async fn authenticate(
&mut self,
password_message: protocol::PasswordMessage,
auth_request: AuthenticationRequest,
auth_secret: protocol::FrontendMessage,
parameters: HashMap<String, String>,
) -> Result<bool, ConnectionError> {
let user = parameters.get("user").unwrap().clone();
let authenticate_response = self
let auth_service = self.session.server.auth.clone();
let auth_status = self
.session
.server
.auth
.authenticate(Some(user.clone()), Some(password_message.password.clone()))
.pg_auth
.authenticate(auth_service, auth_request, auth_secret, &parameters)
.await;
let result = match auth_status {
AuthenticationStatus::UnexpectedFrontendMessage => Err((
"invalid authorization specification".to_string(),
protocol::ErrorCode::InvalidAuthorizationSpecification,
)),
AuthenticationStatus::Failed(err) => Err((err, protocol::ErrorCode::InvalidPassword)),
AuthenticationStatus::Success(user, auth_context) => Ok((user, auth_context)),
};

let mut auth_context: Option<AuthContextRef> = None;
match result {
Err((message, code)) => {
let error_response = protocol::ErrorResponse::fatal(code, message);
buffer::write_message(
&mut self.partial_write_buf,
&mut self.socket,
error_response,
)
.await?;

let auth_success = match authenticate_response {
Ok(authenticate_response) => {
auth_context = Some(authenticate_response.context);
if !authenticate_response.skip_password_check {
match authenticate_response.password {
None => false,
Some(password) => password == password_message.password,
}
} else {
true
}
Ok(false)
}
_ => false,
};

if !auth_success {
let error_response = protocol::ErrorResponse::fatal(
protocol::ErrorCode::InvalidPassword,
format!("password authentication failed for user \"{}\"", &user),
);
buffer::write_message(
&mut self.partial_write_buf,
&mut self.socket,
error_response,
)
.await?;
Ok((user, auth_context)) => {
let database = parameters
.get("database")
.map(|v| v.clone())
.unwrap_or("db".to_string());
self.session.state.set_database(Some(database));
self.session.state.set_user(Some(user));
self.session.state.set_auth_context(Some(auth_context));

self.write(protocol::Authentication::new(AuthenticationRequest::Ok))
.await?;

return Ok(false);
Ok(true)
}
}

let database = parameters
.get("database")
.map(|v| v.clone())
.unwrap_or("db".to_string());
self.session.state.set_database(Some(database));
self.session.state.set_user(Some(user));
self.session.state.set_auth_context(auth_context);

self.write(protocol::Authentication::new(
protocol::AuthenticationRequest::Ok,
))
.await?;

Ok(true)
}

pub async fn ready(&mut self) -> Result<(), ConnectionError> {
Expand Down
4 changes: 4 additions & 0 deletions rust/cubesql/cubesql/src/sql/server_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::{
sql::{
compiler_cache::CompilerCache,
database_variables::{mysql_default_global_variables, postgres_default_global_variables},
pg_auth_service::PostgresAuthService,
SqlAuthService,
},
transport::TransportService,
Expand Down Expand Up @@ -37,6 +38,7 @@ pub struct ServerManager {
// References to shared things
pub auth: Arc<dyn SqlAuthService>,
pub transport: Arc<dyn TransportService>,
pub pg_auth: Arc<dyn PostgresAuthService>,
// Non references
pub configuration: ServerConfiguration,
pub nonce: Option<Vec<u8>>,
Expand All @@ -52,13 +54,15 @@ impl ServerManager {
pub fn new(
auth: Arc<dyn SqlAuthService>,
transport: Arc<dyn TransportService>,
pg_auth: Arc<dyn PostgresAuthService>,
compiler_cache: Arc<dyn CompilerCache>,
nonce: Option<Vec<u8>>,
config_obj: Arc<dyn ConfigObj>,
) -> Self {
Self {
auth,
transport,
pg_auth,
compiler_cache,
nonce,
config_obj,
Expand Down
Loading

0 comments on commit 2bb9aa3

Please sign in to comment.