欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 科技 > 名人名企 > 径向基函数神经网络RBFNN案例实操

径向基函数神经网络RBFNN案例实操

2024/11/30 18:32:41 来源:https://blog.csdn.net/wwz1751879/article/details/142213939  浏览:    关键词:径向基函数神经网络RBFNN案例实操

简介

(来自ChatGPT的介绍,如有更正建议请指出)
径向基函数神经网络(Radial Basis Function Neural Network, RBFNN)是一种特殊的前馈神经网络,其结构和特点与其他常见的神经网络有所不同,主要表现在以下几个方面:

  1. 网络结构
  • 三层结构:RBF神经网络通常由三层组成:输入层、隐藏层和输出层。与常规的多层感知器(MLP)相比,其隐藏层的激活函数采用径向基函数(通常是高斯函数)。
  • 隐层节点:每个隐藏层节点表示一个中心点,输入到该节点的距离通过径向基函数进行转换,因此隐层节点的作用类似于一种局部响应单位。
  1. 径向基函数
  • 局部特性:径向基函数的输出是基于输入数据与某个中心的距离,输出在离中心较近的地方较大,离中心较远时趋近于零。常见的径向基函数包括高斯函数、多项式、逆多二次函数等。
  • 非线性映射:由于径向基函数可以对输入进行非线性映射,RBFNN具备了处理复杂非线性问题的能力。
  1. 训练过程
  • 两阶段训练:RBFNN的训练分为两个阶段。首先是通过无监督学习(如K均值聚类)确定隐藏层节点的中心,然后在第二阶段使用线性优化算法(如最小二乘法)来训练输出层的权重。这种方式相比传统的反向传播算法收敛速度更快。
  • 少量参数优化:因为RBF神经网络的输出层通常是线性权重,因此需要优化的参数数量较少,训练速度快。
  1. 应用场景
  • 函数逼近与插值:RBFNN特别适合用于函数逼近和插值问题,因为它对输入空间进行局部化响应,能够较好地拟合复杂的函数。
  • 分类和回归:RBFNN可以用来解决分类和回归问题,尤其适合那些具有明显聚类特性的任务。
  1. 特点总结
  • 良好的逼近能力:RBFNN具有良好的逼近非线性函数的能力,并且在理论上可以逼近任意连续函数。
  • 快速训练:由于采用两阶段训练方法,RBF神经网络的训练速度较快。
  • 鲁棒性差:RBFNN对噪声数据的敏感度较高,容易出现过拟合现象。
  • 需要确定中心和宽度参数:选择合适的中心和宽度参数对网络性能有很大影响,选择不当会导致模型性能较差。

RBF神经网络的这些特点使其在一定的应用场景中表现突出,尤其是在需要快速训练和局部响应的情况下效果尤为显著。

案例实操

案例概述

平面上有一些点,将这些点映射到马鞍面的对应点上。
输入为平面点空间坐标(x,y,z),输出为马鞍面点空间坐标(x,y,z)
训练,预测一些点的对应坐标

函数主体

