欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 新闻 > 社会 > Rust使用config加载Toml配置文件

Rust使用config加载Toml配置文件

2024/10/23 16:00:47 来源:https://blog.csdn.net/weixin_62799021/article/details/143006327  浏览:    关键词:Rust使用config加载Toml配置文件

前面提到用dotenvy读取配置文件到环境变量:https://juejin.cn/post/7411407565357449225

这里从配置文件中读取配置

添加依赖(这里使用yaml配置)

# 异步运行时
tokio = { version = "1", features = ["full"] }
# 序列化和反序列化数据
serde = { version = "1.0.127", features = ["derive"] }
# 动态修改配置
config = "0.14.0"
# 枚举处理
strum = { version = "0.26", features = ["derive"] }
# 假数据,用于测试
fake = { version = "2.9.2", features = ["derive", "uuid", "chrono"] }
# 异步锁
once_cell = "1.20.2"
# 错误处理
anyhow = "1.0.86"
# 序列化JSON
serde_json = "1.0.128"
# 自定义错误
thiserror = "1.0.64"
# 读取env
dotenvy = "0.15.7"
# 分布式跟踪的 SDK,用于采集监控数据,这里用其日志功能
tracing = "0.1.40"
# 日志
log = "0.4.22"
# 日志派生
log-derive = "0.4.1"
# 日志过滤器
tracing-subscriber = { version = "0.3", default-features = true, features = ["std","env-filter","registry","local-time","fmt",
] }
# 日志记录器
tracing-appender = "0.2.3"
# redis 客户端
redis = { version = "0.27.4", features = ["aio", "tokio-comp"] }
# 使用tokio实现的连接池,支持postgres、redis、redis cluster、rsmq等
bb8 = "1.11.0"
bb8-redis = "0.17.0"
# 异步 WebSocket
tokio-tungstenite = "0.24.0"

新建四个配置文件

  • default.toml:默认配置
  • development.toml:开发配置
  • production.toml:生产配置
  • test.toml:测试配置

settings\default.toml

debug = true
[network]
# 这里的配置会被development中的配置覆盖
host = "0.0.0.0"
port = 8080
[database]
url = "postgres://postgres:root123456@localhost:5432/postgres"
[redis]
url = "redis://127.0.0.1:6379"

settings\development.toml

debug = true
[network]
host = "127.0.0.1"
port = 8080
[database]
url = "postgres://postgres:root123456@localhost:5432/postgres"
[redis]
url = "redis://127.0.0.1:6379"

config\settings.rs配置信息

use std::env;
use config::{ Config, ConfigError, Environment, File };
use serde::Deserialize;#[derive(Debug, Deserialize)]
#[allow(unused)]
pub struct Network {pub host: String,pub port: u16,
}
#[derive(Debug, Deserialize)]
#[allow(unused)]
pub struct Database {pub url: String,
}
#[derive(Debug, Deserialize)]
#[allow(unused)]
pub struct Redis {pub url: String,
}
// 配置
#[derive(Debug, Deserialize)]
#[allow(unused)]
pub struct Settings {pub debug: bool,pub network: Network,pub database: Database,pub redis: Redis,
}impl Settings {pub fn new() -> Result<Self, ConfigError> {// 从环境变量中获取运行模式development、production、testlet run_mode = env::var("RUN_MODE").unwrap_or_else(|_| "development".into());let s = Config::builder()//从“default”配置文件开始合并,后来合并的配置会替换“default”中的配置.add_source(File::with_name("settings/development"))//添加到当前环境文件中//默认为'development'//注意这个功能是可选的,所以可以使用.required(false)来忽略它.add_source(File::with_name(&format!("settings/{}", run_mode)).required(false))// 添加本地配置文件// 不应该加入git.add_source(File::with_name("settings/local").required(false))// 从环境中添加配置(带有APP前缀)// 如. .' APP_DEBUG=1 ./target/app '将设置' debug '键.add_source(Environment::with_prefix("app"))// 您也可以通过编程方式覆盖配置,database.url将覆盖配置文件中的值// .set_override("database.url", "postgres://")?.build()?;// 访问配置println!("debug: {:?}", s.get_bool("debug"));println!("database: {:?}", s.get::<String>("database.url"));// 反序列化s.try_deserialize()}/// 如果你想封装数据,仅将Settings设置为pub,其他的修改为private,通过自定义函数使用/// ```/// let settings = Settings::new().expect("Failed to load settings");/// let db_url = settings.get_db_url();/// println!("Database URL: {:?}", db_url);/// ```pub fn get_newtwork(&self) -> String {format!("{}:{}", &self.network.host, &self.network.port).parse().unwrap()}
}

使用

let settings = Settings::new().expect("Failed to load settings");
println!("{:?}", settings);

结果

database: Database { url:"postgres://postgres:root123456@localhost:5432/postgres" }, 
redis: Redis { url: "redis://127.0.0.1:6379" } }

