iceberg_catalog_rest/
client.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashMap;
19use std::fmt::{Debug, Formatter};
20
21use http::StatusCode;
22use iceberg::{Error, ErrorKind, Result};
23use reqwest::header::HeaderMap;
24use reqwest::{Client, IntoUrl, Method, Request, RequestBuilder, Response};
25use serde::de::DeserializeOwned;
26use tokio::sync::Mutex;
27
28use crate::RestCatalogConfig;
29use crate::types::{ErrorResponse, TokenResponse};
30
31pub(crate) struct HttpClient {
32    client: Client,
33
34    /// The token to be used for authentication.
35    ///
36    /// It's possible to fetch the token from the server while needed.
37    token: Mutex<Option<String>>,
38    /// The token endpoint to be used for authentication.
39    token_endpoint: String,
40    /// The credential to be used for authentication.
41    credential: Option<(Option<String>, String)>,
42    /// Extra headers to be added to each request.
43    extra_headers: HeaderMap,
44    /// Extra oauth parameters to be added to each authentication request.
45    extra_oauth_params: HashMap<String, String>,
46}
47
48impl Debug for HttpClient {
49    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
50        f.debug_struct("HttpClient")
51            .field("client", &self.client)
52            .field("extra_headers", &self.extra_headers)
53            .finish_non_exhaustive()
54    }
55}
56
57impl HttpClient {
58    /// Create a new http client.
59    pub fn new(cfg: &RestCatalogConfig) -> Result<Self> {
60        let extra_headers = cfg.extra_headers()?;
61        Ok(HttpClient {
62            client: cfg.client().unwrap_or_default(),
63            token: Mutex::new(cfg.token()),
64            token_endpoint: cfg.get_token_endpoint(),
65            credential: cfg.credential(),
66            extra_headers,
67            extra_oauth_params: cfg.extra_oauth_params(),
68        })
69    }
70
71    /// Update the http client with new configuration.
72    ///
73    /// If cfg carries new value, we will use cfg instead.
74    /// Otherwise, we will keep the old value.
75    pub fn update_with(self, cfg: &RestCatalogConfig) -> Result<Self> {
76        let extra_headers = (!cfg.extra_headers()?.is_empty())
77            .then(|| cfg.extra_headers())
78            .transpose()?
79            .unwrap_or(self.extra_headers);
80        Ok(HttpClient {
81            client: cfg.client().unwrap_or(self.client),
82            token: Mutex::new(cfg.token().or_else(|| self.token.into_inner())),
83            token_endpoint: if !cfg.get_token_endpoint().is_empty() {
84                cfg.get_token_endpoint()
85            } else {
86                self.token_endpoint
87            },
88            credential: cfg.credential().or(self.credential),
89            extra_headers,
90            extra_oauth_params: if !cfg.extra_oauth_params().is_empty() {
91                cfg.extra_oauth_params()
92            } else {
93                self.extra_oauth_params
94            },
95        })
96    }
97
98    /// This API is testing only to assert the token.
99    #[cfg(test)]
100    pub(crate) async fn token(&self) -> Option<String> {
101        let mut req = self
102            .request(Method::GET, &self.token_endpoint)
103            .build()
104            .unwrap();
105        self.authenticate(&mut req).await.ok();
106        self.token.lock().await.clone()
107    }
108
109    async fn exchange_credential_for_token(&self) -> Result<String> {
110        // Credential must exist here.
111        let (client_id, client_secret) = self.credential.as_ref().ok_or_else(|| {
112            Error::new(
113                ErrorKind::DataInvalid,
114                "Credential must be provided for authentication",
115            )
116        })?;
117
118        let mut params = HashMap::with_capacity(4);
119        params.insert("grant_type", "client_credentials");
120        if let Some(client_id) = client_id {
121            params.insert("client_id", client_id);
122        }
123        params.insert("client_secret", client_secret);
124        params.extend(
125            self.extra_oauth_params
126                .iter()
127                .map(|(k, v)| (k.as_str(), v.as_str())),
128        );
129
130        let mut auth_req = self
131            .request(Method::POST, &self.token_endpoint)
132            .form(&params)
133            .build()?;
134        // extra headers add content-type application/json header it's necessary to override it with proper type
135        // note that form call doesn't add content-type header if already present
136        auth_req.headers_mut().insert(
137            http::header::CONTENT_TYPE,
138            http::HeaderValue::from_static("application/x-www-form-urlencoded"),
139        );
140        let auth_url = auth_req.url().clone();
141        let auth_resp = self.client.execute(auth_req).await?;
142
143        let auth_res: TokenResponse = if auth_resp.status() == StatusCode::OK {
144            let text = auth_resp
145                .bytes()
146                .await
147                .map_err(|err| err.with_url(auth_url.clone()))?;
148            Ok(serde_json::from_slice(&text).map_err(|e| {
149                Error::new(
150                    ErrorKind::Unexpected,
151                    "Failed to parse response from rest catalog server!",
152                )
153                .with_context("operation", "auth")
154                .with_context("url", auth_url.to_string())
155                .with_context("json", String::from_utf8_lossy(&text))
156                .with_source(e)
157            })?)
158        } else {
159            let code = auth_resp.status();
160            let text = auth_resp
161                .bytes()
162                .await
163                .map_err(|err| err.with_url(auth_url.clone()))?;
164            let e: ErrorResponse = serde_json::from_slice(&text).map_err(|e| {
165                Error::new(ErrorKind::Unexpected, "Received unexpected response")
166                    .with_context("code", code.to_string())
167                    .with_context("operation", "auth")
168                    .with_context("url", auth_url.to_string())
169                    .with_context("json", String::from_utf8_lossy(&text))
170                    .with_source(e)
171            })?;
172            Err(Error::from(e))
173        }?;
174        Ok(auth_res.access_token)
175    }
176
177    /// Invalidate the current token without generating a new one. On the next request, the client
178    /// will attempt to generate a new token.
179    pub(crate) async fn invalidate_token(&self) -> Result<()> {
180        *self.token.lock().await = None;
181        Ok(())
182    }
183
184    /// Invalidate the current token and set a new one. Generates a new token before invalidating
185    /// the current token, meaning the old token will be used until this function acquires the lock
186    /// and overwrites the token.
187    ///
188    /// If credential is invalid, or the request fails, this method will return an error and leave
189    /// the current token unchanged.
190    pub(crate) async fn regenerate_token(&self) -> Result<()> {
191        let new_token = self.exchange_credential_for_token().await?;
192        *self.token.lock().await = Some(new_token.clone());
193        Ok(())
194    }
195
196    /// Authenticates the request by adding a bearer token to the authorization header.
197    ///
198    /// This method supports three authentication modes:
199    ///
200    /// 1. **No authentication** - Skip authentication when both `credential` and `token` are missing.
201    /// 2. **Token authentication** - Use the provided `token` directly for authentication.
202    /// 3. **OAuth authentication** - Exchange `credential` for a token, cache it, then use it for authentication.
203    ///
204    /// When both `credential` and `token` are present, `token` takes precedence.
205    ///
206    /// # TODO: Support automatic token refreshing.
207    async fn authenticate(&self, req: &mut Request) -> Result<()> {
208        // Clone the token from lock without holding the lock for entire function.
209        let token = self.token.lock().await.clone();
210
211        if self.credential.is_none() && token.is_none() {
212            return Ok(());
213        }
214
215        // Either use the provided token or exchange credential for token, cache and use that
216        let token = match token {
217            Some(token) => token,
218            None => {
219                let token = self.exchange_credential_for_token().await?;
220                // Update token so that we use it for next request instead of
221                // exchanging credential for token from the server again
222                *self.token.lock().await = Some(token.clone());
223                token
224            }
225        };
226
227        // Insert token in request.
228        req.headers_mut().insert(
229            http::header::AUTHORIZATION,
230            format!("Bearer {token}").parse().map_err(|e| {
231                Error::new(
232                    ErrorKind::DataInvalid,
233                    "Invalid token received from catalog server!",
234                )
235                .with_source(e)
236            })?,
237        );
238
239        Ok(())
240    }
241
242    #[inline]
243    pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder {
244        self.client
245            .request(method, url)
246            .headers(self.extra_headers.clone())
247    }
248
249    /// Executes the given `Request` and returns a `Response`.
250    pub async fn execute(&self, mut request: Request) -> Result<Response> {
251        request.headers_mut().extend(self.extra_headers.clone());
252        Ok(self.client.execute(request).await?)
253    }
254
255    // Queries the Iceberg REST catalog after authentication with the given `Request` and
256    // returns a `Response`.
257    pub async fn query_catalog(&self, mut request: Request) -> Result<Response> {
258        self.authenticate(&mut request).await?;
259        self.execute(request).await
260    }
261}
262
263/// Deserializes a catalog response into the given [`DeserializedOwned`] type.
264///
265/// Returns an error if unable to parse the response bytes.
266pub(crate) async fn deserialize_catalog_response<R: DeserializeOwned>(
267    response: Response,
268) -> Result<R> {
269    let bytes = response.bytes().await?;
270
271    serde_json::from_slice::<R>(&bytes).map_err(|e| {
272        Error::new(
273            ErrorKind::Unexpected,
274            "Failed to parse response from rest catalog server",
275        )
276        .with_context("json", String::from_utf8_lossy(&bytes))
277        .with_source(e)
278    })
279}
280
281/// Deserializes a unexpected catalog response into an error.
282pub(crate) async fn deserialize_unexpected_catalog_error(response: Response) -> Error {
283    let err = Error::new(
284        ErrorKind::Unexpected,
285        "Received response with unexpected status code",
286    )
287    .with_context("status", response.status().to_string())
288    .with_context("headers", format!("{:?}", response.headers()));
289
290    let bytes = match response.bytes().await {
291        Ok(bytes) => bytes,
292        Err(err) => return err.into(),
293    };
294
295    if bytes.is_empty() {
296        return err;
297    }
298    err.with_context("json", String::from_utf8_lossy(&bytes))
299}