// Libraries mod data; mod inference; mod model; mod training; use super::config::OperationMode; use burn::{ backend::{Autodiff, Wgpu}, data::dataset::Dataset, optim::AdamConfig, }; // Constants const MODEL_DIRECTORY: &str = "./out"; // Structures pub struct NeuralNetwork { mode: OperationMode, } // Implementaions impl NeuralNetwork { // Constructors pub fn new(mode: OperationMode) -> Self { // Return Result Self { mode } } // Functions fn train(&self) { // Creating Backend type MyBackend = Wgpu; type MyAutodiffBackend = Autodiff; // Create a default Wgpu device let device = burn::backend::wgpu::WgpuDevice::default(); // Train the model training::train::( MODEL_DIRECTORY, training::TrainingConfig::new(model::ModelConfig::new(10, 512), AdamConfig::new()), device.clone(), ); // Infer the model inference::infer::( MODEL_DIRECTORY, device, burn::data::dataset::vision::MnistDataset::test() .get(42) .unwrap(), ); } fn infer(&self) { // Creating Backend type MyBackend = Wgpu; let device = burn::backend::wgpu::WgpuDevice::default(); // Infer the model inference::infer::( MODEL_DIRECTORY, device, burn::data::dataset::vision::MnistDataset::test() .get(42) .unwrap(), ); } pub fn start(&self) { // Switching based on mode match self.mode { OperationMode::Training => self.train(), OperationMode::Inference => self.infer(), } } }