Dockerized-Rust-ML/project/src/neural.rs
2025-03-19 16:01:24 -04:00

77 lines
1.8 KiB
Rust

// Libraries
mod data;
mod inference;
mod model;
mod training;
use super::config::OperationMode;
use burn::{
backend::{Autodiff, Wgpu},
data::dataset::Dataset,
optim::AdamConfig,
};
// Constants
const MODEL_DIRECTORY: &str = "./out";
// Structures
pub struct NeuralNetwork {
mode: OperationMode,
}
// Implementaions
impl NeuralNetwork {
// Constructors
pub fn new(mode: OperationMode) -> Self {
// Return Result
Self { mode }
}
// Functions
fn train(&self) {
// Creating Backend
type MyBackend = Wgpu<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>;
// Create a default Wgpu device
let device = burn::backend::wgpu::WgpuDevice::default();
// Train the model
training::train::<MyAutodiffBackend>(
MODEL_DIRECTORY,
training::TrainingConfig::new(model::ModelConfig::new(10, 512), AdamConfig::new()),
device.clone(),
);
// Infer the model
inference::infer::<MyBackend>(
MODEL_DIRECTORY,
device,
burn::data::dataset::vision::MnistDataset::test()
.get(42)
.unwrap(),
);
}
fn infer(&self) {
// Creating Backend
type MyBackend = Wgpu<f32, i32>;
let device = burn::backend::wgpu::WgpuDevice::default();
// Infer the model
inference::infer::<MyBackend>(
MODEL_DIRECTORY,
device,
burn::data::dataset::vision::MnistDataset::test()
.get(42)
.unwrap(),
);
}
pub fn start(&self) {
// Switching based on mode
match self.mode {
OperationMode::Training => self.train(),
OperationMode::Inference => self.infer(),
}
}
}