zktraffic/base/sniffer.py (229 lines of code) (raw):
# ==================================================================================================
# Copyright 2014 Twitter, Inc.
# --------------------------------------------------------------------------------------------------
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this work except in compliance with the License.
# You may obtain a copy of the License in the LICENSE file, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==================================================================================================
from collections import defaultdict
from random import random
from threading import Thread
import logging
import os
import hexdump
import signal
import socket
import struct
import sys
from .client_message import ClientMessage, Request
from .network import BadPacket, get_ip, get_ip_packet, SnifferBase
from .server_message import Reply, ServerMessage, WatchEvent
from .zookeeper import DeserializationError, OpCodes
from .util import StringTooLong, to_bytes
from scapy.config import conf as scapy_conf
scapy_conf.logLevel = logging.ERROR # shush scapy
from scapy.sendrecv import sniff
from six.moves import intern
from twitter.common import log
DEFAULT_PORT = 2181
FOUR_LETTER_WORDS = (
to_bytes('conf'),
to_bytes('cons'),
to_bytes('crst'),
to_bytes('dump'),
to_bytes('envi'),
to_bytes('ruok'),
to_bytes('srst'),
to_bytes('srvr'),
to_bytes('stat'),
to_bytes('wchs'),
to_bytes('wchc'),
to_bytes('wchp'),
to_bytes('mntr'),
to_bytes('kill'), # deprecated
to_bytes('reqs'), # deprecated
)
class SnifferConfig(object):
def __init__(self,
iface="eth0",
writes_only=False,
debug=False):
"""
if client_port is 0 we sniff all clients
if zookeeper_port is changed later on you must call update_filter()
"""
self.iface = iface
self.writes_only = writes_only
self.debug = debug
self.client_port = 0
self.track_replies = False
self.max_queued_requests = 10000
self.zookeeper_port = DEFAULT_PORT
self.excluded_opcodes = set()
self.is_loopback = iface in ["lo", "lo0"]
self.read_timeout_ms = 0
self.dump_bad_packet = False
self.sampling = 1.0 # percentage of packets to inspect [0, 1]
# These are set after initialization, and require `update_filter` to be called
self.included_ips = []
self.excluded_ips = []
self.update_filter()
self.exclude_pings()
def update_filter(self):
self.filter = "port %d" % (self.zookeeper_port)
assert not (self.included_ips and self.excluded_ips)
if self.excluded_ips:
self.filter += " and host not " + " and host not ".join(self.excluded_ips)
elif self.included_ips:
self.filter += " and (host " + " or host ".join(self.included_ips) + ")"
def include_pings(self):
self.update_exclusion_list(OpCodes.PING, False)
def exclude_pings(self):
self.update_exclusion_list(OpCodes.PING, True)
def excluded(self, opcode):
return opcode in self.excluded_opcodes
def update_exclusion_list(self, opcode, exclude):
if exclude:
self.excluded_opcodes.add(opcode)
else:
try:
self.excluded_opcodes.remove(opcode)
except KeyError:
pass
def __str__(self):
return """
***sniffer config ***
iface = %s
writes_only = %s
filter = %s
zookeeper_port = %d
is_loopback = %s
read_timeout_ms = %d
debug = %s
""" % (self.iface,
str((self.writes_only)).lower(),
self.filter,
self.zookeeper_port,
str(self.is_loopback),
self.read_timeout_ms,
str(self.debug).lower())
class Sniffer(SnifferBase):
class RegistrationError(Exception): pass
def __init__(self,
config,
request_handler=None,
reply_handler=None,
event_handler=None,
error_to_stderr=False):
"""
This sniffer will intercept:
- client requests
- server replies
- server events (i.e.: connection state change or, most of the times, watches)
Hence handlers for each.
"""
super(Sniffer, self).__init__()
self._error_to_stderr = error_to_stderr
self._packet_size = 65535
self._request_handlers = []
self._reply_handlers = []
self._event_handlers = []
self._requests_xids = defaultdict(dict) # if tracking replies, keep a tab for seen reqs
self._four_letter_mode = {} # key: client addr, val: four letter
self._wants_stop = False
self.config = config
self.add_request_handler(request_handler)
self.add_reply_handler(reply_handler)
self.add_event_handler(event_handler)
self.setDaemon(True)
def stop(self):
self._wants_stop = True
def add_request_handler(self, handler):
self._add_handler(self._request_handlers, handler)
def add_reply_handler(self, handler):
self._add_handler(self._reply_handlers, handler)
def add_event_handler(self, handler):
self._add_handler(self._event_handlers, handler)
def _add_handler(self, handlers, handler):
if handler is None:
return
if handler in handlers:
raise self.RegistrationError("handler %s has already been added" % (handler))
handlers.append(handler)
def wants_stop(self, *args, **kwargs): # pragma: no cover
return self._wants_stop
def run(self):
try:
log.info("Setting filter: %s", self.config.filter)
if self.config.iface == "any": # pragma: no cover
sniff(
filter=self.config.filter,
store=0,
prn=self.handle_packet,
stop_filter=self.wants_stop
)
else:
sniff(
filter=self.config.filter,
store=0,
prn=self.handle_packet,
iface=self.config.iface,
stop_filter=self.wants_stop
)
except socket.error as ex:
if self._error_to_stderr:
sys.stderr.write("Error: %s, device: %s\n" % (ex, self.config.iface))
else:
log.error("Error: %s, device: %s", ex, self.config.iface)
finally:
log.info("The sniff loop exited")
os.kill(os.getpid(), signal.SIGINT)
def handle_packet(self, packet):
sampling = self.config.sampling
if sampling < 1.0 and random() > sampling:
return
try:
message = self.message_from_packet(packet)
self.handle_message(message)
except (BadPacket, StringTooLong, DeserializationError, struct.error) as ex:
if self.config.dump_bad_packet:
print("got: %s" % str(ex))
hexdump.hexdump(packet.load)
sys.stdout.flush()
def handle_message(self, message):
if message and not self.config.excluded(message.opcode):
for h in self._handlers_for(message):
h(message)
def _handlers_for(self, message):
if isinstance(message, Request):
if self.config.writes_only and not message.is_write:
raise BadPacket("Not a write packet")
return self._request_handlers
elif isinstance(message, Reply):
return self._reply_handlers
elif isinstance(message, WatchEvent):
return self._event_handlers
raise BadPacket("No handlers for: %s" % (message))
def message_from_packet(self, packet):
"""
:returns: Returns an instance of ClientMessage or ServerMessage (or a subclass)
:raises:
:exc:`BadPacket` if the packet is for a client we are not tracking
:exc:`DeserializationError` if deserialization failed
:exc:`struct.error` if deserialization failed
"""
client_port = self.config.client_port
zk_port = self.config.zookeeper_port
ip_p = get_ip_packet(packet.load, client_port, zk_port, self.config.is_loopback)
if 0 == len(ip_p.data.data):
return None
if ip_p.data.dport == zk_port:
data = ip_p.data.data
src = intern("%s:%s" % (get_ip(ip_p, ip_p.src), ip_p.data.sport))
dst = intern("%s:%s" % (get_ip(ip_p, ip_p.dst), ip_p.data.dport))
client, server = src, dst
if data.startswith(FOUR_LETTER_WORDS):
self._set_four_letter_mode(client, data[0:4])
raise BadPacket("Four letter request %s" % data[0:4])
client_message = ClientMessage.from_payload(data, client, server)
client_message.timestamp = packet.time
self._track_client_message(client_message)
return client_message
if ip_p.data.sport == zk_port:
data = ip_p.data.data
src = intern("%s:%s" % (get_ip(ip_p, ip_p.src), ip_p.data.sport))
dst = intern("%s:%s" % (get_ip(ip_p, ip_p.dst), ip_p.data.dport))
client, server = dst, src
four_letter = self._get_four_letter_mode(client)
if four_letter:
self._set_four_letter_mode(client, None)
raise BadPacket("Four letter response %s" % four_letter)
requests_xids = self._requests_xids.get(client, {})
server_message = ServerMessage.from_payload(data, client, server, requests_xids)
server_message.timestamp = packet.time
return server_message
raise BadPacket("Packet to the wrong port?")
def _track_client_message(self, request):
"""
Any request that is not a ping or a close should be tracked
"""
if self.config.track_replies and not request.is_ping and not request.is_close:
requests_xids = self._requests_xids[request.client]
if len(requests_xids) > self.config.max_queued_requests: # pragma: no cover
# TODO: logging the counts of each type of pkts in the queue when this happens
# could be useful.
if self._error_to_stderr:
sys.stderr.write("Too many queued requests, replies for %s will be lost\n" %
request.client)
else:
log.error("Too many queued requests, replies for %s will be lost", request.client)
return
requests_xids[request.xid] = request.opcode
def _get_four_letter_mode(self, client):
return self._four_letter_mode.get(client)
def _set_four_letter_mode(self, client, four_letter):
if four_letter:
self._four_letter_mode[client] = four_letter
else:
del self._four_letter_mode[client]