游乐游手机版
首页/AI热点日报/热点详情

PyTorch深度学习实战:基于神经网络的水质分类方法

类型:热点整理2026-06-07
基于PyTorch2 8设计全连接神经网络用于水质分类,数据集因数量有限在训练和测试阶段使用同一批图片。项目包含网络结构定义、模型训练、性能评估及可视化UI界面,代码已进行接口适配与界面美化。

用神经网络来判断水质,这个方向其实挺实用的。这边就一步一步拆解一下整个流程——从网络结构设计,到模型训练、评估,最后再搭一个可视化的UI界面,一整套走下来。

先交代几个注意事项。第一,这个项目在训练和测试阶段用的是同一批图片,原因是手上的数据量确实有限,但这在实际工程中是不太规范的,大家心里有数就好。第二,代码基于PyTorch 2.8,和我之前写的一个项目思路一致,但部分接口作了适配调整。第三,最后的UI页面做了不少美化工作,内容也补充得更为详细,一起往下看吧。

数据集

数据集链接:https://pan.baidu.com/s/1DSDl5uKF0qaoyVs3f-L7iQ?pwd=wy46 提取码:wy46


整体思路

整个项目按模块拆分为四个部分:

  • 定义网络结构(network.py):设计神经网络的主体架构。
  • 训练(train.py):负责数据读取和模型训练。
  • 评估(evaluation.py):对训练好的模型做性能评估。
  • 可视化界面(UI.py):一个可交互的UI页面,让用户能直观地使用模型。


定义网络结构


代码实现

import torch

class WaterQualityNet(torch.nn.Module):
    def __init__(self, input_size=32*32*3, hidden_size=128, num_classes=2):
        super(WaterQualityNet, self).__init__()
        self.fc1 = torch.nn.Linear(input_size, hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, hidden_size)
        self.fc3 = torch.nn.Linear(hidden_size, num_classes)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = x.view(x.size(0), -1)  # 展平 (batch_size, 3072)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)  # 输出logits
        return x

if __name__ == '__main__':
    model = WaterQualityNet()
    test_input = torch.randn(2, 3, 32, 32)
    output = model(test_input)
    print("--- network.py 测试结果 ---")
    print(f"模型输出尺寸 (Batch, Classes): {output.shape}")

简单用表格梳理一下数据流:

步骤操作 / 组件输入形状输出形状核心作用
1输入 (x)(B,3,32,32)(B,3,32,32)原始图像
2展平(B,3,32,32)(B,3072)适配全连接层
3fc1(B,3072)(B,128)映射到隐藏层
4ReLU(B,128)(B,128)非线性激活
5fc2(B,128)(B,128)精炼特征
6ReLU(B,128)(B,128)再次非线性
7fc3(B,128)(B,2)映射到类别空间
8输出(B,2)(B,2)输出2个类别的Logits

训练

Dataset 和 DataLoader

先拿一个房价数据集来打个样——10个样本,特征是面积和房龄,标签是房价。

X_features = np.random.rand(10, 2) * 100
Y_prices = 2 * X_features[:, 0] + 5 * X_features[:, 1] + 10 + np.random.randn(10) * 5
Y_labels = Y_prices.reshape(-1, 1)
print(f"原始总样本数: {len(X_features)}")
print(f"原始特征矩阵形状: {X_features.shape}")
print(X_features)
print(Y_labels)


这就意味着 housing_data_warehouseHousingDataset 类的一个实例,含有两个属性和两个方法。

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class HousingDataset(Dataset):
    def __init__(self, X_data, Y_data):
        self.X = torch.tensor(X_data, dtype=torch.float32)
        self.Y = torch.tensor(Y_data, dtype=torch.float32)

    def __len__(self):
        print("n[内部调用]:__len__ 方法被调用,返回数据集大小。")
        return len(self.X)

    def __getitem__(self, idx):
        print(f"n[内部调用]:__getitem__({idx}) 方法被调用,正在准备样本。")
        return self.X[idx], self.Y[idx]

X_features = np.random.rand(10, 2) * 100
Y_prices = 2 * X_features[:, 0] + 5 * X_features[:, 1] + 10 + np.random.randn(10) * 5
Y_labels = Y_prices.reshape(-1, 1)