在Axum中使用

在路由上使用Extension,任何实现了clone的结构体都可以被提供路由

    let settings = Settings::new().expect("Failed to load settings");println!("{:?}", settings);// 这里将default中的0.0.0.0替换为了127.0.0.1println!("{:?}", settings.network.host);println!("{:?}", settings.database.url);// 将读取的配置格式化为我们需要的字符串let server_url = format!("{}:{}", settings.network.host, settings.network.port);// 从配置文件中读取let listener = tokio::net::TcpListener::bind(server_url).await.unwrap();

从环境变量中选择开发测试配置

1、设置配置

.env

# 指定当前配置文件
RUN_MODE=development
# 开启调试模式
# APP__开头的配置会被自动注入到环境变量中替换配置文件中的配置
# APP__DATABASE_URL=your_database_url
RUST_BACKTRACE=1
# 此配置仅仅用于seaormcli
DATABASE_URL=postgres://postgres:root123456@localhost:5432/postgres
LOG_LEVEL=TRACE

设置默认配置文件settings\default.toml

debug = true[server]
host = "0.0.0.0"
port = 8080

设置开发配置文件settings\development.toml

debug = true
# 指定开发环境配置
profile = "development"
[tracing]
log_level = "debug"
[server]
host = "127.0.0.1"
port = 9090
[db]
password = "root123456"
host = "127.0.0.1"
port = 5_432
max_connections = 100
database_name = "postgres"[redis]
host = "127.0.0.1"
port = 6_379

读取环境变量config\env.rs

use std::str::FromStr;
use config::ConfigError;
use super::profile::Profile;
// 获取环境变量配置,例如本例中prefix为"APP",环境变量前缀分隔符和环境变量分隔符都设置为"__"
// 使用,则将环境变量中的DATABASE_URL作为配置项
pub fn get_env_source(prefix: &str) -> config::Environment {// 创建新的环境变量配置config::Environment::with_prefix(prefix)// 设置环境变量前缀分隔符和环境变量分隔符.prefix_separator("__").separator("__")
}
// 从环境变量中获取profile,开发环境还是测试环境
pub fn get_profile() -> Result<Profile, config::ConfigError> {std::env::var("RUN_MODE").map(|env| Profile::from_str(&env).map_err(|e| ConfigError::Message(e.to_string()))).unwrap_or_else(|_e| Ok(Profile::Dev))
}

数据库配置config\database.rs

use serde::Deserialize;
#[derive(Debug, Deserialize, Clone)]
pub struct DatabaseConfig {pub username: String,pub password: String,pub port: u16,pub host: String,pub max_connections: u32,pub database_name: String,
}impl DatabaseConfig {pub fn get_url(&self) -> String {Self::create_url(&self.username, &self.password, &self.host, self.port, &self.database_name)}// 创建数据库连接字符串pub fn create_url(username: &str,password: &str,host: &str,port: u16,database_name: &str) -> String {format!("postgres://{username}:{password}@{host}:{port}/{database_name}")}
}

redis配置config\redis.rs

pub use redis::Client;
use serde::Deserialize;#[derive(Debug, Deserialize, Clone)]
pub struct RedisConfig {pub port: u16,pub host: String,
}impl RedisConfig {// 获取redis连接地址pub fn get_url(&self) -> String {format!("redis://{host}:{port}",host = self.host,port = self.port,)}
}

服务端配置config\server.rs

