diff --git a/tests/test_engine.py b/tests/test_engine.py index 0159111..6586dcf 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -32,6 +32,7 @@ class MockModel: self.vocab_size = vocab_size self.config = MockConfig() self._device = "cpu" + self._device = torch.device("cpu") def get_device(self): return self._device