housing_data_warehouse = HousingDataset(X_features, Y_labels)
print("n---housing_data_warehouse的实例对象的两个属性 ")
print(housing_data_warehouse.X)
print((housing_data_warehouse.X))

print("n--- 演示 __len__ 的作用 ---")
dataset_size = len(housing_data_warehouse)
print(f"外部程序调用 len(dataset) 得到的结果是: {dataset_size}")
print("结论:len() 方法告诉 DataLoader 等程序,数据集有多少行数据。")

print("n--- 演示 __getitem__ 的作用 ---")
sample_index = 5
feature_5, label_5 = housing_data_warehouse[sample_index]
print(f"外部程序调用 dataset[{sample_index}] 得到的结果:")
print(f"  特征 (X): {feature_5.numpy().round(2)}")
print(f"  标签 (Y): {label_5.item():.2f}")
print("结论:__getitem__ 负责数据的实际加载和返回。")


batch_size = 4
shuffle = True
num_workers = 0

data_loader_pipe = DataLoader(
    dataset=housing_data_warehouse,
    batch_size=batch_size,
    shuffle=shuffle,
    num_workers=num_workers
)

print(f"使用 DataLoader 迭代数据 (批量大小: {batch_size}):")
total_batches = len(data_loader_pipe)
print(f"总共有 {total_batches} 个 Mini-batch (10 / 4 = 2 批,最后剩 2 个样本是第 3 批)。")

for i, (batch_X, batch_Y) in enumerate(data_loader_pipe):
    print(f"n--- 批次 {i + 1}/{total_batches} ---")
    print(f"批次 X 形状: {batch_X.shape}")
    print(f"批次 Y 形状: {batch_Y.shape}")
    print(f"X (前两个样本): {batch_X[:2].numpy().round(2)}")
    print(f"Y (前两个标签): {batch_Y[:2].numpy().round(2)}")


数据读取和预处理

class WaterQualityDataset(Dataset):
    def __init__(self, root, transform=None):
        self.dataset = ImageFolder(root, transform=transform)

    def __getitem__(self, index):
        return self.dataset[index]

    def __len__(self):
        return len(self.dataset)

ImageFolder 这个工具非常方便,它自动遍历根目录下的子文件夹,把文件夹名作为类别标签,收集所有图片路径,并建立类别名到数字标签的映射。因此 self.dataset 已经是一个可用的 Dataset 对象了。它的 __getitem__ 会自动找到文件、读取图片、应用变换,最终返回图像Tensor和数字标签。__len__ 则返回数据集大小。


transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
])

这里 Resize((32, 32)) 把所有图片统一缩放到32×32,ToTensor() 则转成张量并归一化像素值到 [0, 1] 区间。

train_dataset = WaterQualityDataset('D:/dataset', transform=transform)
test_dataset = WaterQualityDataset('D:/dataset', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

batch_size=1 每次加载1张图片,实际项目可以根据显存适当增大;shuffle=True 打乱训练顺序提升泛化;num_workers 在GPU训练时可以设大一点加速数据加载。

定义模型、损失函数、优化器

model = WaterQualityNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

指定设备,开始训练

num_epochs = 10
loss_history = []

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    epoch_loss = running_loss / len(train_loader)
    loss_history.append(epoch_loss)
    print(f'Training Loss: {epoch_loss}')

几个关键点:model.train() 启用训练模式;前向传播计算输出;损失函数计算 loss;反向传播计算梯度;优化器更新参数。每个 epoch 结束时记录平均损失。

可视化损失曲线

plt.figure(figsize=(8, 5))
plt.plot(range(1, num_epochs + 1), loss_history, marker='o', linestyle='-', color='b')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.show()

横轴是训练轮数,纵轴是损失值。如果曲线持续下降,说明模型正在收敛;如果震荡或上升,可能需要排查学习率或数据稳定性问题。

完整训练代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from tqdm import tqdm
from network import WaterQualityNet
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

class WaterQualityDataset(Dataset):
    def __init__(self, root, transform=None):
        self.dataset = ImageFolder(root, transform=transform)
    def __getitem__(self, index):
        return self.dataset[index]
    def __len__(self):
        return len(self.dataset)

transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
])

