Skip to content

Commit

Permalink
try mycli/pgcli in order of: virtualenv > system > regular mysql/psql
Browse files Browse the repository at this point in the history
  • Loading branch information
jontsai committed May 28, 2019
1 parent af6a47d commit 557c9cd
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ Credits
=======

* Simon Percivall <[email protected]>
* Jonathan Tsai <[email protected]>
41 changes: 35 additions & 6 deletions lib/django_dbshell_plus/management/commands/dbshell_plus.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import errno
import subprocess
import sys

from django.db import connections
from django.core.management.commands import dbshell
Expand All @@ -18,18 +19,46 @@ def handle(self, **options):
cmd = dbclis.get(connection.vendor)
if cmd:
try:
# attempt to use mycli/pgcli
getattr(self, cmd)(connection)
return
except OSError, e:
if e.errno != errno.ENOENT:
self.stderr.write("Could not start %s: %s" % (cmd, str(e)))
if self._is_virtualenv:
try:
# retry without explicitly without virtualenv
getattr(self, cmd)(connection, ignore_virtualenv=True)
return
except OSError, e:
if e.errno != errno.ENOENT:
self.stderr.write("Could not start %s: %s" % (cmd, str(e)))
else:
if e.errno != errno.ENOENT:
self.stderr.write("Could not start %s: %s" % (cmd, str(e)))

# default to system mysql/psql
super(Command, self).handle(**options)

def pgcli(self, connection):
@property
def _is_virtualenv(self):
# sys.real_prefix is only set if inside virtualenv
return hasattr(sys, 'real_prefix')

@property
def _python_path(self):
path = '{}/bin/'.format(sys.prefix) if self._is_virtualenv else ''
return path

def _get_cli_command(self, cli, ignore_virtualenv=False):
cli_command = '{}{}'.format(
'' if ignore_virtualenv else self._python_path,
cli, # 'pgcli' or 'mycli'
)
return cli_command

def pgcli(self, connection, ignore_virtualenv=False):
# argument code copied from Django
settings_dict = connection.settings_dict
args = ['pgcli']
args = [self._get_cli_command('pgcli', ignore_virtualenv=ignore_virtualenv)]
if settings_dict['USER']:
args += ["-U", settings_dict['USER']]
if settings_dict['HOST']:
Expand All @@ -40,10 +69,10 @@ def pgcli(self, connection):

subprocess.call(args)

def mycli(self, connection):
def mycli(self, connection, ignore_virtualenv=False):
# argument code copied from Django
settings_dict = connection.settings_dict
args = ['mycli']
args = [self._get_cli_command('mycli', ignore_virtualenv=ignore_virtualenv)]
db = settings_dict['OPTIONS'].get('db', settings_dict['NAME'])
user = settings_dict['OPTIONS'].get('user', settings_dict['USER'])
passwd = settings_dict['OPTIONS'].get('passwd', settings_dict['PASSWORD'])
Expand Down

0 comments on commit 557c9cd

Please sign in to comment.