use std::net::{AddrParseError, SocketAddr};
use serde::Deserialize;
#[derive(Debug, Deserialize, Clone)]
pub struct ServerConfig {pub host: String,pub port: u16,
}impl ServerConfig {// 获取地址pub fn get_addr(&self) -> String {format!("{}:{}", self.host, self.port)}// 获取http地址pub fn get_http_addr(&self) -> String {format!("http://{}:{}", self.host, self.port)}// 获取socket地址pub fn get_socket_addr(&self) -> Result<SocketAddr, AddrParseError> {self.get_addr().parse()}
}#[cfg(test)]
pub mod tests {use super::*;#[test]pub fn app_config_http_addr_test() {let config = ServerConfig {host: "127.0.0.1".to_string(),port: 8080,};assert_eq!(config.get_http_addr(), "http://127.0.0.1:8080");}#[test]pub fn app_config_socket_addr_test() {let config = ServerConfig {host: "127.0.0.1".to_string(),port: 8080,};assert_eq!(config.get_socket_addr().unwrap().to_string(), "127.0.0.1:8080");}
}

日志配置config\tracing.rs

use serde::Deserialize;
#[derive(Debug, Deserialize, Clone)]
pub struct TracingConfig {log_level: String,
}
impl TracingConfig {// 获取地址pub fn get_log_level(&self) -> String {format!("{}", self.log_level)}
}
#[cfg(test)]
pub mod tests {use super::*;#[test]pub fn app_config_http_addr_test() {let config = TracingConfig {log_level: "debug".to_string(),};assert_eq!(config.get_log_level(), "debug");}
}

读取配置类型config\profile.rs

use serde::Deserialize;#[derive(Debug,strum::Display,strum::EnumString,Deserialize,PartialEq,Eq,PartialOrd,Ord,Clone,Copy,
)]
pub enum Profile {#[serde(rename = "test")]// 序列化、反序列化重命名#[strum(serialize = "test")]// 枚举序列化、反序列化重命名Test,#[serde(rename = "development")]#[strum(serialize = "development")]Dev,#[serde(rename = "production")]#[strum(serialize = "production")]Prod,
}

构建配置config\mod.rs


