Matplotlib绘制散点图趋势线:全面指南与实例
参考:Drawing Scatter Trend Lines Using Matplotlib
Matplotlib是Python中最流行的数据可视化库之一,它提供了强大的工具来创建各种类型的图表,包括散点图和趋势线。本文将深入探讨如何使用Matplotlib绘制散点图并添加趋势线,这是数据分析和科学研究中常用的技术。我们将从基础开始,逐步深入,涵盖多种方法和技巧,以帮助您掌握这一重要的数据可视化技能。
1. 散点图基础
散点图是展示两个变量之间关系的最基本和最有效的方法之一。在Matplotlib中,我们可以使用plt.scatter()
函数来创建散点图。
1.1 创建简单的散点图
让我们从一个简单的散点图开始:
import matplotlib.pyplot as plt
import numpy as np
# 生成示例数据
x = np.linspace(0, 10, 50)
y = 2 * x + 1 + np.random.randn(50)
# 创建散点图
plt.figure(figsize=(10, 6))
plt.scatter(x, y, color='blue', alpha=0.6)
plt.title('Simple Scatter Plot - how2matplotlib.com')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.show()
Output:
在这个例子中,我们首先导入必要的库,然后生成一些示例数据。np.linspace()
用于创建一个均匀分布的x值数组,而y值是通过一个简单的线性关系加上一些随机噪声生成的。plt.scatter()
函数用于绘制散点图,我们设置了点的颜色和透明度。最后,我们添加了标题和轴标签,并显示图形。
1.2 自定义散点图
Matplotlib提供了多种方式来自定义散点图的外观:
import matplotlib.pyplot as plt
import numpy as np
# 生成示例数据
x = np.linspace(0, 10, 50)
y = 2 * x + 1 + np.random.randn(50)
# 创建自定义散点图
plt.figure(figsize=(10, 6))
plt.scatter(x, y, c=y, cmap='viridis', s=50, alpha=0.7, edgecolors='black')
plt.colorbar(label='Y values')
plt.title('Customized Scatter Plot - how2matplotlib.com')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()
Output:
在这个例子中,我们使用了更多的自定义选项:
– c=y
:根据y值来设置点的颜色
– cmap='viridis'
:使用viridis颜色映射
– s=50
:设置点的大小
– edgecolors='black'
:设置点的边缘颜色
– plt.colorbar()
:添加颜色条
– plt.grid()
:添加网格线
这些选项让我们的散点图更加信息丰富和视觉吸引力。
2. 添加趋势线
趋势线可以帮助我们更好地理解数据的整体趋势。在Matplotlib中,我们可以使用多种方法来添加趋势线。
2.1 线性趋势线
最简单的趋势线是线性趋势线,它假设数据遵循线性关系。
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
# 生成示例数据
x = np.linspace(0, 10, 50)
y = 2 * x + 1 + np.random.randn(50)
# 计算线性回归
slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
line = slope * x + intercept
# 绘制散点图和趋势线
plt.figure(figsize=(10, 6))
plt.scatter(x, y, color='blue', alpha=0.6, label='Data')
plt.plot(x, line, color='red', label=f'Trend line (R²={r_value**2:.2f})')
plt.title('Scatter Plot with Linear Trend Line - how2matplotlib.com')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.legend()
plt.show()
Output:
在这个例子中,我们使用scipy.stats.linregress()
函数来计算线性回归。这个函数返回斜率、截距和其他统计信息。我们使用这些信息来绘制趋势线,并在图例中显示R²值。
2.2 多项式趋势线
对于非线性关系,我们可以使用多项式趋势线:
import matplotlib.pyplot as plt
import numpy as np
# 生成示例数据
x = np.linspace(0, 10, 100)
y = 2 * x**2 - 5 * x + 3 + np.random.randn(100) * 10
# 计算多项式趋势线
z = np.polyfit(x, y, 2)
p = np.poly1d(z)
# 绘制散点图和趋势线
plt.figure(figsize=(10, 6))
plt.scatter(x, y, color='blue', alpha=0.6, label='Data')
plt.plot(x, p(x), color='red', label='Polynomial trend line')
plt.title('Scatter Plot with Polynomial Trend Line - how2matplotlib.com')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.legend()
plt.show()
Output:
在这个例子中,我们使用np.polyfit()
函数来计算多项式系数,然后使用np.poly1d()
创建一个多项式函数。这允许我们绘制一条更适合非线性数据的趋势线。
2.3 局部加权回归散点平滑(LOWESS)
对于更复杂的数据模式,我们可以使用LOWESS(Locally Weighted Scatterplot Smoothing)方法:
import matplotlib.pyplot as plt
import numpy as np
from statsmodels.nonparametric.smoothers_lowess import lowess
# 生成示例数据
x = np.linspace(0, 10, 100)
y = np.sin(x) + np.random.randn(100) * 0.2
# 计算LOWESS趋势线
lowess_result = lowess(y, x, frac=0.6)
# 绘制散点图和趋势线
plt.figure(figsize=(10, 6))
plt.scatter(x, y, color='blue', alpha=0.6, label='Data')
plt.plot(lowess_result[:, 0], lowess_result[:, 1], color='red', label='LOWESS trend line')
plt.title('Scatter Plot with LOWESS Trend Line - how2matplotlib.com')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.legend()
plt.show()
Output:
LOWESS方法非常适合捕捉数据中的局部趋势,特别是当数据关系不遵循简单的数学函数时。
3. 多组数据的趋势线
在实际应用中,我们经常需要比较多组数据的趋势。Matplotlib允许我们在同一图表上绘制多组数据和它们各自的趋势线。
3.1 多组线性趋势线
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
# 生成示例数据
x = np.linspace(0, 10, 50)
y1 = 2 * x + 1 + np.random.randn(50)
y2 = 1.5 * x + 2 + np.random.randn(50)
# 计算线性回归
slope1, intercept1, r_value1, _, _ = stats.linregress(x, y1)
slope2, intercept2, r_value2, _, _ = stats.linregress(x, y2)
# 绘制散点图和趋势线
plt.figure(figsize=(10, 6))
plt.scatter(x, y1, color='blue', alpha=0.6, label='Data 1')
plt.scatter(x, y2, color='green', alpha=0.6, label='Data 2')
plt.plot(x, slope1 * x + intercept1, color='red', label=f'Trend 1 (R²={r_value1**2:.2f})')
plt.plot(x, slope2 * x + intercept2, color='orange', label=f'Trend 2 (R²={r_value2**2:.2f})')
plt.title('Multiple Scatter Plots with Trend Lines - how2matplotlib.com')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.legend()
plt.show()
Output:
这个例子展示了如何在同一图表上比较两组数据的线性趋势。我们使用不同的颜色来区分数据集和它们的趋势线,并在图例中显示各自的R²值。
3.2 不同类型的趋势线比较
有时,我们可能需要比较不同类型的趋势线:
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
from statsmodels.nonparametric.smoothers_lowess import lowess
# 生成示例数据
x = np.linspace(0, 10, 100)
y = 2 * x**2 - 5 * x + 3 + np.random.randn(100) * 10
# 计算线性趋势线
slope, intercept, r_value, _, _ = stats.linregress(x, y)
linear_trend = slope * x + intercept
# 计算多项式趋势线
z = np.polyfit(x, y, 2)
poly_trend = np.poly1d(z)(x)
# 计算LOWESS趋势线
lowess_result = lowess(y, x, frac=0.6)
# 绘制散点图和趋势线
plt.figure(figsize=(12, 7))
plt.scatter(x, y, color='blue', alpha=0.6, label='Data')
plt.plot(x, linear_trend, color='red', label=f'Linear (R²={r_value**2:.2f})')
plt.plot(x, poly_trend, color='green', label='Polynomial')
plt.plot(lowess_result[:, 0], lowess_result[:, 1], color='orange', label='LOWESS')
plt.title('Comparison of Different Trend Lines - how2matplotlib.com')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.legend()
plt.show()
Output:
这个例子展示了如何在同一图表上比较线性、多项式和LOWESS趋势线。这种比较可以帮助我们选择最适合数据的趋势线类型。
4. 高级技巧
4.1 置信区间
添加置信区间可以帮助我们了解趋势线的不确定性:
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
# 生成示例数据
x = np.linspace(0, 10, 50)
y = 2 * x + 1 + np.random.randn(50) * 2
# 计算线性回归
slope, intercept, r_value, _, std_err = stats.linregress(x, y)
# 计算预测值和置信区间
y_pred = slope * x + intercept
pi = 1.96 * std_err * np.sqrt(1 + 1/len(x) + (x - np.mean(x))**2 / np.sum((x - np.mean(x))**2))
# 绘制散点图、趋势线和置信区间
plt.figure(figsize=(10, 6))
plt.scatter(x, y, color='blue', alpha=0.6, label='Data')
plt.plot(x, y_pred, color='red', label=f'Trend line (R²={r_value**2:.2f})')
plt.fill_between(x, y_pred - pi, y_pred + pi, color='gray', alpha=0.2, label='95% CI')
plt.title('Scatter Plot with Trend Line and Confidence Interval - how2matplotlib.com')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.legend()
plt.show()
Output:
这个例子展示了如何添加95%置信区间到线性趋势线。置信区间可以帮助我们理解趋势线的可靠性。
4.2 分组数据的趋势线
当我们有分类数据时,可能需要为每个类别绘制单独的趋势线:
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
# 生成示例数据
np.random.seed(0)
categories = ['A', 'B', 'C']
colors = ['red', 'green', 'blue']
x = np.random.rand(150)
y = np.random.rand(150)
c = np.random.choice(categories, 150)
# 绘制散点图和趋势线
plt.figure(figsize=(10, 6))
for category, color in zip(categories, colors):
mask = c == category
x_cat = x[mask]
y_cat = y[mask]
plt.scatter(x_cat, y_cat, c=color, alpha=0.6, label=f'Category {category}')
slope, intercept, r_value, _, _ = stats.linregress(x_cat, y_cat)
line = slope * x_cat + intercept
plt.plot(x_cat, line, c=color, linestyle='--')
plt.title('Scatter Plot with Trend Lines by Category - how2matplotlib.com')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.legend()
plt.show()
Output:
这个例子展示了如何为不同类别的数据绘制单独的趋势线。这种方法在比较不同组或类别的趋势时非常有用。
4.3 动态趋势线
在某些情况下,我们可能希望创建一个交互式的图表,允许用户动态调整趋势线。虽然Matplotlib本身不支持交互功能,但我们可以结合使用ipywidgets来实现这一点:
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact, interactive, fixed
import ipywidgets as widgets
def plot_trend(degree):
# 生成示例数据
x = np.linspace(0, 10, 100)
y = 2 * x**2 - 5 * x + 3 + np.random.randn(100) * 10
# 计算多项式趋势线
z = np.polyfit(x, y, degree)
p = np.poly1d(z)
# 绘制散点图和趋势线
plt.figure(figsize=(10, 6))
plt.scatter(x, y, color='blue', alpha=0.6, label='Data')
plt.plot(x, p(x), color='red', label=f'Polynomial trend (degree={degree})')
plt.title('Dynamic Trend Line - how2matplotlib.com')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.legend()
plt.show()
interact(plot_trend, degree=widgets.IntSlider(min=1, max=10, step=1, value=2))
这个例子创建了一个交互式的图表,用户可以通过滑块调整多项式趋势线的阶数。这种方法特别适用于探索性数据分析,允许用户快速尝试不同的趋势线拟合。
5. 处理异常值和数据清理
在绘制趋势线时,处理异常值是一个重要的步骤,因为异常值可能会显著影响趋势线的形状和方向。
5.1 识别和移除异常值
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
# 生成示例数据(包含异常值)
np.random.seed(0)
x = np.linspace(0, 10, 100)
y = 2 * x + 1 + np.random.randn(100) * 2
y[80] = 30 # 添加一个异常值
# 计算Z-score
z_scores = np.abs(stats.zscore(y))
# 移除异常值(Z-score > 3)
threshold = 3
x_clean = x[z_scores < threshold]
y_clean = y[z_scores < threshold]
# 计算趋势线
slope, intercept, r_value, _, _ = stats.linregress(x_clean, y_clean)
line = slope * x + intercept
# 绘图
plt.figure(figsize=(12, 6))
plt.scatter(x, y, color='blue', alpha=0.6, label='Original data')
plt.scatter(x_clean, y_clean, color='green', alpha=0.6, label='Cleaned data')
plt.plot(x, line, color='red', label=f'Trend line (R²={r_value**2:.2f})')
plt.title('Scatter Plot with Outlier Removal - how2matplotlib.com')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.legend()
plt.show()
Output:
这个例子展示了如何使用Z-score方法识别和移除异常值,然后基于清理后的数据绘制趋势线。这种方法可以帮助我们获得更准确的趋势线。
5.2 鲁棒回归
另一种处理异常值的方法是使用鲁棒回归技术,如RANSAC(Random Sample Consensus):
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import RANSACRegressor
# 生成示例数据(包含异常值)
np.random.seed(0)
x = np.linspace(0, 10, 100).reshape(-1, 1)
y = 2 * x.ravel() + 1 + np.random.randn(100) * 2
y[80] = 30 # 添加一个异常值
# 使用RANSAC进行鲁棒回归
ransac = RANSACRegressor()
ransac.fit(x, y)
inlier_mask = ransac.inlier_mask_
outlier_mask = np.logical_not(inlier_mask)
# 绘图
plt.figure(figsize=(12, 6))
plt.scatter(x[inlier_mask], y[inlier_mask], color='blue', alpha=0.6, label='Inliers')
plt.scatter(x[outlier_mask], y[outlier_mask], color='red', alpha=0.6, label='Outliers')
line_x = np.arange(x.min(), x.max())[:, np.newaxis]
line_y = ransac.predict(line_x)
plt.plot(line_x, line_y, color='green', label='RANSAC trend line')
plt.title('Robust Regression with RANSAC - how2matplotlib.com')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.legend()
plt.show()
Output:
RANSAC算法能够自动识别异常值并基于内点(inliers)拟合模型,这使得它在存在显著异常值的情况下特别有用。
6. 趋势线的统计评估
在绘制趋势线时,了解其统计显著性和拟合优度是很重要的。
6.1 R-squared和p值
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
# 生成示例数据
np.random.seed(0)
x = np.linspace(0, 10, 100)
y = 2 * x + 1 + np.random.randn(100) * 3
# 计算线性回归
slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
line = slope * x + intercept
# 绘图
plt.figure(figsize=(10, 6))
plt.scatter(x, y, color='blue', alpha=0.6, label='Data')
plt.plot(x, line, color='red', label=f'Trend line')
plt.title('Scatter Plot with Trend Line and Statistics - how2matplotlib.com')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.legend()
# 添加统计信息
stats_text = f'R² = {r_value**2:.3f}\np-value = {p_value:.3e}'
plt.text(0.05, 0.95, stats_text, transform=plt.gca().transAxes,
verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.5))
plt.show()
Output:
这个例子不仅绘制了趋势线,还在图表上显示了R-squared值和p值。R-squared值表示模型解释数据变异性的程度,而p值表示趋势线的统计显著性。
6.2 残差分析
残差分析是评估趋势线拟合质量的重要工具:
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
# 生成示例数据
np.random.seed(0)
x = np.linspace(0, 10, 100)
y = 2 * x + 1 + np.random.randn(100) * 2
# 计算线性回归
slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
y_pred = slope * x + intercept
# 计算残差
residuals = y - y_pred
# 创建子图
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10))
# 绘制散点图和趋势线
ax1.scatter(x, y, color='blue', alpha=0.6, label='Data')
ax1.plot(x, y_pred, color='red', label='Trend line')
ax1.set_title('Scatter Plot with Trend Line - how2matplotlib.com')
ax1.set_xlabel('X-axis')
ax1.set_ylabel('Y-axis')
ax1.legend()
# 绘制残差图
ax2.scatter(x, residuals, color='green', alpha=0.6)
ax2.axhline(y=0, color='red', linestyle='--')
ax2.set_title('Residual Plot')
ax2.set_xlabel('X-axis')
ax2.set_ylabel('Residuals')
plt.tight_layout()
plt.show()
Output:
残差图可以帮助我们识别趋势线是否适当地捕捉了数据的模式。理想情况下,残差应该随机分布在零线周围,没有明显的模式。
7. 高级可视化技巧
7.1 3D散点图和趋势面
对于三维数据,我们可以绘制3D散点图和趋势面:
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
# 生成示例数据
np.random.seed(0)
x = np.random.rand(100)
y = np.random.rand(100)
z = 2*x + 3*y + np.random.randn(100)*0.5
# 创建网格
xi = np.linspace(0, 1, 100)
yi = np.linspace(0, 1, 100)
X, Y = np.meshgrid(xi, yi)
# 计算趋势面
A = np.c_[np.ones(len(x)), x, y]
C, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
Z = C[0] + C[1]*X + C[2]*Y
# 创建3D图
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111, projection='3d')
# 绘制散点图
ax.scatter(x, y, z, c='blue', s=50, alpha=0.6, label='Data')
# 绘制趋势面
surf = ax.plot_surface(X, Y, Z, cmap='viridis', alpha=0.5)
ax.set_xlabel('X-axis')
ax.set_ylabel('Y-axis')
ax.set_zlabel('Z-axis')
ax.set_title('3D Scatter Plot with Trend Surface - how2matplotlib.com')
plt.colorbar(surf, shrink=0.5, aspect=5)
plt.legend()
plt.show()
Output:
这个例子展示了如何创建3D散点图并添加趋势面。这种可视化方法对于理解三个变量之间的关系特别有用。
7.2 等高线图
对于三维数据,等高线图也是一个有效的可视化工具:
import matplotlib.pyplot as plt
import numpy as np
# 生成示例数据
np.random.seed(0)
x = np.random.rand(100)
y = np.random.rand(100)
z = 2*x + 3*y + np.random.randn(100)*0.5
# 创建网格
xi = np.linspace(0, 1, 100)
yi = np.linspace(0, 1, 100)
X, Y = np.meshgrid(xi, yi)
# 计算趋势面
A = np.c_[np.ones(len(x)), x, y]
C, _, _, _ = np.linalg.lstsq(A, z, rcond=None)
Z = C[0] + C[1]*X + C[2]*Y
# 创建等高线图
plt.figure(figsize=(10, 8))
plt.scatter(x, y, c=z, cmap='viridis', s=50, alpha=0.6)
contour = plt.contour(X, Y, Z, colors='black', alpha=0.7)
plt.clabel(contour, inline=True, fontsize=8)
plt.colorbar(label='Z values')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.title('Contour Plot with Scatter Data - how2matplotlib.com')
plt.show()
Output:
等高线图可以帮助我们理解三维数据在二维平面上的投影,同时保留了高度信息。
结论
本文详细介绍了如何使用Matplotlib绘制散点图和趋势线,涵盖了从基础到高级的多个方面。我们探讨了不同类型的趋势线,包括线性、多项式和LOWESS方法,以及如何处理异常值、进行统计评估和创建高级可视化。这些技术和方法为数据分析和科学研究提供了强大的工具,能够帮助我们更好地理解和展示数据中的模式和关系。
通过掌握这些技巧,您将能够创建更加丰富、信息量更大的数据可视化,从而更有效地传达您的研究结果和见解。记住,选择合适的趋势线类型和可视化方法取决于您的具体数据和研究目标。实践和实验是提高数据可视化技能的关键,所以不要犹豫,开始尝试这些技术吧!