use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::sync::Mutex;
use http::StatusCode;
use iceberg::{Error, ErrorKind, Result};
use reqwest::header::HeaderMap;
use reqwest::{Client, IntoUrl, Method, Request, RequestBuilder, Response};
use serde::de::DeserializeOwned;
use crate::types::{ErrorResponse, TokenResponse};
use crate::RestCatalogConfig;
pub(crate) struct HttpClient {
client: Client,
token: Mutex<Option<String>>,
token_endpoint: String,
credential: Option<(Option<String>, String)>,
extra_headers: HeaderMap,
extra_oauth_params: HashMap<String, String>,
}
impl Debug for HttpClient {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HttpClient")
.field("client", &self.client)
.field("extra_headers", &self.extra_headers)
.finish_non_exhaustive()
}
}
impl HttpClient {
pub fn new(cfg: &RestCatalogConfig) -> Result<Self> {
Ok(HttpClient {
client: Client::new(),
token: Mutex::new(cfg.token()),
token_endpoint: cfg.get_token_endpoint(),
credential: cfg.credential(),
extra_headers: cfg.extra_headers()?,
extra_oauth_params: cfg.extra_oauth_params(),
})
}
pub fn update_with(self, cfg: &RestCatalogConfig) -> Result<Self> {
Ok(HttpClient {
client: self.client,
token: Mutex::new(
cfg.token()
.or_else(|| self.token.into_inner().ok().flatten()),
),
token_endpoint: (!cfg.get_token_endpoint().is_empty())
.then(|| cfg.get_token_endpoint())
.unwrap_or(self.token_endpoint),
credential: cfg.credential().or(self.credential),
extra_headers: (!cfg.extra_headers()?.is_empty())
.then(|| cfg.extra_headers())
.transpose()?
.unwrap_or(self.extra_headers),
extra_oauth_params: (!cfg.extra_oauth_params().is_empty())
.then(|| cfg.extra_oauth_params())
.unwrap_or(self.extra_oauth_params),
})
}
#[cfg(test)]
pub(crate) async fn token(&self) -> Option<String> {
let mut req = self
.request(Method::GET, &self.token_endpoint)
.build()
.unwrap();
self.authenticate(&mut req).await.ok();
self.token.lock().unwrap().clone()
}
async fn authenticate(&self, req: &mut Request) -> Result<()> {
let token = { self.token.lock().expect("lock poison").clone() };
if self.credential.is_none() && token.is_none() {
return Ok(());
}
if let Some(token) = &token {
req.headers_mut().insert(
http::header::AUTHORIZATION,
format!("Bearer {token}").parse().map_err(|e| {
Error::new(
ErrorKind::DataInvalid,
"Invalid token received from catalog server!",
)
.with_source(e)
})?,
);
return Ok(());
}
let (client_id, client_secret) = self.credential.as_ref().ok_or_else(|| {
Error::new(
ErrorKind::DataInvalid,
"Credential must be provided for authentication",
)
})?;
let mut params = HashMap::with_capacity(4);
params.insert("grant_type", "client_credentials");
if let Some(client_id) = client_id {
params.insert("client_id", client_id);
}
params.insert("client_secret", client_secret);
params.extend(
self.extra_oauth_params
.iter()
.map(|(k, v)| (k.as_str(), v.as_str())),
);
let auth_req = self
.client
.request(Method::POST, &self.token_endpoint)
.form(¶ms)
.build()?;
let auth_url = auth_req.url().clone();
let auth_resp = self.client.execute(auth_req).await?;
let auth_res: TokenResponse = if auth_resp.status() == StatusCode::OK {
let text = auth_resp
.bytes()
.await
.map_err(|err| err.with_url(auth_url.clone()))?;
Ok(serde_json::from_slice(&text).map_err(|e| {
Error::new(
ErrorKind::Unexpected,
"Failed to parse response from rest catalog server!",
)
.with_context("operation", "auth")
.with_context("url", auth_url.to_string())
.with_context("json", String::from_utf8_lossy(&text))
.with_source(e)
})?)
} else {
let code = auth_resp.status();
let text = auth_resp
.bytes()
.await
.map_err(|err| err.with_url(auth_url.clone()))?;
let e: ErrorResponse = serde_json::from_slice(&text).map_err(|e| {
Error::new(ErrorKind::Unexpected, "Received unexpected response")
.with_context("code", code.to_string())
.with_context("operation", "auth")
.with_context("url", auth_url.to_string())
.with_context("json", String::from_utf8_lossy(&text))
.with_source(e)
})?;
Err(Error::from(e))
}?;
let token = auth_res.access_token;
*self.token.lock().expect("lock poison") = Some(token.clone());
req.headers_mut().insert(
http::header::AUTHORIZATION,
format!("Bearer {token}").parse().map_err(|e| {
Error::new(
ErrorKind::DataInvalid,
"Invalid token received from catalog server!",
)
.with_source(e)
})?,
);
Ok(())
}
#[inline]
pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder {
self.client.request(method, url)
}
pub async fn query<R: DeserializeOwned, E: DeserializeOwned + Into<Error>>(
&self,
mut request: Request,
) -> Result<R> {
self.authenticate(&mut request).await?;
let method = request.method().clone();
let url = request.url().clone();
let response = self.client.execute(request).await?;
if response.status() == StatusCode::OK {
let text = response
.bytes()
.await
.map_err(|err| err.with_url(url.clone()))?;
Ok(serde_json::from_slice::<R>(&text).map_err(|e| {
Error::new(
ErrorKind::Unexpected,
"Failed to parse response from rest catalog server!",
)
.with_context("method", method.to_string())
.with_context("url", url.to_string())
.with_context("json", String::from_utf8_lossy(&text))
.with_source(e)
})?)
} else {
let code = response.status();
let text = response
.bytes()
.await
.map_err(|err| err.with_url(url.clone()))?;
let e = serde_json::from_slice::<E>(&text).map_err(|e| {
Error::new(ErrorKind::Unexpected, "Received unexpected response")
.with_context("code", code.to_string())
.with_context("method", method.to_string())
.with_context("url", url.to_string())
.with_context("json", String::from_utf8_lossy(&text))
.with_source(e)
})?;
Err(e.into())
}
}
pub async fn execute<E: DeserializeOwned + Into<Error>>(
&self,
mut request: Request,
) -> Result<()> {
self.authenticate(&mut request).await?;
let method = request.method().clone();
let url = request.url().clone();
let response = self.client.execute(request).await?;
match response.status() {
StatusCode::OK | StatusCode::NO_CONTENT => Ok(()),
code => {
let text = response
.bytes()
.await
.map_err(|err| err.with_url(url.clone()))?;
let e = serde_json::from_slice::<E>(&text).map_err(|e| {
Error::new(ErrorKind::Unexpected, "Received unexpected response")
.with_context("code", code.to_string())
.with_context("method", method.to_string())
.with_context("url", url.to_string())
.with_context("json", String::from_utf8_lossy(&text))
.with_source(e)
})?;
Err(e.into())
}
}
}
pub async fn do_execute<R, E: DeserializeOwned + Into<Error>>(
&self,
mut request: Request,
handler: impl FnOnce(&Response) -> Option<R>,
) -> Result<R> {
self.authenticate(&mut request).await?;
let method = request.method().clone();
let url = request.url().clone();
let response = self.client.execute(request).await?;
if let Some(ret) = handler(&response) {
Ok(ret)
} else {
let code = response.status();
let text = response
.bytes()
.await
.map_err(|err| err.with_url(url.clone()))?;
let e = serde_json::from_slice::<E>(&text).map_err(|e| {
Error::new(ErrorKind::Unexpected, "Received unexpected response")
.with_context("code", code.to_string())
.with_context("method", method.to_string())
.with_context("url", url.to_string())
.with_context("json", String::from_utf8_lossy(&text))
.with_source(e)
})?;
Err(e.into())
}
}
}