pub mod database;
pub mod env;
pub mod redis;
pub mod server;
pub mod profile;
pub mod tracing;
use std::str::FromStr;
use database::DatabaseConfig;
use profile::Profile;
use redis::RedisConfig;
use server::ServerConfig;
use ::tracing::info;
use config::{ConfigError, Environment};
use serde::Deserialize;
use tracing::TracingConfig;
use utils::dir::get_project_root;
use crate::utils;#[derive(Debug, Deserialize, Clone)]
pub struct AppConfig {pub profile: Profile,pub tracing: TracingConfig,pub server: ServerConfig,pub db: DatabaseConfig,pub redis: RedisConfig,
}impl AppConfig {pub fn read(env_src: Environment) -> Result<Self, config::ConfigError> {// 获取配置文件目录let config_dir = get_settings_dir()?;info!("config_dir: {:#?}", config_dir);// 获取配置文件环境let run_mode = std::env::var("RUN_MODE").map(|env| Profile::from_str(&env).map_err(|e| ConfigError::Message(e.to_string()))).unwrap_or_else(|_e| Ok(Profile::Dev))?;// 当前配置文件名let profile_filename = format!("{run_mode}.toml");// 获取配置let config = config::Config::builder()// 添加默认配置.add_source(config::File::from(config_dir.join("default.toml")))// 添加自定义前缀配置.add_source(config::File::from(config_dir.join(profile_filename)))// 添加环境变量.add_source(env_src).build()?;info!("Successfully read config profile: {run_mode}.");// 反序列化config.try_deserialize()}
}
// 获取配置文件目录
pub fn get_settings_dir() -> Result<std::path::PathBuf, ConfigError> {Ok(get_project_root().map_err(|e| ConfigError::Message(e.to_string()))?.join("settings"))
}
// // 获取静态文件目录
// pub fn get_static_dir() -> Result<std::path::PathBuf, ConfigError> {
//     Ok(get_project_root()
//         .map_err(|e| ConfigError::Message(e.to_string()))?
//         .join("static"))
// }#[cfg(test)]
mod tests {use crate::config::profile::Profile;use self::env::get_env_source;pub use super::*;#[test]pub fn test_profile_to_string() {// 设置dev模式let profile: Profile = Profile::try_from("development").unwrap();println!("profile: {:#?}", profile);assert_eq!(profile, Profile::Dev)}#[test]pub fn test_read_app_config_prefix() {// 读取配置let config = AppConfig::read(get_env_source("APP")).unwrap();println!("config: {:#?}", config);}
}

2、设置客户端

InfraError是自定义的错误,你可以改成标准库的错误

redis客户端client\redis.rs

use redis::{Client, RedisError};
use std::time::Duration;
use tracing::log::info;
use test_context::AsyncTestContext;
use crate::{config::{AppConfig,redis::RedisConfig}, constant::CONFIG};
use super::builder::ClientBuilder;
// 类型别名
pub type RedisClient = redis::Client;pub trait RedisClientExt: ClientBuilder {fn ping(&self) -> impl std::future::Future<Output = Result<Option<String>, RedisError>>;fn set(&self,key: &str,value: &str,expire: Duration,) -> impl std::future::Future<Output = Result<(), RedisError>>;fn exist(&self, key: &str) -> impl std::future::Future<Output = Result<bool, RedisError>>;fn get(&self,key: &str,) -> impl std::future::Future<Output = Result<Option<String>, RedisError>>;fn del(&self, key: &str) -> impl std::future::Future<Output = Result<bool, RedisError>>;fn ttl(&self, key: &str) -> impl std::future::Future<Output = Result<i64, RedisError>>;
}impl ClientBuilder for RedisClient {fn build_from_config(config: &AppConfig) -> Result<RedisClient,InfraError> {Ok(redis::Client::open(config.redis.get_url())?)}
}pub struct RedisTestContext {pub config: RedisConfig,pub redis: RedisClient,
}impl AsyncTestContext for RedisTestContext {async fn setup() -> Self {info!("setup redis config for the test");// let database_name = util::string::generate_random_string_with_prefix("test_db");let redis = RedisClient::build_from_config(&CONFIG).unwrap();Self {config: CONFIG.redis.clone(),redis,}}
}impl RedisClientExt for Client {async fn ping(&self) -> Result<Option<String>, RedisError> {let mut conn = self.get_multiplexed_async_connection().await?;let value: Option<String> = redis::cmd("PING").query_async(&mut conn).await?;info!("ping redis server");Ok(value)}async fn set(&self, key: &str, value: &str, expire: Duration) -> Result<(), RedisError> {let mut conn = self.get_multiplexed_async_connection().await?;let msg: String = redis::cmd("SET").arg(&[key, value]).query_async(&mut conn).await?;info!("set key redis: {msg}");let msg: i32 = redis::cmd("EXPIRE").arg(&[key, &expire.as_secs().to_string()]).query_async(&mut conn).await?;info!("set expire time redis: {msg}");Ok(())}async fn exist(&self, key: &str) -> Result<bool, RedisError> {let mut conn = self.get_multiplexed_async_connection().await?;let value: bool = redis::cmd("EXISTS").arg(key).query_async(&mut conn).await?;info!("check key exists: {key}");Ok(value)}async fn get(&self, key: &str) -> Result<Option<String>, RedisError> {let mut conn = self.get_multiplexed_async_connection().await?;let value: Option<String> = redis::cmd("GET").arg(key).query_async(&mut conn).await?;info!("get value: {key}");Ok(value)}async fn del(&self, key: &str) -> Result<bool, RedisError> {let mut conn = self.get_multiplexed_async_connection().await?;let value: i32 = redis::cmd("DEL").arg(key).query_async(&mut conn).await?;info!("delete value: {key}");Ok(value == 1)}async fn ttl(&self, key: &str) -> Result<i64, RedisError> {let mut conn = self.get_multiplexed_async_connection().await?;let value: i64 = redis::cmd("TTL").arg(key).query_async(&mut conn).await?;info!("get TTL value: {key}");Ok(value)}
}#[cfg(test)]
mod tests {use crate::constant::REDIS;use super::*;use fake::{Fake, Faker};use uuid::Uuid;#[tokio::test]async fn test_ping_redis_server() {let resp = REDIS.ping().await.unwrap();let pong = "PONG";assert!(matches!(resp, Some(p) if p == pong));}#[tokio::test]async fn test_set_key_redis() {let key: String = Faker.fake();let value = Uuid::new_v4().to_string();REDIS.set(&key, &value, Duration::from_secs(5)).await.unwrap();let resp = REDIS.get(&key).await.unwrap();assert!(matches!(resp, Some(v) if v == value));let resp = REDIS.ttl(&key).await.unwrap();assert!(resp > 0);}#[tokio::test]async fn test_exist_key_redis() {let key: String = Faker.fake();let value = Uuid::new_v4().to_string();REDIS.set(&key, &value, Duration::from_secs(4)).await.unwrap();let resp = REDIS.get(&key).await.unwrap();assert!(matches!(resp, Some(v) if v == value));let resp = REDIS.exist(&key).await.unwrap();assert!(resp);let key: String = Faker.fake();let resp = REDIS.exist(&key).await.unwrap();assert!(!resp);}#[tokio::test]async fn test_del_key_redis() {let key: String = Faker.fake();let value = Uuid::new_v4().to_string();REDIS.set(&key, &value, Duration::from_secs(4)).await.unwrap();let resp = REDIS.get(&key).await.unwrap();assert!(matches!(resp, Some(v) if v == value));let resp = REDIS.exist(&key).await.unwrap();assert!(resp);REDIS.del(&key).await.unwrap();let resp = REDIS.exist(&key).await.unwrap();assert!(!resp);}#[tokio::test]async fn test_key_ttl_redis() {let key: String = Faker.fake();let ttl = 4;let value = Uuid::new_v4().to_string();REDIS.set(&key, &value, Duration::from_secs(ttl)).await.unwrap();let resp = REDIS.get(&key).await.unwrap();assert!(matches!(resp, Some(v) if v == value));let resp = REDIS.ttl(&key).await.unwrap();assert!(resp <= ttl as i64 && resp > 0);REDIS.del(&key).await.unwrap();let resp = REDIS.ttl(&key).await.unwrap();assert!(resp < 0);}
}

数据库客户端client\database.rs

use std::time::Duration;
use common::error::InfraError;
use sea_orm::{ConnectOptions, Database, DatabaseConnection};
use tracing::info;
use crate::config::AppConfig;// 类型别名
pub type DatabaseClient = DatabaseConnection;pub trait DatabaseClientExt: Sized {fn build_from_config(config: &AppConfig) -> impl std::future::Future<Output = Result<Self,InfraError>>;
}impl DatabaseClientExt for DatabaseClient {async fn build_from_config(config: &AppConfig) -> Result<Self,InfraError> {let mut opt = ConnectOptions::new(config.db.get_url());opt.max_connections(100).min_connections(5).connect_timeout(Duration::from_secs(8)).acquire_timeout(Duration::from_secs(8)).idle_timeout(Duration::from_secs(8)).max_lifetime(Duration::from_secs(8)).sqlx_logging(false).sqlx_logging_level(log::LevelFilter::Info);let db = Database::connect(opt).await?;info!("Database connected");Ok(db)}
}
#[cfg(test)]
mod tests {use super::*;use crate::constant::CONFIG;#[tokio::test]async fn test_ping_database() {DatabaseClient::build_from_config(&CONFIG).await.unwrap().ping().await.expect("Database ping failed.")}
}

日志客户端logger\log.rs

use std::env;
use tracing::dispatcher::set_global_default;
use tracing_appender::rolling::daily;
use tracing_subscriber::{ fmt::{ self, time::UtcTime }, layer::SubscriberExt, EnvFilter, Registry };
pub struct LogGuard(pub std::sync::Arc<tracing_appender::non_blocking::WorkerGuard>);
pub async fn setup_logs(log_level: Option<String>) -> LogGuard {// 读取日志级别let log_level = log_level.unwrap_or_else(|| "debug".to_string());// 设置日志级别过滤器let env_filter = EnvFilter::try_from_default_env().or_else(|_| EnvFilter::try_new(log_level)).unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into());// 创建每日滚动的日志文件写入器let file_appender = daily("logs", "app.log");let (non_blocking_appender, guard) = tracing_appender::non_blocking(file_appender);// 初始化日志订阅者,同时输出到控制台和文件let subscriber = Registry::default().with(env_filter)// 控制台输出层.with(fmt::layer().with_ansi(true).with_target(false))// 文件输出层.with(fmt::Layer::new().with_writer(non_blocking_appender).with_timer(UtcTime::rfc_3339()));// 设置为全局日志订阅者set_global_default(subscriber.into()).expect("setting default subscriber failed");LogGuard(std::sync::Arc::new(guard))
}

构建客户端client\builder.rs

use crate::config::AppConfig;
// 传输配置文件到客户端
pub trait ClientBuilder: Sized {fn build_from_config(config: &AppConfig) -> Result<Self,InfraError>;
}

3、使用AppState存储配置和数据库链接

AppStatestate\mod.rs

use std::sync::Arc;
use anyhow::Error;
use infrastructure::{client::{builder::ClientBuilder, database::{DatabaseClient, DatabaseClientExt}, redis::RedisClient}, config::AppConfig};
// 使用Arc来共享数据,避免数据的复制和所有权的转移
#[derive(Clone)]
pub struct AppState {pub config: Arc<AppConfig>,pub redis: Arc<RedisClient>,pub db: Arc<DatabaseClient>,
}impl AppState {pub async fn new(config: AppConfig) -> Result<Self,Error> {let redis = Arc::new(RedisClient::build_from_config(&config)?);let db = Arc::new(DatabaseClient::build_from_config(&config).await?);Ok(Self {config: Arc::new(config),db,redis,})}
}

使用AppState

use application::dto::response_dto::{ EmptyData, Res };
use axum::{ http::{ HeaderValue, Method }, response::IntoResponse, Router };
use infrastructure::{config::{env::get_env_source, AppConfig}, logger::log};
use tokio::signal;
use tower_http::cors::CorsLayer;
use tracing::info;
use crate::state::AppState;pub async fn start() -> anyhow::Result<()> {// 加载.env 环境配置文件,成功返回包含的值,失败返回Nonedotenvy::dotenv().ok();// 加载AppStatelet config = AppConfig::read(get_env_source("APP"))?;let state = AppState::new(config.clone()).await?;info!("The initialization of Settings was successful");// 初始化日志let guard = log::setup_logs(Some(config.tracing.get_log_level())).await;info!("The initialization of Tracing was successful");// 路由以及后备处理let app = setup_routes().await.fallback(handler_404).with_state(state);// 端口绑定let listener = tokio::net::TcpListener::bind(config.server.get_socket_addr()?).await.unwrap();// 调用 `tracing` 包的 `info!`,放在启动服务之前,因为会被moveinfo!("🚀 listening on {}", &listener.local_addr().unwrap());// 启动服务axum::serve(listener, app.into_make_service()).with_graceful_shutdown(shutdown_signal()).await.unwrap();// 在程序结束前释放资源drop(guard);Ok(())
}/// 嵌套路由
pub async fn setup_routes() -> Router<AppState> {Router::new()// .nest("/users", setup_user_routes().await)//请注意,对于某些请求类型,例如发布content-style:app/json//需要添加“.allow_heads([http::header::CONTENT_GROUP])”.layer(CorsLayer::new().allow_origin("http://localhost:3000".parse::<HeaderValue>().unwrap()).allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE]))
}/// 404处理
async fn handler_404() -> impl IntoResponse {Res::<EmptyData>::with_not_found()
}
/// 优雅关闭
async fn shutdown_signal() {let ctrl_c = async {signal::ctrl_c().await.expect("failed to install Ctrl+C handler");};#[cfg(unix)]let terminate = async {signal::unix::signal(signal::unix::SignalKind::terminate()).expect("failed to install signal handler").recv().await;};#[cfg(not(unix))]let terminate = std::future::pending::<()>();tokio::select! {_ = ctrl_c => {},_ = terminate => {},}println!("signal received, starting graceful shutdown");
}

项目地址:https://github.com/VCCICCV/MGR/tree/main/auth

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com