train_dataset = WaterQualityDataset('D:/dataset', transform=transform)
test_dataset = WaterQualityDataset('D:/dataset', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

model = WaterQualityNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

device = torch.device("cuda" if torch.cuda.is_a vailable() else "cpu")
model.to(device)

num_epochs = 10
loss_history = []

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    epoch_loss = running_loss / len(train_loader)
    loss_history.append(epoch_loss)
    print(f'Training Loss: {epoch_loss}')

torch.sa ve(model, 'water_quality_full_model.pth')
print('模型已经保存为 water_quality_full_model.pth')

plt.figure(figsize=(8, 5))
plt.plot(range(1, num_epochs + 1), loss_history, marker='o', linestyle='-', color='b')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.show()



评估

import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from tqdm import tqdm

test_loader = DataLoader(
    datasets.ImageFolder('D:/dataset1', transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
    ])),
    batch_size=32, shuffle=False, num_workers=0
)

device = torch.device("cuda" if torch.cuda.is_a vailable() else "cpu")
model = torch.load('water_quality_full_model.pth')
model.to(device)
model.eval()

correct, total = 0, 0
with torch.no_grad():
    for images, labels in tqdm(test_loader, desc='Evaluating'):
        outputs = model(images.to(device))
        correct += (outputs.argmax(1) == labels.to(device)).sum().item()
        total += labels.size(0)

print(f" Top-1 Accuracy: {correct / total * 100:.2f}%")

这里解释一下为什么不用 Softmax:对原始 Logits 取最大值和经过 Softmax 之后取最大值,结果是一样的。既然结果一致,自然选择计算量更小的方式,跳过 Softmax 这一步。


UI 页面

下面是用 PyQt5 搭建的一个交互界面,用户可以选择一张水质图片,模型会返回预测结果并显示置信度。

import sys
import torch
import torchvision.transforms as transforms
from PIL import Image
from PyQt5.QtWidgets import QApplication, QWidget, QLabel, QPushButton, QFileDialog, QVBoxLayout, QMessageBox
from PyQt5.QtGui import QPixmap, QFont
from PyQt5.QtCore import Qt

# --- 模型导入与加载 ---
try:
    class WaterQualityNet(torch.nn.Module):
        def __init__(self):
            super().__init__()
            print("注意:WaterQualityNet 类已定义占位符。如果模型加载失败,请提供实际的模型定义。")
        def forward(self, x):
            return torch.rand(x.size(0), 2)
except ImportError:
    print("错误:未找到 WaterQualityNet 类,请检查 network.py 导入路径是否正确!")
    sys.exit(1)

device = torch.device("cuda" if torch.cuda.is_a vailable() else "cpu")
model_path = "water_quality_full_model.pth"
model = None
try:
    model = torch.load(model_path, map_location=device, weights_only=False)
    model.to(device)
    model.eval()
    print(f"模型成功加载到 {device} 设备")
except Exception as e:
    print(f"模型加载失败,使用 WaterQualityNet 占位符可能导致加载失败。错误详情:{str(e)}")
    model = WaterQualityNet().to(device)
    model.eval()

transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
])

QUALITY_MAPPING = {
    0: " 干净水 (一类)",
    1: "️ 轻微污染水 (二类)"
}

