summaryrefslogtreecommitdiffstats
path: root/tools/multi_process_rss.py
blob: 100d0f759b1c4660de6904b16899aeab30bd00fd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#!/usr/bin/env python
# Copyright 2013 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

# Counts a resident set size (RSS) of multiple processes without double-counts.
# If they share the same page frame, the page frame is counted only once.
#
# Usage:
# ./multi-process-rss.py <pid>|<pid>r [...]
#
# If <pid> has 'r' at the end, all descendants of the process are accounted.
#
# Example:
# ./multi-process-rss.py 12345 23456r
#
# The command line above counts the RSS of 1) process 12345, 2) process 23456
# and 3) all descendant processes of process 23456.


import collections
import logging
import os
import psutil
import sys


if sys.platform.startswith('linux'):
  _TOOLS_PATH = os.path.dirname(os.path.abspath(__file__))
  _TOOLS_LINUX_PATH = os.path.join(_TOOLS_PATH, 'linux')
  sys.path.append(_TOOLS_LINUX_PATH)
  import procfs  # pylint: disable=F0401


class _NullHandler(logging.Handler):
  def emit(self, record):
    pass


_LOGGER = logging.getLogger('multi-process-rss')
_LOGGER.addHandler(_NullHandler())


def _recursive_get_children(pid):
  try:
    children = psutil.Process(pid).get_children()
  except psutil.error.NoSuchProcess:
    return []
  descendant = []
  for child in children:
    descendant.append(child.pid)
    descendant.extend(_recursive_get_children(child.pid))
  return descendant


def list_pids(argv):
  pids = []
  for arg in argv[1:]:
    try:
      if arg.endswith('r'):
        recursive = True
        pid = int(arg[:-1])
      else:
        recursive = False
        pid = int(arg)
    except ValueError:
      raise SyntaxError("%s is not an integer." % arg)
    else:
      pids.append(pid)
    if recursive:
      children = _recursive_get_children(pid)
      pids.extend(children)

  pids = sorted(set(pids), key=pids.index)  # uniq: maybe slow, but simple.

  return pids


def count_pageframes(pids):
  pageframes = collections.defaultdict(int)
  pagemap_dct = {}
  for pid in pids:
    maps = procfs.ProcMaps.load(pid)
    if not maps:
      _LOGGER.warning('/proc/%d/maps not found.' % pid)
      continue
    pagemap = procfs.ProcPagemap.load(pid, maps)
    if not pagemap:
      _LOGGER.warning('/proc/%d/pagemap not found.' % pid)
      continue
    pagemap_dct[pid] = pagemap

  for pid, pagemap in pagemap_dct.iteritems():
    for vma in pagemap.vma_internals.itervalues():
      for pageframe, number in vma.pageframes.iteritems():
        pageframes[pageframe] += number

  return pageframes


def count_statm(pids):
  resident = 0
  shared = 0
  private = 0

  for pid in pids:
    statm = procfs.ProcStatm.load(pid)
    if not statm:
      _LOGGER.warning('/proc/%d/statm not found.' % pid)
      continue
    resident += statm.resident
    shared += statm.share
    private += (statm.resident - statm.share)

  return (resident, shared, private)


def main(argv):
  logging_handler = logging.StreamHandler()
  logging_handler.setLevel(logging.WARNING)
  logging_handler.setFormatter(logging.Formatter(
      '%(asctime)s:%(name)s:%(levelname)s:%(message)s'))

  _LOGGER.setLevel(logging.WARNING)
  _LOGGER.addHandler(logging_handler)

  if sys.platform.startswith('linux'):
    logging.getLogger('procfs').setLevel(logging.WARNING)
    logging.getLogger('procfs').addHandler(logging_handler)
    pids = list_pids(argv)
    pageframes = count_pageframes(pids)
  else:
    _LOGGER.error('%s is not supported.' % sys.platform)
    return 1

  # TODO(dmikurube): Classify this total RSS.
  print len(pageframes) * 4096

  return 0


if __name__ == '__main__':
  sys.exit(main(sys.argv))