# 使用 matplotlib 的交互模式
%matplotlib widgetimport torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans# RBF Layer
class RBFLayer(nn.Module):def __init__(self, in_features, out_features, centers, gamma=1.0):super(RBFLayer, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.centers = nn.Parameter(torch.tensor(centers, dtype=torch.float32), requires_grad=False)self.gamma = gammadef forward(self, x):# 计算输入到每个中心的欧式距离x = x.unsqueeze(1).repeat(1, self.out_features, 1)centers = self.centers.unsqueeze(0).repeat(x.size(0), 1, 1)distance = torch.norm(x - centers, dim=2)# 应用 RBF 函数 (高斯函数)return torch.exp(-self.gamma * distance ** 2)# RBFNN Model
class RBFNN(nn.Module):def __init__(self, input_dim, output_dim, num_centers, gamma=1.0):super(RBFNN, self).__init__()self.num_centers = num_centersself.gamma = gamma# 使用 KMeans 聚类确定中心kmeans = KMeans(n_clusters=num_centers, random_state=42)centers = kmeans.fit(X_train).cluster_centers_# RBF 层self.rbflayer = RBFLayer(input_dim, num_centers, centers, gamma)# 输出层,线性回归self.linear = nn.Linear(num_centers, output_dim)def forward(self, x):# 通过 RBF 层和线性层rbf_out = self.rbflayer(x)return self.linear(rbf_out)

训练

# 生成示例数据,平面包络体 -> 马鞍面包络体
n_samples = 500
X_train = np.random.uniform(-1, 1, (n_samples, 2))  # 平面 (x, y)
# 计算马鞍面
y_train_saddle = X_train[:, 0]**2 - X_train[:, 1]**2# 将 y_train_saddle 重新形状为列向量
y_train_saddle = y_train_saddle.reshape(-1, 1)# 水平堆叠 X_train 的列和 y_train_saddle
y_train = np.hstack([X_train, y_train_saddle])
X_train = np.hstack([X_train, np.zeros((n_samples, 1))])  # 平面 z = 0# 转换为 PyTorch tensor
X_train_tensor = torch.FloatTensor(X_train)
y_train_tensor = torch.FloatTensor(y_train)# 初始化模型
input_dim = 3  # 输入为平面上的三维坐标
output_dim = 3  # 输出为曲面上的三维坐标
num_centers = 50  # RBF 隐藏层节点
gamma = 10model = RBFNN(input_dim=input_dim, output_dim=output_dim, num_centers=num_centers, gamma=gamma)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
epochs = 1000
for epoch in range(epochs):model.train()# 前向传播outputs = model(X_train_tensor)loss = criterion(outputs, y_train_tensor)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()if (epoch + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

结果显示:

Epoch [100/1000], Loss: 0.1242
Epoch [200/1000], Loss: 0.0590
Epoch [300/1000], Loss: 0.0344
Epoch [400/1000], Loss: 0.0234
Epoch [500/1000], Loss: 0.0178
Epoch [600/1000], Loss: 0.0146
Epoch [700/1000], Loss: 0.0125
Epoch [800/1000], Loss: 0.0110
Epoch [900/1000], Loss: 0.0099
Epoch [1000/1000], Loss: 0.0089

预测

# 预测新点
num_points = 5
x_new = np.linspace(-1, 1, num_points)
y_new = np.linspace(-1, 1, num_points)# 生成网格数据并展开为二维点
X_new, Y_new = np.meshgrid(x_new, y_new)
X_new = X_new.flatten()
Y_new = Y_new.flatten()# 创建新点数组(z 坐标初始化为 0)
new_points = np.column_stack([X_new, Y_new, np.zeros_like(X_new)])# 转换为 PyTorch tensor
new_points_tensor = torch.FloatTensor(new_points)# 使用模型进行预测
model.eval()
with torch.no_grad():predicted_points = model(new_points_tensor).detach().numpy()

误差计算

# 假设预测点和实际点的坐标已经存在
# new_points 是输入点 (x, y, z)
# predicted_points 是模型的预测点 (x, y, z)# 生成实际的 z 值
actual_z = new_points[:, 0]**2 - new_points[:, 1]**2# 计算预测的 z 值
predicted_z = predicted_points[:, 2]# 计算误差
errors = np.abs(predicted_z - actual_z)# 打印误差
print("Errors (absolute difference):", errors)# 计算均方误差
mse = np.mean(errors**2)
print("Mean Squared Error (MSE):", mse)# 计算均方根误差
rmse = np.sqrt(mse)
print("Root Mean Squared Error (RMSE):", rmse)

显示:

Errors (absolute difference): [0.00164947 0.28711906 0.33240026 0.21893883 0.06195563 0.4516308
0.01005104 0.04899709 0.00997831 0.30071303 0.38031268 0.04964676
0.00083099 0.0416033 0.33901107 0.41264805 0.03607443 0.03517236
0.02895199 0.25411072 0.01035234 0.20135468 0.30079675 0.23898464
0.13712199]
Mean Squared Error (MSE): 0.050029532713638136
Root Mean Squared Error (RMSE): 0.22367282515683065

可视化

# 绘制马鞍面
x = np.linspace(-1, 1, 100)
y = np.linspace(-1, 1, 100)
X, Y = np.meshgrid(x, y)
Z = X**2 - Y**2# 使用交互式 matplotlib widget 绘制马鞍面
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, Y, Z, cmap='viridis', alpha=0.6)# 添加预测点(如果有)
# ax.scatter(X_train[:, 0], X_train[:, 1], X_train[:, 2], c='g', marker='o', label='Surface Points')
# ax.scatter(y_train[:, 0], y_train[:, 1], y_train[:, 2], c='b', marker='o', label='Saddle Surface Points')
ax.scatter(predicted_points[:, 0], predicted_points[:, 1], predicted_points[:, 2], c='r', marker='x', s=100, label='Predicted Point')ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')plt.legend()
plt.show()

结果:
在这里插入图片描述

打赏博主

制作不易,如果有能帮到你,可以打赏博主一瓶可乐
在这里插入图片描述

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com