From cd16920eb7c1edacab36bd6d240c4db24f9528c7 Mon Sep 17 00:00:00 2001 From: mrava87 Date: Thu, 9 Jan 2020 23:27:43 +0100 Subject: [PATCH] Added device parameter to TorchOperator --- pylops_gpu/TorchOperator.py | 17 +++++++++++------ tutorials/poststack.py | 2 +- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/pylops_gpu/TorchOperator.py b/pylops_gpu/TorchOperator.py index 4772be6..4bf32f6 100644 --- a/pylops_gpu/TorchOperator.py +++ b/pylops_gpu/TorchOperator.py @@ -10,16 +10,17 @@ class _TorchOperator(torch.autograd.Function): """ @staticmethod - def forward(ctx, x, forw, adj, pylops): + def forward(ctx, x, forw, adj, pylops, device): ctx.forw = forw ctx.adj = adj ctx.pylops = pylops + ctx.device = device if ctx.pylops: x = x.cpu().detach().numpy() y = ctx.forw(x) if ctx.pylops: - y = torch.from_numpy(y) + y = torch.from_numpy(y).to(ctx.device) return y @staticmethod @@ -28,8 +29,8 @@ def backward(ctx, y): y = y.cpu().detach().numpy() x = ctx.adj(y) if ctx.pylops: - x = torch.from_numpy(x) - return x, None, None, None + x = torch.from_numpy(x).to(cxt.device) + return x, None, None, None, None class TorchOperator(): @@ -56,6 +57,8 @@ class TorchOperator(): pylops : :obj:`bool`, optional ``Op`` is a pylops operator (``True``) or a pylops-gpu operator (``False``) + device : :obj:`str`, optional + Device to be used for output vectors when ``Op`` is a pylops operator Returns ------- @@ -63,8 +66,9 @@ class TorchOperator(): Output array resulting from the application of the operator to ``x``. """ - def __init__(self, Op, batch=False, pylops=False): + def __init__(self, Op, batch=False, pylops=False, device='cpu'): self.pylops = pylops + self.device = device if not batch: self.matvec = Op.matvec self.rmatvec = Op.rmatvec @@ -86,4 +90,5 @@ def apply(self, x): Output array resulting from the application of the operator to ``x``. """ - return _TorchOperator.apply(x, self.matvec, self.rmatvec, self.pylops) + return _TorchOperator.apply(x, self.matvec, self.rmatvec, + self.pylops, self.device) diff --git a/tutorials/poststack.py b/tutorials/poststack.py index 8ace3f5..077197b 100644 --- a/tutorials/poststack.py +++ b/tutorials/poststack.py @@ -1,5 +1,5 @@ r""" -01. Post-stack inversion +02. Post-stack inversion ======================== This tutorial focuses on extending post-stack seismic inversion to GPU processing. We refer to the equivalent `PyLops tutorial `_