I've used pytorch successfully in a MacOS VM on MacOS using https://tart.run/ so I'd expect it to work here too.
update: torch for Linux on ARM isn't built with Apple's MPS support so it didn't work with the pip install version. Perhaps it's possible to compile from scratch to have it.