Upload app.py
Browse files
app.py
CHANGED
|
@@ -83,7 +83,7 @@ def get_paths(dataset_id):
|
|
| 83 |
|
| 84 |
|
| 85 |
def load_pgm(dataset_id, pgm_path):
|
| 86 |
-
checkpoint = torch.load(pgm_path, map_location=DEVICE
|
| 87 |
args = Hparams()
|
| 88 |
args.update(checkpoint["hparams"])
|
| 89 |
args.device = DEVICE
|
|
@@ -101,7 +101,7 @@ def load_pgm(dataset_id, pgm_path):
|
|
| 101 |
def load_vae(dataset_id, vae_path):
|
| 102 |
if "Chest" in dataset_id:
|
| 103 |
vae_path, dscm_path = vae_path[0], vae_path[1]
|
| 104 |
-
checkpoint = torch.load(vae_path, map_location=DEVICE
|
| 105 |
args = Hparams()
|
| 106 |
args.update(checkpoint["hparams"])
|
| 107 |
# backwards compatibility hack
|
|
@@ -115,7 +115,7 @@ def load_vae(dataset_id, vae_path):
|
|
| 115 |
vae = HVAE(args).to(args.device)
|
| 116 |
|
| 117 |
if "Chest" in dataset_id:
|
| 118 |
-
dscm_ckpt = torch.load(dscm_path, map_location=DEVICE
|
| 119 |
vae.load_state_dict(
|
| 120 |
{
|
| 121 |
k[4:]: v
|
|
|
|
| 83 |
|
| 84 |
|
| 85 |
def load_pgm(dataset_id, pgm_path):
|
| 86 |
+
checkpoint = torch.load(pgm_path, map_location=DEVICE)
|
| 87 |
args = Hparams()
|
| 88 |
args.update(checkpoint["hparams"])
|
| 89 |
args.device = DEVICE
|
|
|
|
| 101 |
def load_vae(dataset_id, vae_path):
|
| 102 |
if "Chest" in dataset_id:
|
| 103 |
vae_path, dscm_path = vae_path[0], vae_path[1]
|
| 104 |
+
checkpoint = torch.load(vae_path, map_location=DEVICE)
|
| 105 |
args = Hparams()
|
| 106 |
args.update(checkpoint["hparams"])
|
| 107 |
# backwards compatibility hack
|
|
|
|
| 115 |
vae = HVAE(args).to(args.device)
|
| 116 |
|
| 117 |
if "Chest" in dataset_id:
|
| 118 |
+
dscm_ckpt = torch.load(dscm_path, map_location=DEVICE)
|
| 119 |
vae.load_state_dict(
|
| 120 |
{
|
| 121 |
k[4:]: v
|