欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 汽车 > 时评 > 机器学习day8

机器学习day8

2025/2/7 5:01:00 来源:https://blog.csdn.net/h1413517383/article/details/145432221  浏览:    关键词:机器学习day8

自定义数据集 ,使用朴素贝叶斯对其进行分类

代码

import numpy as np
import matplotlib.pyplot as pltclass1_points = np.array([[2.1, 2.2], [2.4, 2.5], [2.2, 2.0], [2.0, 2.1], [2.3, 2.3], [2.6, 2.4], [2.5, 2.1]])
class2_points = np.array([[4.0, 3.5], [4.2, 3.9], [4.1, 3.8], [3.7, 3.4], [4.4, 3.6], [4.5, 3.7], [4.3, 3.9]])X = np.concatenate((class1_points, class2_points), axis=0)
Y = np.concatenate((np.zeros(len(class1_points)), np.ones(len(class2_points))), axis=0)
print(Y)prior_prob = [np.sum(Y == 0) / len(Y), np.sum(Y == 1) / len(Y)]class_μ = [np.mean(X[Y == 0], axis=0), np.mean(X[Y == 1], axis=0)]
class_cov = [np.cov(X[Y == 0], rowvar=False), np.cov(X[Y == 1], rowvar=False)]def pdf(x, mean, cov):n = len(mean)coff = 1 / (2 * np.pi) ** (n / 2) * np.sqrt(np.linalg.det(cov))exponent = np.exp(-(1 / 2) * np.dot(np.dot((x - mean).T, np.linalg.inv(cov)), (x - mean)))return coff * exponentxx, yy = np.meshgrid(np.arange(0, 5, 0.05), np.arange(0, 5, 0.05))
grid_points = np.c_[xx.ravel(), yy.ravel()]grid_label = []
for point in grid_points:poster_prob = []for i in range(2):likelihood = pdf(point, class_μ[i], class_cov[i])poster_prob.append(prior_prob[i] * likelihood)pre_class = np.argmax(poster_prob)grid_label.append(pre_class)plt.scatter(class1_points[:, 0], class1_points[:, 1], c="blue", label="class 1")
plt.scatter(class2_points[:, 0], class2_points[:, 1], c="red", label="class 2")
plt.legend()grid_label = np.array(grid_label)
pre_grid_label = grid_label.reshape(xx.shape)
contour = plt.contour(xx, yy, pre_grid_label, level=0.5, color='green')plt.show()

效果

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com