diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index 2a02748a10..184196984f 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -192,7 +192,8 @@ async fn authenticate( }; let conn_info = compute::ConnectInfo { - host: db_info.host.into(), + host: db_info.host.as_ref().into(), + server_name: db_info.host.into_string(), port: db_info.port, ssl_mode, host_addr: None, diff --git a/proxy/src/auth/backend/local.rs b/proxy/src/auth/backend/local.rs index 2224f492b8..a86f726b24 100644 --- a/proxy/src/auth/backend/local.rs +++ b/proxy/src/auth/backend/local.rs @@ -33,6 +33,7 @@ impl LocalBackend { conn_info: ConnectInfo { host_addr: Some(postgres_addr.ip()), host: postgres_addr.ip().to_string().into(), + server_name: postgres_addr.ip().to_string(), port: postgres_addr.port(), ssl_mode: SslMode::Disable, }, diff --git a/proxy/src/compute/mod.rs b/proxy/src/compute/mod.rs index ca784423ee..f0452d1d79 100644 --- a/proxy/src/compute/mod.rs +++ b/proxy/src/compute/mod.rs @@ -151,6 +151,7 @@ pub(crate) struct AuthInfo { pub struct ConnectInfo { pub host_addr: Option, pub host: Host, + pub server_name: String, pub port: u16, pub ssl_mode: SslMode, } @@ -303,6 +304,7 @@ impl ConnectInfo { // require for our business. let port = self.port; let host = &*self.host; + let server_name = &*self.server_name; let addrs = match self.host_addr { Some(addr) => vec![SocketAddr::new(addr, port)], @@ -312,7 +314,7 @@ impl ConnectInfo { match connect_once(&*addrs).await { Ok((sockaddr, stream)) => Ok(( sockaddr, - tls::connect_tls(stream, self.ssl_mode, config, host, tls).await?, + tls::connect_tls(stream, self.ssl_mode, config, server_name, tls).await?, )), Err(err) => { warn!("couldn't connect to compute node at {host}:{port}: {err}"); diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index b76b13e2c2..9db6f8c8a2 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -31,7 +31,7 @@ use crate::control_plane::{ use crate::metrics::Metrics; use crate::proxy::retry::CouldRetry; use crate::rate_limiter::WakeComputeRateLimiter; -use crate::types::{EndpointCacheKey, EndpointId, RoleName}; +use crate::types::{EndpointCacheKey, EndpointId, Host, RoleName}; use crate::{compute, http, scram}; pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); @@ -313,7 +313,7 @@ impl NeonControlPlaneClient { Some(_) => SslMode::Require, None => SslMode::Disable, }; - let host = match body.server_name { + let host: Host = match body.server_name { Some(host) => host.into(), None => host.into(), }; @@ -321,6 +321,7 @@ impl NeonControlPlaneClient { let node = NodeInfo { conn_info: compute::ConnectInfo { host_addr, + server_name: host.to_string(), host, port, ssl_mode, diff --git a/proxy/src/control_plane/client/lakebase_v1.rs b/proxy/src/control_plane/client/lakebase_v1.rs new file mode 100644 index 0000000000..2959b625cd --- /dev/null +++ b/proxy/src/control_plane/client/lakebase_v1.rs @@ -0,0 +1,104 @@ +//! Production console backend. + +use std::sync::Arc; + +use ::http::HeaderName; +use postgres_client::config::SslMode; + +use crate::auth::backend::ComputeUserInfo; +use crate::auth::backend::jwt::AuthRule; +use crate::compute::ConnectInfo; +use crate::context::RequestContext; +use crate::control_plane::errors::{ + ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError, +}; +use crate::control_plane::messages::{ColdStartInfo, EndpointRateLimitConfig, MetricsAuxInfo}; +use crate::control_plane::{ + AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, NodeInfo, RoleAccessControl, +}; +use crate::intern::{BranchIdInt, ProjectIdInt}; +use crate::types::{BranchId, EndpointId, ProjectId, RoleName}; + +pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); + +#[derive(Clone)] +pub struct LakebaseClient { + pub namespace: String, + pub port: u16, +} + +impl LakebaseClient { + /// Construct an API object containing the auth parameters. + pub fn new(namespace: String, port: u16) -> Self { + Self { namespace, port } + } +} + +impl super::ControlPlaneApi for LakebaseClient { + #[tracing::instrument(skip_all)] + async fn get_role_access_control( + &self, + _ctx: &RequestContext, + _endpoint: &EndpointId, + _role: &RoleName, + ) -> Result { + Ok(RoleAccessControl { secret: None }) + } + + #[tracing::instrument(skip_all)] + async fn get_endpoint_access_control( + &self, + _ctx: &RequestContext, + _endpoint: &EndpointId, + _role: &RoleName, + ) -> Result { + Ok(EndpointAccessControl { + allowed_ips: Arc::new(vec![]), + allowed_vpce: Arc::new(vec![]), + flags: AccessBlockerFlags::default(), + rate_limits: EndpointRateLimitConfig::default(), + }) + } + + #[tracing::instrument(skip_all)] + async fn get_endpoint_jwks( + &self, + _ctx: &RequestContext, + _endpoint: &EndpointId, + ) -> Result, GetEndpointJwksError> { + Err(GetEndpointJwksError::ControlPlane( + ControlPlaneError::Transport(std::io::Error::other("unsupported")), + )) + } + + #[tracing::instrument(skip_all)] + async fn wake_compute( + &self, + _ctx: &RequestContext, + user_info: &ComputeUserInfo, + ) -> Result { + let instance_id = user_info.endpoint.as_str(); + let namespace = self.namespace.as_str(); + + let host = format!("{instance_id}.{namespace}.svc.cluster.local").into(); + let server_name = format!("{instance_id}.online-tables.dev.databricks.com"); + let port = self.port; + + Ok(CachedNodeInfo::new_uncached(NodeInfo { + conn_info: ConnectInfo { + host_addr: None, + host, + server_name, + port, + ssl_mode: SslMode::Require, + }, + aux: MetricsAuxInfo { + endpoint_id: user_info.endpoint.normalize_intern(), + project_id: ProjectIdInt::from(&ProjectId::from("unknown")), + branch_id: BranchIdInt::from(&BranchId::from("unknown")), + compute_id: user_info.endpoint.as_str().into(), + cold_start_info: ColdStartInfo::WarmCached, + }, + })) + } +} diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index 9e48d91340..a40cddc10f 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -175,12 +175,14 @@ impl MockControlPlane { let conn_info = match self.endpoint.host_str() { None => ConnectInfo { host_addr: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), + server_name: "localhost".into(), host: "localhost".into(), port, ssl_mode: SslMode::Disable, }, Some(host) => ConnectInfo { host_addr: IpAddr::from_str(host).ok(), + server_name: host.into(), host: host.into(), port, ssl_mode: SslMode::Disable, diff --git a/proxy/src/control_plane/client/mod.rs b/proxy/src/control_plane/client/mod.rs index ec26746873..be80171076 100644 --- a/proxy/src/control_plane/client/mod.rs +++ b/proxy/src/control_plane/client/mod.rs @@ -1,4 +1,5 @@ pub mod cplane_proxy_v1; +pub mod lakebase_v1; #[cfg(any(test, feature = "testing"))] pub mod mock; diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 7e0710749e..d923f4b260 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -531,6 +531,7 @@ fn helper_create_uncached_node_info() -> NodeInfo { NodeInfo { conn_info: compute::ConnectInfo { host: "test".into(), + server_name: "test".into(), port: 5432, ssl_mode: SslMode::Disable, host_addr: None,