Skip to content

Commit

Permalink
Merge pull request #189 from proteneer/ast_fix
Browse files Browse the repository at this point in the history
Use AST as opposed to unsafe eval.
  • Loading branch information
j-wags committed Feb 26, 2019
2 parents 4bfe675 + 1a3ffa6 commit f63ddbc
Showing 1 changed file with 31 additions and 4 deletions.
35 changes: 31 additions & 4 deletions openforcefield/typing/engines/smirnoff/forcefield.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
#=============================================================================================
# GLOBAL IMPORTS
#=============================================================================================

import ast
import operator as op
import collections
import sys
import string

Expand Down Expand Up @@ -58,6 +60,7 @@

import itertools


#=============================================================================================
# PRIVATE SUBROUTINES
#=============================================================================================
Expand Down Expand Up @@ -1102,6 +1105,28 @@ def _validateSMIRKS(smirks, node=None):

return smirks

def _ast_eval(node):
"""
Performs an algebraic syntax tree evaluation of a unit.
"""

operators = {ast.Add: op.add, ast.Sub: op.sub, ast.Mult: op.mul,
ast.Div: op.truediv, ast.Pow: op.pow, ast.BitXor: op.xor,
ast.USub: op.neg}

if isinstance(node, ast.Num): # <number>
return node.n
elif isinstance(node, ast.BinOp): # <left> <operator> <right>
return operators[type(node.op)](_ast_eval(node.left), _ast_eval(node.right))
elif isinstance(node, ast.UnaryOp): # <operator> <operand> e.g., -1
return operators[type(node.op)](_ast_eval(node.operand))
elif isinstance(node, ast.Name):
# see if this is a simtk unit
b = getattr(unit, node.id)
return b
else:
raise TypeError(node)

def _extractQuantity(node, parent, name, unit_name=None):
"""
Form a (potentially unit-bearing) quantity from the specified attribute name.
Expand Down Expand Up @@ -1137,9 +1162,11 @@ def _extractQuantity(node, parent, name, unit_name=None):
unit_name = name + '_unit'

if unit_name in parent.attrib:
# TODO: This is very dangerous.
string = '(%s * %s).value_in_unit_system(md_unit_system)' % (node.attrib[name], parent.attrib[unit_name])
quantity = eval(string, unit.__dict__)
a = node.attrib[name]
b = parent.attrib[unit_name]
parsed_units = _ast_eval(ast.parse(b, mode='eval').body)
result = float(a) * parsed_units
quantity = result.value_in_unit_system(unit.md_unit_system)

return quantity

Expand Down

0 comments on commit f63ddbc

Please sign in to comment.