generated from maddoxwerts/Dockerized-Rust
77 lines
1.8 KiB
Rust
77 lines
1.8 KiB
Rust
// Libraries
|
|
mod data;
|
|
mod inference;
|
|
mod model;
|
|
mod training;
|
|
use super::config::OperationMode;
|
|
|
|
use burn::{
|
|
backend::{Autodiff, Wgpu},
|
|
data::dataset::Dataset,
|
|
optim::AdamConfig,
|
|
};
|
|
|
|
// Constants
|
|
const MODEL_DIRECTORY: &str = "./out";
|
|
|
|
// Structures
|
|
pub struct NeuralNetwork {
|
|
mode: OperationMode,
|
|
}
|
|
|
|
// Implementaions
|
|
impl NeuralNetwork {
|
|
// Constructors
|
|
pub fn new(mode: OperationMode) -> Self {
|
|
// Return Result
|
|
Self { mode }
|
|
}
|
|
|
|
// Functions
|
|
fn train(&self) {
|
|
// Creating Backend
|
|
type MyBackend = Wgpu<f32, i32>;
|
|
type MyAutodiffBackend = Autodiff<MyBackend>;
|
|
|
|
// Create a default Wgpu device
|
|
let device = burn::backend::wgpu::WgpuDevice::default();
|
|
|
|
// Train the model
|
|
training::train::<MyAutodiffBackend>(
|
|
MODEL_DIRECTORY,
|
|
training::TrainingConfig::new(model::ModelConfig::new(10, 512), AdamConfig::new()),
|
|
device.clone(),
|
|
);
|
|
|
|
// Infer the model
|
|
inference::infer::<MyBackend>(
|
|
MODEL_DIRECTORY,
|
|
device,
|
|
burn::data::dataset::vision::MnistDataset::test()
|
|
.get(42)
|
|
.unwrap(),
|
|
);
|
|
}
|
|
fn infer(&self) {
|
|
// Creating Backend
|
|
type MyBackend = Wgpu<f32, i32>;
|
|
let device = burn::backend::wgpu::WgpuDevice::default();
|
|
|
|
// Infer the model
|
|
inference::infer::<MyBackend>(
|
|
MODEL_DIRECTORY,
|
|
device,
|
|
burn::data::dataset::vision::MnistDataset::test()
|
|
.get(42)
|
|
.unwrap(),
|
|
);
|
|
}
|
|
|
|
pub fn start(&self) {
|
|
// Switching based on mode
|
|
match self.mode {
|
|
OperationMode::Training => self.train(),
|
|
OperationMode::Inference => self.infer(),
|
|
}
|
|
}
|
|
}
|