Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(query): refactor license manager #16492

Merged
merged 1 commit into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 24 additions & 24 deletions src/binaries/query/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,30 @@ pub async fn start_services(conf: &InnerConfig) -> Result<(), MainError> {

info!("Databend Query start with config: {:?}", conf);

// Cluster register.
{
ClusterDiscovery::instance()
.register_to_metastore(conf)
.await
.with_context(make_error)?;
info!(
"Databend query has been registered:{:?} to metasrv:{:?}.",
conf.query.cluster_id, conf.meta.endpoints
);
}

// RPC API service.
{
let address = conf.query.flight_api_address.clone();
let mut srv = FlightService::create(conf.clone()).with_context(make_error)?;
let listening = srv
.start(address.parse().with_context(make_error)?)
.await
.with_context(make_error)?;
shutdown_handle.add_service("RPCService", srv);
info!("Listening for RPC API (interserver): {}", listening);
}

// MySQL handler.
{
let hostname = conf.query.mysql_handler_host.clone();
Expand Down Expand Up @@ -229,30 +253,6 @@ pub async fn start_services(conf: &InnerConfig) -> Result<(), MainError> {
info!("Listening for FlightSQL API: {}", listening);
}

// RPC API service.
{
let address = conf.query.flight_api_address.clone();
let mut srv = FlightService::create(conf.clone()).with_context(make_error)?;
let listening = srv
.start(address.parse().with_context(make_error)?)
.await
.with_context(make_error)?;
shutdown_handle.add_service("RPCService", srv);
info!("Listening for RPC API (interserver): {}", listening);
}

// Cluster register.
{
ClusterDiscovery::instance()
.register_to_metastore(conf)
.await
.with_context(make_error)?;
info!(
"Databend query has been registered:{:?} to metasrv:{:?}.",
conf.query.cluster_id, conf.meta.endpoints
);
}

// Print information to users.
println!("Databend Query");

Expand Down
1 change: 1 addition & 0 deletions src/common/exception/src/exception_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ build_exceptions! {
/// For example: license key is expired
LicenseKeyInvalid(1402),
EnterpriseFeatureNotEnable(1403),
LicenseKeyExpired(1404),

BackgroundJobAlreadyExists(1501),
UnknownBackgroundJob(1502),
Expand Down
21 changes: 13 additions & 8 deletions src/common/license/src/license.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::fmt;

use databend_common_base::display::display_option::DisplayOptionExt;
use databend_common_base::display::display_slice::DisplaySliceExt;
use databend_common_exception::ErrorCode;
use serde::Deserialize;
use serde::Serialize;

Expand Down Expand Up @@ -124,31 +125,35 @@ impl fmt::Display for Feature {
}

impl Feature {
pub fn verify(&self, feature: &Feature) -> bool {
pub fn verify_default(&self, message: impl Into<String>) -> Result<(), ErrorCode> {
Err(ErrorCode::LicenseKeyInvalid(message.into()))
}

pub fn verify(&self, feature: &Feature) -> Result<bool, ErrorCode> {
match (self, feature) {
(Feature::ComputeQuota(c), Feature::ComputeQuota(v)) => {
if let Some(thread_num) = c.threads_num {
if thread_num <= v.threads_num.unwrap_or(usize::MAX) {
return false;
return Ok(false);
}
}

if let Some(max_memory_usage) = c.memory_usage {
if max_memory_usage <= v.memory_usage.unwrap_or(usize::MAX) {
return false;
return Ok(false);
}
}

true
Ok(true)
}
(Feature::StorageQuota(c), Feature::StorageQuota(v)) => {
if let Some(max_storage_usage) = c.storage_usage {
if max_storage_usage <= v.storage_usage.unwrap_or(usize::MAX) {
return false;
return Ok(false);
}
}

true
Ok(true)
}
(Feature::Test, Feature::Test)
| (Feature::AggregateIndex, Feature::AggregateIndex)
Expand All @@ -161,8 +166,8 @@ impl Feature {
| (Feature::InvertedIndex, Feature::InvertedIndex)
| (Feature::VirtualColumn, Feature::VirtualColumn)
| (Feature::AttacheTable, Feature::AttacheTable)
| (Feature::StorageEncryption, Feature::StorageEncryption) => true,
(_, _) => false,
| (Feature::StorageEncryption, Feature::StorageEncryption) => Ok(true),
(_, _) => Ok(false),
}
}
}
Expand Down
7 changes: 3 additions & 4 deletions src/common/license/src/license_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,9 @@ impl LicenseManager for OssLicenseManager {
GlobalInstance::get()
}

fn check_enterprise_enabled(&self, _license_key: String, _feature: Feature) -> Result<()> {
Err(ErrorCode::LicenseKeyInvalid(
"Need Commercial License".to_string(),
))
fn check_enterprise_enabled(&self, _license_key: String, feature: Feature) -> Result<()> {
// oss ignore license key.
feature.verify_default("Need Commercial License".to_string())
}

fn parse_license(&self, _raw: &str) -> Result<JWTClaims<LicenseInfo>> {
Expand Down
94 changes: 57 additions & 37 deletions src/query/ee/src/license/license_mgr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use jwt_simple::algorithms::ES256PublicKey;
use jwt_simple::claims::JWTClaims;
use jwt_simple::prelude::Clock;
use jwt_simple::prelude::ECDSAP256PublicKeyLike;
use jwt_simple::JWTError;

const LICENSE_PUBLIC_KEY: &str = r#"-----BEGIN PUBLIC KEY-----
MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEGsKCbhXU7j56VKZ7piDlLXGhud0a
Expand Down Expand Up @@ -60,34 +61,41 @@ impl LicenseManager for RealLicenseManager {

fn check_enterprise_enabled(&self, license_key: String, feature: Feature) -> Result<()> {
if license_key.is_empty() {
return Err(ErrorCode::LicenseKeyInvalid(format!(
"use of {feature} requires an enterprise license. license key is not found for {}",
return feature.verify_default(format!(
"The use of this feature requires a Databend Enterprise Edition license. No license key found for tenant: {}. To unlock enterprise features, please contact Databend to obtain a license. Learn more at https://docs.databend.com/guides/overview/editions/dee/",
self.tenant
)));
));
}

if let Some(v) = self.cache.get(&license_key) {
return Self::verify_feature(v.value(), feature);
return self.verify_feature(v.value(), feature);
}

let license = self.parse_license(&license_key).map_err_to_code(
ErrorCode::LicenseKeyInvalid,
|| format!("use of {feature} requires an enterprise license. current license is invalid for {}", self.tenant),
)?;
Self::verify_feature(&license, feature)?;
self.cache.insert(license_key, license);
Ok(())
match self.parse_license(&license_key) {
Ok(license) => {
self.verify_feature(&license, feature)?;
self.cache.insert(license_key, license);
Ok(())
}
Err(e) => match e.code() == ErrorCode::LICENSE_KEY_EXPIRED {
true => self.verify_if_expired(feature),
false => Err(e),
},
}
}

fn parse_license(&self, raw: &str) -> Result<JWTClaims<LicenseInfo>> {
let public_key = ES256PublicKey::from_pem(self.public_key.as_str())
.map_err_to_code(ErrorCode::LicenseKeyParseError, || "public key load failed")?;
public_key
.verify_token::<LicenseInfo>(raw, None)
.map_err_to_code(
ErrorCode::LicenseKeyParseError,
|| "jwt claim decode failed",
)
match public_key.verify_token::<LicenseInfo>(raw, None) {
Ok(v) => Ok(v),
Err(cause) => match cause.downcast_ref::<JWTError>() {
Some(JWTError::TokenHasExpired) => {
Err(ErrorCode::LicenseKeyExpired("license key is expired."))
}
_ => Err(ErrorCode::LicenseKeyParseError("jwt claim decode failed")),
},
}
}

fn get_storage_quota(&self, license_key: String) -> Result<StorageQuota> {
Expand All @@ -96,15 +104,26 @@ impl LicenseManager for RealLicenseManager {
}

if let Some(v) = self.cache.get(&license_key) {
Self::verify_license(v.value())?;
if Self::verify_license_expired(v.value())? {
return Err(ErrorCode::LicenseKeyExpired(format!(
"license key expired in {:?}",
v.value().expires_at,
)));
}
return Ok(v.custom.get_storage_quota());
}

let license = self.parse_license(&license_key).map_err_to_code(
ErrorCode::LicenseKeyInvalid,
|| format!("use of storage requires an enterprise license. current license is invalid for {}", self.tenant),
)?;
Self::verify_license(&license)?;

if Self::verify_license_expired(&license)? {
return Err(ErrorCode::LicenseKeyExpired(format!(
"license key expired in {:?}",
license.expires_at,
)));
}

let quota = license.custom.get_storage_quota();
self.cache.insert(license_key, license);
Expand All @@ -123,36 +142,28 @@ impl RealLicenseManager {
}
}

fn verify_license(l: &JWTClaims<LicenseInfo>) -> Result<()> {
fn verify_license_expired(l: &JWTClaims<LicenseInfo>) -> Result<bool> {
let now = Clock::now_since_epoch();
match l.expires_at {
Some(expire_at) => {
if now > expire_at {
return Err(ErrorCode::LicenseKeyInvalid(format!(
"license key expired in {:?}",
expire_at
)));
}
}
None => {
return Err(ErrorCode::LicenseKeyInvalid(
"cannot find valid expire time",
));
}
Some(expire_at) => Ok(now > expire_at),
None => Err(ErrorCode::LicenseKeyInvalid(
"cannot find valid expire time",
)),
}
Ok(())
}

fn verify_feature(l: &JWTClaims<LicenseInfo>, feature: Feature) -> Result<()> {
Self::verify_license(l)?;
fn verify_feature(&self, l: &JWTClaims<LicenseInfo>, feature: Feature) -> Result<()> {
if Self::verify_license_expired(l)? {
return self.verify_if_expired(feature);
}

if l.custom.features.is_none() {
return Ok(());
}

let verify_features = l.custom.features.as_ref().unwrap();
for verify_feature in verify_features {
if verify_feature.verify(&feature) {
if verify_feature.verify(&feature)? {
return Ok(());
}
}
Expand All @@ -163,4 +174,13 @@ impl RealLicenseManager {
l.custom.display_features()
)))
}

fn verify_if_expired(&self, feature: Feature) -> Result<()> {
feature.verify_default("").map_err(|_|
ErrorCode::LicenseKeyExpired(format!(
"The use of this feature requires a Databend Enterprise Edition license. License key has expired for tenant: {}. To unlock enterprise features, please contact Databend to obtain a license. Learn more at https://docs.databend.com/guides/overview/editions/dee/",
self.tenant
))
)
}
}
Loading