Skip to content

Commit

Permalink
[child-table] check neighbor is in ChildTable before cast as Child (
Browse files Browse the repository at this point in the history
#8071)

This commit adds a check `Get<ChildTable>().Contains(Neighbor &)`
to ensure that the neighbor is from the child table before casting
the neighbor entry to `Child`. This adds safety and protection against
potential corner-case where we have a neighbor entry which is
`Parent` or `ParentCandidate` that is a REED.
  • Loading branch information
abtink authored Aug 25, 2022
1 parent d7cbc17 commit cba1beb
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/core/thread/mesh_forwarder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1834,8 +1834,8 @@ uint16_t MeshForwarder::CalcFrameVersion(const Neighbor *aNeighbor, bool aIePres
version = Mac::Frame::kFcfFrameVersion2015;
}
#if OPENTHREAD_FTD && OPENTHREAD_CONFIG_MAC_CSL_TRANSMITTER_ENABLE
else if (aNeighbor != nullptr && !Mle::IsActiveRouter(aNeighbor->GetRloc16()) &&
Get<Mle::MleRouter>().IsRouterOrLeader() && static_cast<const Child *>(aNeighbor)->IsCslSynchronized())
else if ((aNeighbor != nullptr) && Get<ChildTable>().Contains(*aNeighbor) &&
static_cast<const Child *>(aNeighbor)->IsCslSynchronized())
{
version = Mac::Frame::kFcfFrameVersion2015;
}
Expand Down
28 changes: 15 additions & 13 deletions src/core/thread/mesh_forwarder_ftd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ Error MeshForwarder::SendMessage(Message &aMessage)
{
Mle::MleRouter &mle = Get<Mle::MleRouter>();
Error error = kErrorNone;
Neighbor * neighbor;

aMessage.SetOffset(0);
aMessage.SetDatagramTag(0);
Expand Down Expand Up @@ -104,17 +103,20 @@ Error MeshForwarder::SendMessage(Message &aMessage)
}
}
}
else if ((neighbor = Get<NeighborTable>().FindNeighbor(ip6Header.GetDestination())) != nullptr &&
!neighbor->IsRxOnWhenIdle() && !aMessage.IsDirectTransmission())
else // Destination is unicast
{
// destined for a sleepy child
Child &child = *static_cast<Child *>(neighbor);
mIndirectSender.AddMessageForSleepyChild(aMessage, child);
}
else
{
// schedule direct transmission
aMessage.SetDirectTransmission();
Neighbor *neighbor = Get<NeighborTable>().FindNeighbor(ip6Header.GetDestination());

if ((neighbor != nullptr) && !neighbor->IsRxOnWhenIdle() && !aMessage.IsDirectTransmission() &&
Get<ChildTable>().Contains(*neighbor))
{
// Destined for a sleepy child
mIndirectSender.AddMessageForSleepyChild(aMessage, *static_cast<Child *>(neighbor));
}
else
{
aMessage.SetDirectTransmission();
}
}

break;
Expand Down Expand Up @@ -283,7 +285,7 @@ void MeshForwarder::RemoveMessages(Child &aChild, Message::SubType aSubType)

IgnoreError(message.Read(0, ip6header));

if (&aChild == static_cast<Child *>(Get<NeighborTable>().FindNeighbor(ip6header.GetDestination())))
if (&aChild == Get<NeighborTable>().FindNeighbor(ip6header.GetDestination()))
{
message.ClearDirectTransmission();
}
Expand All @@ -297,7 +299,7 @@ void MeshForwarder::RemoveMessages(Child &aChild, Message::SubType aSubType)

IgnoreError(meshHeader.ParseFrom(message));

if (&aChild == static_cast<Child *>(Get<NeighborTable>().FindNeighbor(meshHeader.GetDestination())))
if (&aChild == Get<NeighborTable>().FindNeighbor(meshHeader.GetDestination()))
{
message.ClearDirectTransmission();
}
Expand Down
5 changes: 3 additions & 2 deletions src/core/thread/mle_router.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2724,7 +2724,8 @@ void MleRouter::HandleChildUpdateResponse(RxInfo &aRxInfo)
Child * child;
uint16_t addressRegistrationOffset = 0;

if ((aRxInfo.mNeighbor == nullptr) || IsActiveRouter(aRxInfo.mNeighbor->GetRloc16()))
if ((aRxInfo.mNeighbor == nullptr) || IsActiveRouter(aRxInfo.mNeighbor->GetRloc16()) ||
!Get<ChildTable>().Contains(*aRxInfo.mNeighbor))
{
Log(kMessageReceive, kTypeChildUpdateResponseOfUnknownChild, aRxInfo.mMessageInfo.GetPeerAddr());
ExitNow(error = kErrorNotFound);
Expand Down Expand Up @@ -3529,7 +3530,7 @@ void MleRouter::RemoveNeighbor(Neighbor &aNeighbor)
}
else if (!IsActiveRouter(aNeighbor.GetRloc16()))
{
OT_ASSERT(mChildTable.GetChildIndex(static_cast<Child &>(aNeighbor)) < kMaxChildren);
OT_ASSERT(mChildTable.Contains(aNeighbor));

if (aNeighbor.IsStateValidOrRestoring())
{
Expand Down
1 change: 1 addition & 0 deletions src/core/thread/neighbor_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ void NeighborTable::Signal(Event aEvent, const Neighbor &aNeighbor)
case kChildRemoved:
case kChildModeChanged:
#if OPENTHREAD_FTD
OT_ASSERT(Get<ChildTable>().Contains(aNeighbor));
static_cast<Child::Info &>(info.mInfo.mChild).SetFrom(static_cast<const Child &>(aNeighbor));
#endif
break;
Expand Down

0 comments on commit cba1beb

Please sign in to comment.