// Libraries mod data; mod infrence; mod model; mod training; use super::config::OperationMode; use burn::{ backend::{Autodiff, WebGpu}, data::dataset::Dataset, optim::AdamConfig, }; // Structures pub struct NeuralNetwork { mode: OperationMode, } // Implementaions impl NeuralNetwork { // Constructors pub fn new(mode: OperationMode) -> Self { // Return Result Self { mode } } // Functions fn train(&self) { type MyBackend = WebGpu; type MyAutodiffBackend = Autodiff; // Create a default Wgpu device let device = burn::backend::wgpu::WgpuDevice::default(); // All the training artifacts will be saved in this directory let artifact_dir = "/tmp/guide"; // Train the model training::train::( artifact_dir, training::TrainingConfig::new(model::ModelConfig::new(10, 512), AdamConfig::new()), device.clone(), ); // Infer the model inference::infer::( artifact_dir, device, burn::data::dataset::vision::MnistDataset::test() .get(42) .unwrap(), ); } fn infer(&self) { type MyBackend = WebGpu; let device = burn::backend::wgpu::WgpuDevice::default(); // All the training artifacts are saved in this directory let artifact_dir = "/tmp/guide"; // Infer the model infrence::infer::( artifact_dir, device, burn::data::dataset::vision::MnistDataset::test() .get(42) .unwrap(), ); } pub fn start(&self) { // Switching based on mode match self.mode { OperationMode::Training => self.train(), OperationMode::Infrence => self.infer(), } } }