Skip to content

Commit

Permalink
Safeley handle nulls in WME constructors from SML
Browse files Browse the repository at this point in the history
If a client passes null to any of the `Create*WME` constructors, Soar segfaults
due to accessing invalid memory. This is because it assumes non-null strings for
attribute names and values.

Fix this by checking for null and using `"nil"` when required. This is in-line
with Soar behavior elsewhere (where the `nil` Symbol gets stringified into an
attribute name). Also print a warning so that the user knows that this happened.
It was likely an accident.

The fixes work for all of our SWIG-bindings, so we only need to test in one. Use
Python to test the conversions to `"nil"`.

Fixes #485.
  • Loading branch information
garfieldnate committed Sep 27, 2024
1 parent a173b29 commit e58a00a
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 13 deletions.
35 changes: 23 additions & 12 deletions Core/ClientSML/src/sml_ClientWorkingMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
#include "sml_ClientDirect.h"
#include <cassert>

// Null attribute name and value strings must be converted to legal strings;
#define STRINGIFIED_NULL "nil"
#define SAFE_STRING(s, location) ((s) ? (s) : (std::cerr << "Warning: Null string encountered at: " << location << std::endl, STRINGIFIED_NULL))

using namespace sml ;
using namespace soarxml;

Expand Down Expand Up @@ -838,8 +842,10 @@ Identifier* WorkingMemory::GetOutputLink()
StringElement* WorkingMemory::CreateStringWME(Identifier* parent, char const* pAttribute, char const* pValue)
{
assert(m_Agent == parent->GetAgent()) ;
const char* pAttributeSafe = SAFE_STRING(pAttribute, "CreateStringWME, argument 'pAttribute'");
const char* pValueSafe = SAFE_STRING(pValue, "CreateStringWME, argument 'pValue'");

StringElement* pWME = new StringElement(GetAgent(), parent, parent->GetValueAsString(), pAttribute, pValue, GenerateTimeTag()) ;
StringElement* pWME = new StringElement(GetAgent(), parent, parent->GetValueAsString(), pAttributeSafe, pValueSafe, GenerateTimeTag()) ;

// Record that the identifer owns this new WME
parent->AddChild(pWME) ;
Expand All @@ -848,7 +854,7 @@ StringElement* WorkingMemory::CreateStringWME(Identifier* parent, char const* pA
if (GetConnection()->IsDirectConnection())
{
EmbeddedConnection* pConnection = static_cast<EmbeddedConnection*>(GetConnection());
pConnection->DirectAddWME_String(m_AgentSMLHandle, parent->GetValueAsString(), pAttribute, pValue, pWME->GetTimeTag());
pConnection->DirectAddWME_String(m_AgentSMLHandle, parent->GetValueAsString(), pAttributeSafe, pValueSafe, pWME->GetTimeTag());

// Return immediately, without adding it to the commit list.
return pWME ;
Expand All @@ -874,8 +880,9 @@ StringElement* WorkingMemory::CreateStringWME(Identifier* parent, char const* pA
IntElement* WorkingMemory::CreateIntWME(Identifier* parent, char const* pAttribute, long long value)
{
assert(m_Agent == parent->GetAgent()) ;
const char* pAttributeSafe = SAFE_STRING(pAttribute, "CreateIntWME, argument 'pAttribute'");

IntElement* pWME = new IntElement(GetAgent(), parent, parent->GetValueAsString(), pAttribute, value, GenerateTimeTag()) ;
IntElement* pWME = new IntElement(GetAgent(), parent, parent->GetValueAsString(), pAttributeSafe, value, GenerateTimeTag()) ;

// Record that the identifer owns this new WME
parent->AddChild(pWME) ;
Expand All @@ -884,7 +891,7 @@ IntElement* WorkingMemory::CreateIntWME(Identifier* parent, char const* pAttribu
if (GetConnection()->IsDirectConnection())
{
EmbeddedConnection* pConnection = static_cast<EmbeddedConnection*>(GetConnection());
pConnection->DirectAddWME_Int(m_AgentSMLHandle, parent->GetValueAsString(), pAttribute, value, pWME->GetTimeTag());
pConnection->DirectAddWME_Int(m_AgentSMLHandle, parent->GetValueAsString(), pAttributeSafe, value, pWME->GetTimeTag());

// Return immediately, without adding it to the commit list.
return pWME ;
Expand All @@ -910,8 +917,9 @@ IntElement* WorkingMemory::CreateIntWME(Identifier* parent, char const* pAttribu
FloatElement* WorkingMemory::CreateFloatWME(Identifier* parent, char const* pAttribute, double value)
{
assert(m_Agent == parent->GetAgent()) ;
const char* pAttributeSafe = SAFE_STRING(pAttribute, "CreateFloatWME, argument 'pAttribute'");

FloatElement* pWME = new FloatElement(GetAgent(), parent, parent->GetValueAsString(), pAttribute, value, GenerateTimeTag()) ;
FloatElement* pWME = new FloatElement(GetAgent(), parent, parent->GetValueAsString(), pAttributeSafe, value, GenerateTimeTag()) ;

// Record that the identifer owns this new WME
parent->AddChild(pWME) ;
Expand All @@ -920,7 +928,7 @@ FloatElement* WorkingMemory::CreateFloatWME(Identifier* parent, char const* pAtt
if (GetConnection()->IsDirectConnection())
{
EmbeddedConnection* pConnection = static_cast<EmbeddedConnection*>(GetConnection());
pConnection->DirectAddWME_Double(m_AgentSMLHandle, parent->GetValueAsString(), pAttribute, value, pWME->GetTimeTag());
pConnection->DirectAddWME_Double(m_AgentSMLHandle, parent->GetValueAsString(), pAttributeSafe, value, pWME->GetTimeTag());

// Return immediately, without adding it to the commit list.
return pWME ;
Expand Down Expand Up @@ -1148,12 +1156,14 @@ Identifier* WorkingMemory::CreateIdWME(Identifier* parent, char const* pAttribut
{
assert(m_Agent == parent->GetAgent()) ;

char const* pAttributeSafe = SAFE_STRING(pAttribute, "CreateIdWME, argument 'pAttribute'"); ;

// Create a new, unique id (e.g. "i3"). This id will be mapped to a different id
// in the kernel.
std::string id ;
GenerateNewID(pAttribute, &id) ;
GenerateNewID(pAttributeSafe, &id) ;

Identifier* pWME = new Identifier(GetAgent(), parent, parent->GetValueAsString(), pAttribute, id.c_str(), GenerateTimeTag()) ;
Identifier* pWME = new Identifier(GetAgent(), parent, parent->GetValueAsString(), pAttributeSafe, id.c_str(), GenerateTimeTag()) ;

// Record that the identifer owns this new WME
parent->AddChild(pWME) ;
Expand All @@ -1162,7 +1172,7 @@ Identifier* WorkingMemory::CreateIdWME(Identifier* parent, char const* pAttribut
if (GetConnection()->IsDirectConnection())
{
EmbeddedConnection* pConnection = static_cast<EmbeddedConnection*>(GetConnection());
pConnection->DirectAddID(m_AgentSMLHandle, parent->GetValueAsString(), pAttribute, id.c_str(), pWME->GetTimeTag());
pConnection->DirectAddID(m_AgentSMLHandle, parent->GetValueAsString(), pAttributeSafe, id.c_str(), pWME->GetTimeTag());

// Return immediately, without adding it to the commit list.
return pWME ;
Expand Down Expand Up @@ -1190,13 +1200,14 @@ Identifier* WorkingMemory::CreateSharedIdWME(Identifier* parent, char const* pAt
{
assert(m_Agent == parent->GetAgent()) ;
assert(m_Agent == pSharedValue->GetAgent()) ;
const char* pAttributeSafe = SAFE_STRING(pAttribute, "CreateSharedIdWME, argument 'pAttribute'"); ;

// bug 1060
// need to check and make sure that this shared wme will not violate the set
{
// find other wmes on parent with same attribute
WMElement* wme = 0;
for (int i = 0; (wme = parent->FindByAttribute(pAttribute, i)) != 0; ++i)
for (int i = 0; (wme = parent->FindByAttribute(pAttributeSafe, i)) != 0; ++i)
{
if (wme == pSharedValue)
{
Expand All @@ -1209,7 +1220,7 @@ Identifier* WorkingMemory::CreateSharedIdWME(Identifier* parent, char const* pAt
std::string id = pSharedValue->GetValueAsString() ;

// Create the new WME with the same value
Identifier* pWME = new Identifier(GetAgent(), parent, parent->GetValueAsString(), pAttribute, pSharedValue, GenerateTimeTag()) ;
Identifier* pWME = new Identifier(GetAgent(), parent, parent->GetValueAsString(), pAttributeSafe, pSharedValue, GenerateTimeTag()) ;

// Record that the identifer owns this new WME
parent->AddChild(pWME) ;
Expand All @@ -1218,7 +1229,7 @@ Identifier* WorkingMemory::CreateSharedIdWME(Identifier* parent, char const* pAt
if (GetConnection()->IsDirectConnection())
{
EmbeddedConnection* pConnection = static_cast<EmbeddedConnection*>(GetConnection());
pConnection->DirectAddID(m_AgentSMLHandle, parent->GetValueAsString(), pAttribute, id.c_str(), pWME->GetTimeTag());
pConnection->DirectAddID(m_AgentSMLHandle, parent->GetValueAsString(), pAttributeSafe, id.c_str(), pWME->GetTimeTag());

// Return immediately, without adding it to the commit list.
return pWME ;
Expand Down
30 changes: 29 additions & 1 deletion Core/ClientSMLSWIG/Python/TestPythonSML.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
#
# This file needs to be compatible with python 3.5, to run on CI jobs testing the lowest supported python version.
from pathlib import Path
import sys
import os
import re
import sys
import time

try:
Expand Down Expand Up @@ -189,6 +190,33 @@ def test_agent_reinit(agent):

test_agent_reinit(agent)

def test_pass_null_to_WME_constructors(agent):
# The logic for safley handling null arguments in the C++ code is not language-specific,
# so this test does not need to be run in the other SWIG-binding languages
link = agent.GetInputLink()

id_wme = link.CreateIdWME("foo")
link.CreateSharedIdWME(None, id_wme)
link.CreateIdWME(None)
link.CreateStringWME(None, None)
link.CreateIntWME(None, 1)
link.CreateFloatWME(None, 1.0)

agent.ExecuteCommandLine("step")

result_string = agent.ExecuteCommandLine("print I2")
failure_message = lambda details: f"❌ Pass null to WME constructors: {details}\nResult string was: " + result_string + "\n"
assert "^foo F1" in result_string, failure_message("CreateIdWME('foo') failed")
assert "^nil F1" in result_string, failure_message("CreateSharedIdWME(None, id_wme) failed")
assert "^nil N1" in result_string, failure_message("CreateIdWME(None) failed")
assert "^nil nil" in result_string, failure_message("CreateStringWME(None, None) failed")
assert re.search(r"\^nil 1(?!\.)", result_string), failure_message("CreateIntWME(None) failed")
assert "^nil 1.0" in result_string, failure_message("CreateFloatWME(None) failed")

print("✅ Pass null to WME constructors")

test_pass_null_to_WME_constructors(agent)

def test_agent_destroy(agent):
assert not agent_destroy_called.called, "❌ Agent destroy handler called before destroy"
kernel.DestroyAgent(agent)
Expand Down

0 comments on commit e58a00a

Please sign in to comment.