root/masterdriverz/snakeoil-formatters/snakeoil/containers.py @ masterdriverz%2540gentoo.org-20070511133515-3shy3wccub8fxul2

Revision masterdriverz%2540gentoo.org-20070511133515-3shy3wccub8fxul2, 5.6 kB (checked in by Charlie Shepherd <masterdriverz@…>, 20 months ago)

Merge from ferringb

Line 
1# Copyright: 2005-2007 Brian Harring <ferringb@gmail.com>
2# License: GPL2
3
4"""
5collection of container classes
6"""
7
8from snakeoil.demandload import demandload
9demandload(
10    globals(),
11    'itertools:chain,ifilterfalse',
12)
13
14class InvertedContains(set):
15
16    """Set that inverts all contains lookup results.
17
18    Mainly useful in conjuection with LimitedChangeSet for converting
19    from blacklist to whitelist.
20
21    Cannot be iterated over.
22    """
23
24    def __contains__(self, key):
25        return not set.__contains__(self, key)
26
27    def __iter__(self):
28        # infinite set, non iterable.
29        raise TypeError("InvertedContains cannot be iterated over")
30
31
32class SetMixin(object):
33    """
34    A mixin providing set methods.
35
36    Subclasses should provide __init__, __iter__ and __contains__.
37    """
38
39    def __and__(self, other, kls=None):
40        # Note: for these methods we don't bother to filter dupes from this
41        # list -  since the subclasses __init__ should already handle this,
42        # there's no point doing it twice.
43        kls = kls or self.__class__
44        return kls(x for x in self if x in other)
45
46    def __rand__(self, other):
47        return self.__and__(other, kls=other.__class__)
48
49    def __or__(self, other, kls=None):
50        kls = kls or self.__class__
51        return kls(chain(self, other))
52
53    def __ror__(self, other):
54        return self.__or__(other, kls=other.__class__)
55
56    def __xor__(self, other, kls=None):
57        kls = kls or self.__class__
58        return kls(chain((x for x in self if x not in other),
59                         (x for x in other if x not in self)))
60
61    def __rxor__(self, other):
62        return self.__xor__(other, kls=other.__class__)
63
64    def __sub__(self, other):
65        return self.__class__(x for x in self if x not in other)
66
67    def __rsub__(self, other):
68        return other.__class__(x for x in other if x not in self)
69
70    __add__ = __or__
71    __radd__ = __ror__
72
73
74class LimitedChangeSet(SetMixin):
75
76    """Set used to limit the number of times a key can be removed/added.
77
78    specifically deleting/adding a key only once per commit,
79    optionally blocking changes to certain keys.
80    """
81
82    _removed    = 0
83    _added      = 1
84
85    def __init__(self, initial_keys, unchangable_keys=None):
86        self._new = set(initial_keys)
87        if unchangable_keys is None:
88            self._blacklist = []
89        else:
90            if isinstance(unchangable_keys, (list, tuple)):
91                unchangable_keys = set(unchangable_keys)
92            self._blacklist = unchangable_keys
93        self._changed = set()
94        self._change_order = []
95        self._orig = frozenset(self._new)
96
97    def add(self, key):
98        if key in self._changed or key in self._blacklist:
99            # it's been del'd already once upon a time.
100            if key in self._new:
101                return
102            raise Unchangable(key)
103
104        self._new.add(key)
105        self._changed.add(key)
106        self._change_order.append((self._added, key))
107
108    def remove(self, key):
109        if key in self._changed or key in self._blacklist:
110            if key not in self._new:
111                raise KeyError(key)
112            raise Unchangable(key)
113
114        if key in self._new:
115            self._new.remove(key)
116        self._changed.add(key)
117        self._change_order.append((self._removed, key))
118
119    def __contains__(self, key):
120        return key in self._new
121
122    def changes_count(self):
123        return len(self._change_order)
124
125    def commit(self):
126        self._orig = frozenset(self._new)
127        self._changed.clear()
128        self._change_order = []
129
130    def rollback(self, point=0):
131        l = self.changes_count()
132        if point < 0 or point > l:
133            raise TypeError(
134                "%s point must be >=0 and <= changes_count()" % point)
135        while l > point:
136            change, key = self._change_order.pop(-1)
137            self._changed.remove(key)
138            if change == self._removed:
139                self._new.add(key)
140            else:
141                self._new.remove(key)
142            l -= 1
143
144    def __str__(self):
145        return str(self._new).replace("set(", "LimitedChangeSet(", 1)
146
147    def __iter__(self):
148        return iter(self._new)
149
150    def __len__(self):
151        return len(self._new)
152
153    def __eq__(self, other):
154        if isinstance(other, LimitedChangeSet):
155            return self._new == other._new
156        elif isinstance(other, (frozenset, set)):
157            return self._new == other
158        return False
159
160    def __ne__(self, other):
161        return not (self == other)
162
163
164class Unchangable(Exception):
165
166    def __init__(self, key):
167        Exception.__init__(self, "key '%s' is unchangable" % (key,))
168        self.key = key
169
170
171class ProtectedSet(SetMixin):
172
173    """
174    Wraps a set pushing all changes into a secondary set.
175    """
176    def __init__(self, orig_set):
177        self._orig = orig_set
178        self._new = set()
179
180    def __contains__(self, key):
181        return key in self._orig or key in self._new
182
183    def __iter__(self):
184        return chain(iter(self._new),
185            ifilterfalse(self._new.__contains__, self._orig))
186
187    def __len__(self):
188        return len(self._orig.union(self._new))
189
190    def add(self, key):
191        if key not in self._orig:
192            self._new.add(key)
193
194
195class RefCountingSet(dict):
196
197    def __init__(self, iterable=None):
198        if iterable is not None:
199            dict.__init__(self, ((x, 1) for x in iterable))
200
201    def add(self, item):
202        count = self.get(item, 0)
203        self[item] = count + 1
204
205    def remove(self, item):
206        count = self[item]
207        if count == 1:
208            del self[item]
209        else:
210            self[item] = count - 1
Note: See TracBrowser for help on using the browser.