root/masterdriverz/snakeoil/snakeoil/containers.py @ masterdriverz%2540gentoo.org-20070622063143-0nuxfo9iay6az0bv

Revision masterdriverz%2540gentoo.org-20070622063143-0nuxfo9iay6az0bv, 5.5 kB (checked in by Charlie Shepherd <masterdriverz@…>, 19 months ago)

Condense the set methods into single lines

Line 
1# Copyright: 2005-2007 Brian Harring <ferringb@gmail.com>
2# License: GPL2
3
4"""
5collection of container classes
6"""
7
8from snakeoil.demandload import demandload
9demandload(
10    globals(),
11    'itertools:chain,ifilterfalse',
12)
13
14class InvertedContains(set):
15
16    """Set that inverts all contains lookup results.
17
18    Mainly useful in conjuection with LimitedChangeSet for converting
19    from blacklist to whitelist.
20
21    Cannot be iterated over.
22    """
23
24    def __contains__(self, key):
25        return not set.__contains__(self, key)
26
27    def __iter__(self):
28        # infinite set, non iterable.
29        raise TypeError("InvertedContains cannot be iterated over")
30
31
32class SetMixin(object):
33    """
34    A mixin providing set methods.
35
36    Subclasses should provide __init__, __iter__ and __contains__.
37    """
38
39    def __and__(self, other, kls=None):
40        # Note: for these methods we don't bother to filter dupes from this
41        # list -  since the subclasses __init__ should already handle this,
42        # there's no point doing it twice.
43        return (kls or self.__class__)(x for x in self if x in other)
44
45    def __rand__(self, other):
46        return self.__and__(other, kls=other.__class__)
47
48    def __or__(self, other, kls=None):
49        return (kls or self.__class__)(chain(self, other))
50
51    def __ror__(self, other):
52        return self.__or__(other, kls=other.__class__)
53
54    def __xor__(self, other, kls=None):
55        return (kls or self.__class__)(chain((x for x in self if x not in other),
56                         (x for x in other if x not in self)))
57
58    def __rxor__(self, other):
59        return self.__xor__(other, kls=other.__class__)
60
61    def __sub__(self, other):
62        return self.__class__(x for x in self if x not in other)
63
64    def __rsub__(self, other):
65        return other.__class__(x for x in other if x not in self)
66
67    __add__ = __or__
68    __radd__ = __ror__
69
70
71class LimitedChangeSet(SetMixin):
72
73    """Set used to limit the number of times a key can be removed/added.
74
75    specifically deleting/adding a key only once per commit,
76    optionally blocking changes to certain keys.
77    """
78
79    _removed    = 0
80    _added      = 1
81
82    def __init__(self, initial_keys, unchangable_keys=None):
83        self._new = set(initial_keys)
84        if unchangable_keys is None:
85            self._blacklist = []
86        else:
87            if isinstance(unchangable_keys, (list, tuple)):
88                unchangable_keys = set(unchangable_keys)
89            self._blacklist = unchangable_keys
90        self._changed = set()
91        self._change_order = []
92        self._orig = frozenset(self._new)
93
94    def add(self, key):
95        if key in self._changed or key in self._blacklist:
96            # it's been del'd already once upon a time.
97            if key in self._new:
98                return
99            raise Unchangable(key)
100
101        self._new.add(key)
102        self._changed.add(key)
103        self._change_order.append((self._added, key))
104
105    def remove(self, key):
106        if key in self._changed or key in self._blacklist:
107            if key not in self._new:
108                raise KeyError(key)
109            raise Unchangable(key)
110
111        if key in self._new:
112            self._new.remove(key)
113        self._changed.add(key)
114        self._change_order.append((self._removed, key))
115
116    def __contains__(self, key):
117        return key in self._new
118
119    def changes_count(self):
120        return len(self._change_order)
121
122    def commit(self):
123        self._orig = frozenset(self._new)
124        self._changed.clear()
125        self._change_order = []
126
127    def rollback(self, point=0):
128        l = self.changes_count()
129        if point < 0 or point > l:
130            raise TypeError(
131                "%s point must be >=0 and <= changes_count()" % point)
132        while l > point:
133            change, key = self._change_order.pop(-1)
134            self._changed.remove(key)
135            if change == self._removed:
136                self._new.add(key)
137            else:
138                self._new.remove(key)
139            l -= 1
140
141    def __str__(self):
142        return str(self._new).replace("set(", "LimitedChangeSet(", 1)
143
144    def __iter__(self):
145        return iter(self._new)
146
147    def __len__(self):
148        return len(self._new)
149
150    def __eq__(self, other):
151        if isinstance(other, LimitedChangeSet):
152            return self._new == other._new
153        elif isinstance(other, (frozenset, set)):
154            return self._new == other
155        return False
156
157    def __ne__(self, other):
158        return not (self == other)
159
160
161class Unchangable(Exception):
162
163    def __init__(self, key):
164        Exception.__init__(self, "key '%s' is unchangable" % (key,))
165        self.key = key
166
167
168class ProtectedSet(SetMixin):
169
170    """
171    Wraps a set pushing all changes into a secondary set.
172    """
173    def __init__(self, orig_set):
174        self._orig = orig_set
175        self._new = set()
176
177    def __contains__(self, key):
178        return key in self._orig or key in self._new
179
180    def __iter__(self):
181        return chain(iter(self._new),
182            ifilterfalse(self._new.__contains__, self._orig))
183
184    def __len__(self):
185        return len(self._orig.union(self._new))
186
187    def add(self, key):
188        if key not in self._orig:
189            self._new.add(key)
190
191
192class RefCountingSet(dict):
193
194    def __init__(self, iterable=None):
195        if iterable is not None:
196            dict.__init__(self, ((x, 1) for x in iterable))
197
198    def add(self, item):
199        count = self.get(item, 0)
200        self[item] = count + 1
201
202    def remove(self, item):
203        count = self[item]
204        if count == 1:
205            del self[item]
206        else:
207            self[item] = count - 1
Note: See TracBrowser for help on using the browser.