Skip to content

Commit

Permalink
Added device parameter to TorchOperator
Browse files Browse the repository at this point in the history
  • Loading branch information
mrava87 committed Jan 9, 2020
1 parent 700f663 commit cd16920
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
17 changes: 11 additions & 6 deletions pylops_gpu/TorchOperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@ class _TorchOperator(torch.autograd.Function):
"""
@staticmethod
def forward(ctx, x, forw, adj, pylops):
def forward(ctx, x, forw, adj, pylops, device):
ctx.forw = forw
ctx.adj = adj
ctx.pylops = pylops
ctx.device = device

if ctx.pylops:
x = x.cpu().detach().numpy()
y = ctx.forw(x)
if ctx.pylops:
y = torch.from_numpy(y)
y = torch.from_numpy(y).to(ctx.device)
return y

@staticmethod
Expand All @@ -28,8 +29,8 @@ def backward(ctx, y):
y = y.cpu().detach().numpy()
x = ctx.adj(y)
if ctx.pylops:
x = torch.from_numpy(x)
return x, None, None, None
x = torch.from_numpy(x).to(cxt.device)
return x, None, None, None, None


class TorchOperator():
Expand All @@ -56,15 +57,18 @@ class TorchOperator():
pylops : :obj:`bool`, optional
``Op`` is a pylops operator (``True``) or a pylops-gpu
operator (``False``)
device : :obj:`str`, optional
Device to be used for output vectors when ``Op`` is a pylops operator
Returns
-------
y : :obj:`torch.Tensor`
Output array resulting from the application of the operator to ``x``.
"""
def __init__(self, Op, batch=False, pylops=False):
def __init__(self, Op, batch=False, pylops=False, device='cpu'):
self.pylops = pylops
self.device = device
if not batch:
self.matvec = Op.matvec
self.rmatvec = Op.rmatvec
Expand All @@ -86,4 +90,5 @@ def apply(self, x):
Output array resulting from the application of the operator to ``x``.
"""
return _TorchOperator.apply(x, self.matvec, self.rmatvec, self.pylops)
return _TorchOperator.apply(x, self.matvec, self.rmatvec,
self.pylops, self.device)
2 changes: 1 addition & 1 deletion tutorials/poststack.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
r"""
01. Post-stack inversion
02. Post-stack inversion
========================
This tutorial focuses on extending post-stack seismic inversion to GPU
processing. We refer to the equivalent `PyLops tutorial <https://pylops.readthedocs.io/en/latest/tutorials/poststack.html>`_
Expand Down

0 comments on commit cd16920

Please sign in to comment.