class WaterQualityApp(QWidget):
    def __init__(self):
        super().__init__()
        self.image_path = None
        self.init_ui()
        self.apply_style()

    def init_ui(self):
        layout = QVBoxLayout()
        layout.setSpacing(15)

        self.title_label = QLabel("智能水质图像检测系统")
        self.title_label.setObjectName("TitleLabel")
        layout.addWidget(self.title_label, alignment=Qt.AlignCenter)

        self.image_label = QLabel("等待上传图片...")
        self.image_label.setObjectName("ImageLabel")
        layout.addWidget(self.image_label, stretch=1)

        self.upload_button = QPushButton(" 上传水质图片")
        self.upload_button.setObjectName("UploadButton")
        self.upload_button.clicked.connect(self.load_image)
        layout.addWidget(self.upload_button)

        self.detect_button = QPushButton(" 检测水质类别")
        self.detect_button.setObjectName("DetectButton")
        self.detect_button.clicked.connect(self.detect_water_quality)
        layout.addWidget(self.detect_button)

        self.result_label = QLabel("结果将在这里显示...")
        self.result_label.setObjectName("ResultLabel")
        layout.addWidget(self.result_label, alignment=Qt.AlignCenter)

        self.setLayout(layout)
        self.setWindowTitle("智能水质检测系统")
        self.setGeometry(100, 100, 700, 700)

    def apply_style(self):
        self.setStyleSheet("""
            QWidget { background-color: #f8f9fa; font-family: 'Inter', sans-serif; }
            #TitleLabel { font-size: 28px; font-weight: bold; color: #007bff; padding: 15px; }
            #ImageLabel { border: 3px dashed #ced4da; border-radius: 10px; background-color: #ffffff; min-height: 300px; font-size: 16px; color: #6c757d; }
            QPushButton { font-size: 16px; padding: 12px; border-radius: 8px; border: none; font-weight: 500; color: white; }
            #UploadButton { background-color: #28a745; }
            #UploadButton:hover { background-color: #218838; }
            #DetectButton { background-color: #007bff; }
            #DetectButton:hover { background-color: #0069d9; }
            #ResultLabel { font-size: 22px; padding: 10px; border-radius: 8px; font-weight: bold; background-color: #e9ecef; color: #495057; margin-top: 15px; }
        """)

    def load_image(self):
        file_dialog = QFileDialog()
        image_path, _ = file_dialog.getOpenFileName(self, "选择水质图片", "", "图片文件 (*.png *.jpg *.jpeg *.bmp)")
        if image_path:
            try:
                pixmap = QPixmap(image_path)
                scaled_pixmap = pixmap.scaled(self.image_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation)
                self.image_label.setPixmap(scaled_pixmap)
                self.image_label.setAlignment(Qt.AlignCenter)
                self.image_label.setText("")
                self.image_path = image_path
                self.result_label.setStyleSheet("font-size: 22px; padding: 10px; border-radius: 8px; font-weight: bold; background-color: #e9ecef; color: #495057; margin-top: 15px;")
                self.result_label.setText("图片已上传,请点击 '检测水质类别'...")
            except Exception as e:
                QMessageBox.warning(self, "文件错误", f"无法加载图片文件: {str(e)}")

    def detect_water_quality(self):
        if self.image_path and model is not None:
            try:
                image = Image.open(self.image_path).convert("RGB")
                image_tensor = transform(image).unsqueeze(0).to(device)
                with torch.no_grad():
                    output = model(image_tensor)
                    probabilities = torch.softmax(output, dim=1)
                    _, predicted_class_idx = torch.max(output, 1)

                print("n" + "="*50)
                print("--- 模型输出到最终预测结果的转换 ---")
                print(f"1. 原始 Logits: {output.squeeze().tolist()}")
                print(f"2. Softmax 概率: {probabilities.squeeze().tolist()}")
                predicted_index = predicted_class_idx.item()
                print(f"3. 预测索引: {predicted_index}")

                water_quality = QUALITY_MAPPING.get(predicted_index, "未知类别")
                confidence = probabilities[0, predicted_index].item()
                print(f"4. 映射结果: {water_quality} (置信度: {confidence*100:.2f}%)")
                print("="*50)

                if predicted_index == 0:
                    color = "#28a745"
                    bg_color = "#d4edda"
                else:
                    color = "#dc3545"
                    bg_color = "#f8d7da"

                self.result_label.setStyleSheet(f"color: {color}; font-size: 24px; margin-top: 20px; font-weight: bold; background-color: {bg_color}; border: 1px solid {color}; padding: 15px; border-radius: 8px;")
                self.result_label.setText(f"{water_quality} | 置信度: {confidence*100:.2f}%")
            except Exception as e:
                self.result_label.setStyleSheet("color: #6c757d; font-size: 18px; margin-top: 20px; background-color: #fff3cd;")
                self.result_label.setText(f" 检测失败,请检查模型或输入:{str(e)[:50]}...")
        else:
            QMessageBox.warning(self, "操作提示", "请先上传一张图片!")

if __name__ == "__main__":
    app = QApplication(sys.argv)
    window = WaterQualityApp()
    window.show()
    sys.exit(app.exec_())



来源:https://developer.aliyun.com/article/1739917

相关热点

继续查看同栏目近期热点。

延伸阅读

补充最近整理过的热点入口。