From 2bb9aa3ce13e8b34be2d50c57057221fb7e03d65 Mon Sep 17 00:00:00 2001 From: Alex Qyoun-ae <4062971+MazterQyou@users.noreply.github.com> Date: Wed, 18 Sep 2024 16:17:28 +0400 Subject: [PATCH] refactor(cubesql): Make Postgres authentication extensible (#8709) --- rust/cubesql/cubesql/src/compile/test/mod.rs | 7 +- rust/cubesql/cubesql/src/config/mod.rs | 12 +- rust/cubesql/cubesql/src/sql/postgres/mod.rs | 1 + .../src/sql/postgres/pg_auth_service.rs | 110 +++++++++++++++ rust/cubesql/cubesql/src/sql/postgres/shim.rs | 131 ++++++++---------- .../cubesql/cubesql/src/sql/server_manager.rs | 4 + rust/cubesql/pg-srv/src/buffer.rs | 81 ++++++++--- rust/cubesql/pg-srv/src/protocol.rs | 39 ++++-- 8 files changed, 277 insertions(+), 108 deletions(-) create mode 100644 rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs diff --git a/rust/cubesql/cubesql/src/compile/test/mod.rs b/rust/cubesql/cubesql/src/compile/test/mod.rs index 23dee919bf52a..544ff648bdf4b 100644 --- a/rust/cubesql/cubesql/src/compile/test/mod.rs +++ b/rust/cubesql/cubesql/src/compile/test/mod.rs @@ -16,9 +16,9 @@ use crate::{ }, config::{ConfigObj, ConfigObjImpl}, sql::{ - compiler_cache::CompilerCacheImpl, dataframe::batches_to_dataframe, AuthContextRef, - AuthenticateResponse, HttpAuthContext, ServerManager, Session, SessionManager, - SqlAuthService, + compiler_cache::CompilerCacheImpl, dataframe::batches_to_dataframe, + pg_auth_service::PostgresAuthServiceDefaultImpl, AuthContextRef, AuthenticateResponse, + HttpAuthContext, ServerManager, Session, SessionManager, SqlAuthService, }, transport::{ CubeStreamReceiver, LoadRequestMeta, SpanId, SqlGenerator, SqlResponse, SqlTemplates, @@ -610,6 +610,7 @@ async fn get_test_session_with_config_and_transport( let server = Arc::new(ServerManager::new( get_test_auth(), test_transport.clone(), + Arc::new(PostgresAuthServiceDefaultImpl::new()), Arc::new(CompilerCacheImpl::new(config_obj.clone(), test_transport)), None, config_obj, diff --git a/rust/cubesql/cubesql/src/config/mod.rs b/rust/cubesql/cubesql/src/config/mod.rs index 19c4110f894b1..0beecaae8bbb7 100644 --- a/rust/cubesql/cubesql/src/config/mod.rs +++ b/rust/cubesql/cubesql/src/config/mod.rs @@ -6,7 +6,10 @@ use crate::{ injection::{DIService, Injector}, processing_loop::{ProcessingLoop, ShutdownMode}, }, - sql::{PostgresServer, ServerManager, SessionManager, SqlAuthDefaultImpl, SqlAuthService}, + sql::{ + pg_auth_service::{PostgresAuthService, PostgresAuthServiceDefaultImpl}, + PostgresServer, ServerManager, SessionManager, SqlAuthDefaultImpl, SqlAuthService, + }, transport::{HttpTransport, TransportService}, CubeError, }; @@ -302,6 +305,12 @@ impl Config { }) .await; + self.injector + .register_typed::(|_| async move { + Arc::new(PostgresAuthServiceDefaultImpl::new()) + }) + .await; + self.injector .register_typed::(|i| async move { let config = i.get_service_typed::().await; @@ -319,6 +328,7 @@ impl Config { i.get_service_typed().await, i.get_service_typed().await, i.get_service_typed().await, + i.get_service_typed().await, config.nonce().clone(), config.clone(), )) diff --git a/rust/cubesql/cubesql/src/sql/postgres/mod.rs b/rust/cubesql/cubesql/src/sql/postgres/mod.rs index bf8176b7fe6b6..38f111ebb2252 100644 --- a/rust/cubesql/cubesql/src/sql/postgres/mod.rs +++ b/rust/cubesql/cubesql/src/sql/postgres/mod.rs @@ -1,4 +1,5 @@ pub(crate) mod extended; +pub mod pg_auth_service; pub(crate) mod pg_type; pub(crate) mod service; pub(crate) mod shim; diff --git a/rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs b/rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs new file mode 100644 index 0000000000000..50698f18bb938 --- /dev/null +++ b/rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs @@ -0,0 +1,110 @@ +use std::{collections::HashMap, fmt::Debug, sync::Arc}; + +use async_trait::async_trait; + +use crate::{ + sql::{AuthContextRef, SqlAuthService}, + CubeError, +}; + +pub use pg_srv::{ + buffer as pg_srv_buffer, + protocol::{ + AuthenticationRequest, AuthenticationRequestExtension, FrontendMessage, + FrontendMessageExtension, + }, + MessageTagParser, MessageTagParserDefaultImpl, ProtocolError, +}; + +#[derive(Debug)] +pub enum AuthenticationStatus { + UnexpectedFrontendMessage, + Failed(String), + // User name + auth context + Success(String, AuthContextRef), +} + +#[async_trait] +pub trait PostgresAuthService: Sync + Send + Debug { + fn get_auth_method(&self, parameters: &HashMap) -> AuthenticationRequest; + + async fn authenticate( + &self, + service: Arc, + request: AuthenticationRequest, + secret: FrontendMessage, + parameters: &HashMap, + ) -> AuthenticationStatus; + + fn get_pg_message_tag_parser(&self) -> Arc; +} + +#[derive(Debug)] +pub struct PostgresAuthServiceDefaultImpl { + pg_message_tag_parser: Arc, +} + +impl PostgresAuthServiceDefaultImpl { + pub fn new() -> Self { + Self { + pg_message_tag_parser: Arc::new(MessageTagParserDefaultImpl::default()), + } + } +} + +#[async_trait] +impl PostgresAuthService for PostgresAuthServiceDefaultImpl { + fn get_auth_method(&self, _: &HashMap) -> AuthenticationRequest { + AuthenticationRequest::CleartextPassword + } + + async fn authenticate( + &self, + service: Arc, + request: AuthenticationRequest, + secret: FrontendMessage, + parameters: &HashMap, + ) -> AuthenticationStatus { + let FrontendMessage::PasswordMessage(password_message) = secret else { + return AuthenticationStatus::UnexpectedFrontendMessage; + }; + + if !matches!(request, AuthenticationRequest::CleartextPassword) { + return AuthenticationStatus::UnexpectedFrontendMessage; + } + + let user = parameters.get("user").unwrap().clone(); + let authenticate_response = service + .authenticate(Some(user.clone()), Some(password_message.password.clone())) + .await; + + let auth_fail = || { + AuthenticationStatus::Failed(format!( + "password authentication failed for user \"{}\"", + user + )) + }; + + let Ok(authenticate_response) = authenticate_response else { + return auth_fail(); + }; + + if !authenticate_response.skip_password_check { + let is_password_correct = match authenticate_response.password { + None => false, + Some(password) => password == password_message.password, + }; + if !is_password_correct { + return auth_fail(); + } + } + + AuthenticationStatus::Success(user, authenticate_response.context) + } + + fn get_pg_message_tag_parser(&self) -> Arc { + Arc::clone(&self.pg_message_tag_parser) + } +} + +crate::di_service!(PostgresAuthServiceDefaultImpl, [PostgresAuthService]); diff --git a/rust/cubesql/cubesql/src/sql/postgres/shim.rs b/rust/cubesql/cubesql/src/sql/postgres/shim.rs index 78458f2db2389..8de9ac405c5ff 100644 --- a/rust/cubesql/cubesql/src/sql/postgres/shim.rs +++ b/rust/cubesql/cubesql/src/sql/postgres/shim.rs @@ -3,7 +3,7 @@ use std::{ time::SystemTime, }; -use super::extended::PreparedStatement; +use super::{extended::PreparedStatement, pg_auth_service::AuthenticationStatus}; use crate::{ compile::{ convert_statement_to_cube_query, @@ -24,8 +24,11 @@ use crate::{ use futures::{pin_mut, FutureExt, StreamExt}; use log::{debug, error, trace}; use pg_srv::{ - buffer, protocol, - protocol::{ErrorCode, ErrorResponse, Format, InitialMessage, PortalCompletion}, + buffer, + protocol::{ + self, AuthenticationRequest, ErrorCode, ErrorResponse, Format, InitialMessage, + PortalCompletion, + }, PgType, PgTypeId, ProtocolError, }; use sqlparser::ast::{self, CloseCursor, FetchDirection, Query, SetExpr, Statement, Value}; @@ -46,10 +49,9 @@ pub struct AsyncPostgresShim { logger: Arc, } -#[derive(PartialEq, Eq)] pub enum StartupState { // Initial parameters which client sends in the first message, we use it later in auth method - Success(HashMap), + Success(HashMap, AuthenticationRequest), SslRequested, Denied, CancelRequest, @@ -313,25 +315,23 @@ impl AsyncPostgresShim { } pub async fn run(&mut self) -> Result<(), ConnectionError> { - let initial_parameters = match self.process_initial_message().await? { - StartupState::Success(parameters) => parameters, + let (initial_parameters, auth_method) = match self.process_initial_message().await? { + StartupState::Success(parameters, auth_method) => (parameters, auth_method), StartupState::SslRequested => match self.process_initial_message().await? { - StartupState::Success(parameters) => parameters, + StartupState::Success(parameters, auth_method) => (parameters, auth_method), _ => return Ok(()), }, StartupState::Denied | StartupState::CancelRequest => return Ok(()), }; - match buffer::read_message(&mut self.socket).await? { - protocol::FrontendMessage::PasswordMessage(password_message) => { - if !self - .authenticate(password_message, initial_parameters) - .await? - { - return Ok(()); - } - } - _ => return Ok(()), + let message_tag_parser = self.session.server.pg_auth.get_pg_message_tag_parser(); + let auth_secret = + buffer::read_message(&mut self.socket, Arc::clone(&message_tag_parser)).await?; + if !self + .authenticate(auth_method, auth_secret, initial_parameters) + .await? + { + return Ok(()); } self.ready().await?; @@ -351,7 +351,7 @@ impl AsyncPostgresShim { true = async { semifast_shutdownable && { semifast_shutdown_interruptor.cancelled().await; true } } => { return Self::flush_and_write_admin_shutdown_fatal_message(self).await; } - message_result = buffer::read_message(&mut self.socket) => message_result? + message_result = buffer::read_message(&mut self.socket, Arc::clone(&message_tag_parser)) => message_result? }; let result = match message { @@ -716,73 +716,62 @@ impl AsyncPostgresShim { return Ok(StartupState::Denied); } - self.write(protocol::Authentication::new( - protocol::AuthenticationRequest::CleartextPassword, - )) - .await?; + let auth_method = self.session.server.pg_auth.get_auth_method(¶meters); + self.write(protocol::Authentication::new(auth_method.clone())) + .await?; - Ok(StartupState::Success(parameters)) + Ok(StartupState::Success(parameters, auth_method)) } pub async fn authenticate( &mut self, - password_message: protocol::PasswordMessage, + auth_request: AuthenticationRequest, + auth_secret: protocol::FrontendMessage, parameters: HashMap, ) -> Result { - let user = parameters.get("user").unwrap().clone(); - let authenticate_response = self + let auth_service = self.session.server.auth.clone(); + let auth_status = self .session .server - .auth - .authenticate(Some(user.clone()), Some(password_message.password.clone())) + .pg_auth + .authenticate(auth_service, auth_request, auth_secret, ¶meters) .await; + let result = match auth_status { + AuthenticationStatus::UnexpectedFrontendMessage => Err(( + "invalid authorization specification".to_string(), + protocol::ErrorCode::InvalidAuthorizationSpecification, + )), + AuthenticationStatus::Failed(err) => Err((err, protocol::ErrorCode::InvalidPassword)), + AuthenticationStatus::Success(user, auth_context) => Ok((user, auth_context)), + }; - let mut auth_context: Option = None; + match result { + Err((message, code)) => { + let error_response = protocol::ErrorResponse::fatal(code, message); + buffer::write_message( + &mut self.partial_write_buf, + &mut self.socket, + error_response, + ) + .await?; - let auth_success = match authenticate_response { - Ok(authenticate_response) => { - auth_context = Some(authenticate_response.context); - if !authenticate_response.skip_password_check { - match authenticate_response.password { - None => false, - Some(password) => password == password_message.password, - } - } else { - true - } + Ok(false) } - _ => false, - }; - - if !auth_success { - let error_response = protocol::ErrorResponse::fatal( - protocol::ErrorCode::InvalidPassword, - format!("password authentication failed for user \"{}\"", &user), - ); - buffer::write_message( - &mut self.partial_write_buf, - &mut self.socket, - error_response, - ) - .await?; + Ok((user, auth_context)) => { + let database = parameters + .get("database") + .map(|v| v.clone()) + .unwrap_or("db".to_string()); + self.session.state.set_database(Some(database)); + self.session.state.set_user(Some(user)); + self.session.state.set_auth_context(Some(auth_context)); + + self.write(protocol::Authentication::new(AuthenticationRequest::Ok)) + .await?; - return Ok(false); + Ok(true) + } } - - let database = parameters - .get("database") - .map(|v| v.clone()) - .unwrap_or("db".to_string()); - self.session.state.set_database(Some(database)); - self.session.state.set_user(Some(user)); - self.session.state.set_auth_context(auth_context); - - self.write(protocol::Authentication::new( - protocol::AuthenticationRequest::Ok, - )) - .await?; - - Ok(true) } pub async fn ready(&mut self) -> Result<(), ConnectionError> { diff --git a/rust/cubesql/cubesql/src/sql/server_manager.rs b/rust/cubesql/cubesql/src/sql/server_manager.rs index 6bdad074ae8b4..17375e86ddd4f 100644 --- a/rust/cubesql/cubesql/src/sql/server_manager.rs +++ b/rust/cubesql/cubesql/src/sql/server_manager.rs @@ -4,6 +4,7 @@ use crate::{ sql::{ compiler_cache::CompilerCache, database_variables::{mysql_default_global_variables, postgres_default_global_variables}, + pg_auth_service::PostgresAuthService, SqlAuthService, }, transport::TransportService, @@ -37,6 +38,7 @@ pub struct ServerManager { // References to shared things pub auth: Arc, pub transport: Arc, + pub pg_auth: Arc, // Non references pub configuration: ServerConfiguration, pub nonce: Option>, @@ -52,6 +54,7 @@ impl ServerManager { pub fn new( auth: Arc, transport: Arc, + pg_auth: Arc, compiler_cache: Arc, nonce: Option>, config_obj: Arc, @@ -59,6 +62,7 @@ impl ServerManager { Self { auth, transport, + pg_auth, compiler_cache, nonce, config_obj, diff --git a/rust/cubesql/pg-srv/src/buffer.rs b/rust/cubesql/pg-srv/src/buffer.rs index 27988797e0d21..f4c574030cdba 100644 --- a/rust/cubesql/pg-srv/src/buffer.rs +++ b/rust/cubesql/pg-srv/src/buffer.rs @@ -1,10 +1,13 @@ //! Helpers for reading/writing from/to the connection's socket +use async_trait::async_trait; use bytes::{BufMut, BytesMut}; use std::{ convert::TryFrom, + fmt::Debug, io::{Cursor, Error, ErrorKind}, marker::Send, + sync::Arc, }; use crate::{ @@ -16,34 +19,68 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use super::protocol::{self, Deserialize, FrontendMessage, Serialize}; +#[async_trait] +pub trait MessageTagParser: Sync + Send + Debug { + async fn parse( + &self, + tag: u8, + cursor: Cursor>, + ) -> Result; +} + +#[derive(Default, Debug)] +pub struct MessageTagParserDefaultImpl {} + +impl MessageTagParserDefaultImpl { + pub fn new() -> Self { + Self {} + } + + pub fn with_arc() -> Arc { + Arc::new(Self::new()) + } +} + +#[async_trait] +impl MessageTagParser for MessageTagParserDefaultImpl { + async fn parse( + &self, + tag: u8, + cursor: Cursor>, + ) -> Result { + let message = match tag { + b'Q' => FrontendMessage::Query(protocol::Query::deserialize(cursor).await?), + b'P' => FrontendMessage::Parse(protocol::Parse::deserialize(cursor).await?), + b'B' => FrontendMessage::Bind(protocol::Bind::deserialize(cursor).await?), + b'D' => FrontendMessage::Describe(protocol::Describe::deserialize(cursor).await?), + b'E' => FrontendMessage::Execute(protocol::Execute::deserialize(cursor).await?), + b'C' => FrontendMessage::Close(protocol::Close::deserialize(cursor).await?), + b'p' => FrontendMessage::PasswordMessage( + protocol::PasswordMessage::deserialize(cursor).await?, + ), + b'X' => FrontendMessage::Terminate, + b'H' => FrontendMessage::Flush, + b'S' => FrontendMessage::Sync, + identifier => { + return Err(ErrorResponse::error( + ErrorCode::DataException, + format!("Unknown message identifier: {:X?}", identifier), + ) + .into()) + } + }; + Ok(message) + } +} + pub async fn read_message( reader: &mut Reader, + parser: Arc, ) -> Result { // https://www.postgresql.org/docs/14/protocol-message-formats.html let message_tag = reader.read_u8().await?; let cursor = read_contents(reader, message_tag).await?; - - let message = match message_tag { - b'Q' => FrontendMessage::Query(protocol::Query::deserialize(cursor).await?), - b'P' => FrontendMessage::Parse(protocol::Parse::deserialize(cursor).await?), - b'B' => FrontendMessage::Bind(protocol::Bind::deserialize(cursor).await?), - b'D' => FrontendMessage::Describe(protocol::Describe::deserialize(cursor).await?), - b'E' => FrontendMessage::Execute(protocol::Execute::deserialize(cursor).await?), - b'C' => FrontendMessage::Close(protocol::Close::deserialize(cursor).await?), - b'p' => { - FrontendMessage::PasswordMessage(protocol::PasswordMessage::deserialize(cursor).await?) - } - b'X' => FrontendMessage::Terminate, - b'H' => FrontendMessage::Flush, - b'S' => FrontendMessage::Sync, - identifier => { - return Err(ErrorResponse::error( - ErrorCode::DataException, - format!("Unknown message identifier: {:X?}", identifier), - ) - .into()) - } - }; + let message = parser.parse(message_tag, cursor).await?; trace!("[pg] Decoded {:X?}", message,); diff --git a/rust/cubesql/pg-srv/src/protocol.rs b/rust/cubesql/pg-srv/src/protocol.rs index da7b0439f96b6..5b4755477ae8d 100644 --- a/rust/cubesql/pg-srv/src/protocol.rs +++ b/rust/cubesql/pg-srv/src/protocol.rs @@ -3,10 +3,12 @@ //! Message Data Types: use std::{ + any::Any, collections::HashMap, convert::TryFrom, - fmt::{self, Display, Formatter}, + fmt::{self, Debug, Display, Formatter}, io::{Cursor, Error}, + sync::Arc, }; use async_trait::async_trait; @@ -913,8 +915,12 @@ pub enum Format { Binary, } +pub trait FrontendMessageExtension: Send + Sync + Debug { + fn as_any(&self) -> &dyn Any; +} + /// All frontend messages (request which client sends to the server). -#[derive(Debug, PartialEq)] +#[derive(Debug)] pub enum FrontendMessage { PasswordMessage(PasswordMessage), /// Simple Query @@ -935,6 +941,8 @@ pub enum FrontendMessage { Execute(Execute), /// Extended Query. Close Portal/Statement Close(Close), + /// Extension + Extension(Box), } /// @@ -1055,9 +1063,17 @@ impl TransactionStatus { } } +pub trait AuthenticationRequestExtension: Send + Sync { + fn as_any(&self) -> &dyn Any; + + fn to_code(&self) -> u32; +} + +#[derive(Clone)] pub enum AuthenticationRequest { Ok, CleartextPassword, + Extension(Arc), } impl AuthenticationRequest { @@ -1069,6 +1085,7 @@ impl AuthenticationRequest { match self { Self::Ok => 0, Self::CleartextPassword => 3, + Self::Extension(extension) => extension.to_code(), } } } @@ -1093,7 +1110,7 @@ pub trait Deserialize { #[cfg(test)] mod tests { use super::*; - use crate::{read_message, ProtocolError}; + use crate::{read_message, MessageTagParserDefaultImpl, ProtocolError}; use std::io::Cursor; @@ -1171,7 +1188,7 @@ mod tests { ); let mut cursor = Cursor::new(buffer); - let message = read_message(&mut cursor).await?; + let message = read_message(&mut cursor, MessageTagParserDefaultImpl::with_arc()).await?; match message { FrontendMessage::Parse(parse) => { assert_eq!( @@ -1201,7 +1218,7 @@ mod tests { ); let mut cursor = Cursor::new(buffer); - let message = read_message(&mut cursor).await?; + let message = read_message(&mut cursor, MessageTagParserDefaultImpl::with_arc()).await?; match message { FrontendMessage::Bind(bind) => { assert_eq!( @@ -1236,7 +1253,7 @@ mod tests { ); let mut cursor = Cursor::new(buffer); - let message = read_message(&mut cursor).await?; + let message = read_message(&mut cursor, MessageTagParserDefaultImpl::with_arc()).await?; match message { FrontendMessage::Bind(body) => { assert_eq!( @@ -1272,7 +1289,7 @@ mod tests { ); let mut cursor = Cursor::new(buffer); - let message = read_message(&mut cursor).await?; + let message = read_message(&mut cursor, MessageTagParserDefaultImpl::with_arc()).await?; match message { FrontendMessage::Describe(desc) => { assert_eq!( @@ -1299,7 +1316,7 @@ mod tests { ); let mut cursor = Cursor::new(buffer); - let message = read_message(&mut cursor).await?; + let message = read_message(&mut cursor, MessageTagParserDefaultImpl::with_arc()).await?; match message { FrontendMessage::PasswordMessage(body) => { assert_eq!( @@ -1325,7 +1342,7 @@ mod tests { ); let mut cursor = Cursor::new(buffer); - let message = read_message(&mut cursor).await?; + let message = read_message(&mut cursor, MessageTagParserDefaultImpl::with_arc()).await?; match message { FrontendMessage::Execute(body) => { assert_eq!( @@ -1355,8 +1372,8 @@ mod tests { // This test demonstrates that protocol can decode two // simple messages without body in sequence - read_message(&mut cursor).await?; - read_message(&mut cursor).await?; + read_message(&mut cursor, MessageTagParserDefaultImpl::with_arc()).await?; + read_message(&mut cursor, MessageTagParserDefaultImpl::with_arc()).await?; Ok(()) }