aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBrian Harring <ferringb@gmail.com>2022-12-24 13:14:53 -0800
committerArthur Zamarin <arthurzam@gentoo.org>2022-12-25 19:49:11 +0200
commitd6a7c2e44b4f497357f8569d423104232a58f384 (patch)
tree625ac52169356714a9f5e69e11f2b6cc2d72355a
parentcompression: prefer gtar over tar if available (diff)
downloadsnakeoil-d6a7c2e44b4f497357f8569d423104232a58f384.tar.gz
snakeoil-d6a7c2e44b4f497357f8569d423104232a58f384.tar.bz2
snakeoil-d6a7c2e44b4f497357f8569d423104232a58f384.zip
Reformat w/ black 22.12.0 for consistency.
Signed-off-by: Brian Harring <ferringb@gmail.com> Signed-off-by: Arthur Zamarin <arthurzam@gentoo.org>
-rw-r--r--doc/conf.py157
-rw-r--r--src/snakeoil/__init__.py4
-rw-r--r--src/snakeoil/_fileutils.py25
-rw-r--r--src/snakeoil/bash.py115
-rw-r--r--src/snakeoil/caching.py10
-rw-r--r--src/snakeoil/chksum/__init__.py19
-rw-r--r--src/snakeoil/chksum/defaults.py53
-rw-r--r--src/snakeoil/cli/arghparse.py481
-rw-r--r--src/snakeoil/cli/exceptions.py6
-rw-r--r--src/snakeoil/cli/input.py37
-rw-r--r--src/snakeoil/cli/tool.py63
-rw-r--r--src/snakeoil/compatibility.py15
-rw-r--r--src/snakeoil/compression/__init__.py107
-rw-r--r--src/snakeoil/compression/_bzip2.py30
-rw-r--r--src/snakeoil/compression/_util.py59
-rw-r--r--src/snakeoil/compression/_xz.py28
-rw-r--r--src/snakeoil/constraints.py30
-rw-r--r--src/snakeoil/containers.py26
-rw-r--r--src/snakeoil/contexts.py78
-rw-r--r--src/snakeoil/currying.py60
-rw-r--r--src/snakeoil/data_source.py81
-rw-r--r--src/snakeoil/decorators.py7
-rw-r--r--src/snakeoil/demandimport.py30
-rw-r--r--src/snakeoil/demandload.py95
-rw-r--r--src/snakeoil/dependant_methods.py20
-rw-r--r--src/snakeoil/errors.py14
-rw-r--r--src/snakeoil/fileutils.py43
-rw-r--r--src/snakeoil/formatters.py115
-rw-r--r--src/snakeoil/iterables.py12
-rw-r--r--src/snakeoil/klass.py164
-rw-r--r--src/snakeoil/mappings.py67
-rw-r--r--src/snakeoil/modules.py1
-rw-r--r--src/snakeoil/obj.py169
-rw-r--r--src/snakeoil/osutils/__init__.py40
-rw-r--r--src/snakeoil/osutils/mount.py10
-rw-r--r--src/snakeoil/osutils/native_readdir.py55
-rw-r--r--src/snakeoil/pickling.py3
-rw-r--r--src/snakeoil/process/__init__.py8
-rw-r--r--src/snakeoil/process/namespaces.py47
-rw-r--r--src/snakeoil/process/spawn.py95
-rw-r--r--src/snakeoil/sequences.py35
-rw-r--r--src/snakeoil/stringio.py2
-rw-r--r--src/snakeoil/strings.py14
-rw-r--r--src/snakeoil/tar.py27
-rw-r--r--src/snakeoil/test/__init__.py29
-rw-r--r--src/snakeoil/test/argparse_helpers.py57
-rw-r--r--src/snakeoil/test/eq_hash_inheritance.py11
-rw-r--r--src/snakeoil/test/mixins.py46
-rw-r--r--src/snakeoil/test/modules.py4
-rw-r--r--src/snakeoil/test/slot_shadowing.py22
-rw-r--r--src/snakeoil/version.py50
-rw-r--r--src/snakeoil/weakrefs.py1
-rw-r--r--tests/cli/test_arghparse.py336
-rw-r--r--tests/cli/test_input.py99
-rw-r--r--tests/compression/__init__.py68
-rw-r--r--tests/compression/test_bzip2.py40
-rw-r--r--tests/compression/test_init.py77
-rw-r--r--tests/compression/test_xz.py30
-rw-r--r--tests/test_bash.py283
-rw-r--r--tests/test_caching.py22
-rw-r--r--tests/test_chksum.py4
-rw-r--r--tests/test_chksum_defaults.py37
-rw-r--r--tests/test_constraints.py64
-rw-r--r--tests/test_containers.py19
-rw-r--r--tests/test_contexts.py51
-rw-r--r--tests/test_currying.py119
-rw-r--r--tests/test_data_source.py37
-rw-r--r--tests/test_decorators.py37
-rw-r--r--tests/test_demandload.py86
-rw-r--r--tests/test_demandload_usage.py6
-rw-r--r--tests/test_dependant_methods.py21
-rw-r--r--tests/test_fileutils.py105
-rw-r--r--tests/test_formatters.py243
-rw-r--r--tests/test_iterables.py11
-rw-r--r--tests/test_klass.py154
-rw-r--r--tests/test_mappings.py208
-rw-r--r--tests/test_modules.py67
-rw-r--r--tests/test_obj.py85
-rw-r--r--tests/test_osutils.py221
-rw-r--r--tests/test_process.py6
-rw-r--r--tests/test_process_spawn.py50
-rw-r--r--tests/test_sequences.py110
-rw-r--r--tests/test_stringio.py5
-rw-r--r--tests/test_strings.py34
-rw-r--r--tests/test_version.py139
85 files changed, 3268 insertions, 2383 deletions
diff --git a/doc/conf.py b/doc/conf.py
index d2cbc09d..0a83d80a 100644
--- a/doc/conf.py
+++ b/doc/conf.py
@@ -18,231 +18,248 @@ import sys
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
-sys.path.insert(0, os.path.abspath('../src/'))
+sys.path.insert(0, os.path.abspath("../src/"))
# generate API docs
-subprocess.call([
- 'sphinx-apidoc', '-ef', '-o', 'api', '../src/snakeoil',
- '../src/snakeoil/dist', '../src/snakeoil/test', # excludes
-])
+subprocess.call(
+ [
+ "sphinx-apidoc",
+ "-ef",
+ "-o",
+ "api",
+ "../src/snakeoil",
+ "../src/snakeoil/dist",
+ "../src/snakeoil/test", # excludes
+ ]
+)
# -- General configuration -----------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
-needs_sphinx = '1.0'
+needs_sphinx = "1.0"
# Add any Sphinx extension module names here, as strings. They can be extensions
# coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
extensions = [
- 'sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx.ext.doctest',
- 'sphinx.ext.extlinks',
- 'sphinx.ext.intersphinx', 'sphinx.ext.todo', 'sphinx.ext.coverage',
- 'sphinx.ext.ifconfig', 'sphinx.ext.graphviz',
- 'sphinx.ext.viewcode',
+ "sphinx.ext.autodoc",
+ "sphinx.ext.autosummary",
+ "sphinx.ext.doctest",
+ "sphinx.ext.extlinks",
+ "sphinx.ext.intersphinx",
+ "sphinx.ext.todo",
+ "sphinx.ext.coverage",
+ "sphinx.ext.ifconfig",
+ "sphinx.ext.graphviz",
+ "sphinx.ext.viewcode",
]
# Add any paths that contain templates here, relative to this directory.
-templates_path = ['_templates']
+templates_path = ["_templates"]
# The suffix of source filenames.
-source_suffix = '.rst'
+source_suffix = ".rst"
# The encoding of source files.
-#source_encoding = 'utf-8-sig'
+# source_encoding = 'utf-8-sig'
# The master toctree document.
-master_doc = 'index'
+master_doc = "index"
# General information about the project.
-project = 'snakeoil'
-authors = ''
-copyright = '2007-2022, snakeoil contributors'
+project = "snakeoil"
+authors = ""
+copyright = "2007-2022, snakeoil contributors"
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# The short X.Y version.
-version = '-trunk'
+version = "-trunk"
# The full version, including alpha/beta/rc tags.
-release = '-trunk'
+release = "-trunk"
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
-#language = None
+# language = None
# There are two options for replacing |today|: either, you set today to some
# non-false value, then it is used:
-#today = ''
+# today = ''
# Else, today_fmt is used as the format for a strftime call.
-#today_fmt = '%B %d, %Y'
+# today_fmt = '%B %d, %Y'
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
-exclude_patterns = ['_build']
+exclude_patterns = ["_build"]
# The reST default role (used for this markup: `text`) to use for all documents.
-#default_role = None
+# default_role = None
# If true, '()' will be appended to :func: etc. cross-reference text.
-#add_function_parentheses = True
+# add_function_parentheses = True
# If true, the current module name will be prepended to all description
# unit titles (such as .. function::).
-#add_module_names = True
+# add_module_names = True
# If true, sectionauthor and moduleauthor directives will be shown in the
# output. They are ignored by default.
show_authors = False
# The name of the Pygments (syntax highlighting) style to use.
-pygments_style = 'sphinx'
+pygments_style = "sphinx"
# A list of ignored prefixes for module index sorting.
-#modindex_common_prefix = []
+# modindex_common_prefix = []
# -- Options for HTML output ---------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
-#html_theme = 'default'
-html_theme = 'default'
+# html_theme = 'default'
+html_theme = "default"
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
-#html_theme_options = {}
+# html_theme_options = {}
# Add any paths that contain custom themes here, relative to this directory.
-#html_theme_path = []
+# html_theme_path = []
# The name for this set of Sphinx documents. If None, it defaults to
# "<project> v<release> documentation".
-#html_title = None
+# html_title = None
# A shorter title for the navigation bar. Default is the same as html_title.
-#html_short_title = None
+# html_short_title = None
# The name of an image file (relative to this directory) to place at the top
# of the sidebar.
-#html_logo = None
+# html_logo = None
# The name of an image file (within the static path) to use as favicon of the
# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
# pixels large.
-#html_favicon = None
+# html_favicon = None
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
-#html_static_path = ['_static']
+# html_static_path = ['_static']
# If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
# using the given strftime format.
-#html_last_updated_fmt = '%b %d, %Y'
+# html_last_updated_fmt = '%b %d, %Y'
# If true, SmartyPants will be used to convert quotes and dashes to
# typographically correct entities.
-#html_use_smartypants = True
+# html_use_smartypants = True
# Custom sidebar templates, maps document names to template names.
-#html_sidebars = {}
+# html_sidebars = {}
# Additional templates that should be rendered to pages, maps page names to
# template names.
-#html_additional_pages = {}
+# html_additional_pages = {}
# If false, no module index is generated.
-#html_domain_indices = True
+# html_domain_indices = True
# If false, no index is generated.
-#html_use_index = True
+# html_use_index = True
# If true, the index is split into individual pages for each letter.
-#html_split_index = False
+# html_split_index = False
# If true, links to the reST sources are added to the pages.
html_show_sourcelink = False
# If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
-#html_show_sphinx = True
+# html_show_sphinx = True
# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
-#html_show_copyright = True
+# html_show_copyright = True
# If true, an OpenSearch description file will be output, and all pages will
# contain a <link> tag referring to it. The value of this option must be the
# base URL from which the finished HTML is served.
-#html_use_opensearch = ''
+# html_use_opensearch = ''
# This is the file name suffix for HTML files (e.g. ".xhtml").
-#html_file_suffix = None
+# html_file_suffix = None
# Output file base name for HTML help builder.
-htmlhelp_basename = 'snakeoildoc'
+htmlhelp_basename = "snakeoildoc"
# -- Options for LaTeX output --------------------------------------------------
# The paper size ('letter' or 'a4').
-#latex_paper_size = 'letter'
+# latex_paper_size = 'letter'
# The font size ('10pt', '11pt' or '12pt').
-#latex_font_size = '10pt'
+# latex_font_size = '10pt'
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title, author, documentclass [howto/manual]).
latex_documents = [
- ('index', 'snakeoil.tex', 'snakeoil Documentation',
- authors, 'manual'),
+ ("index", "snakeoil.tex", "snakeoil Documentation", authors, "manual"),
]
# The name of an image file (relative to this directory) to place at the top of
# the title page.
-#latex_logo = None
+# latex_logo = None
# For "manual" documents, if this is true, then toplevel headings are parts,
# not chapters.
-#latex_use_parts = False
+# latex_use_parts = False
# If true, show page references after internal links.
-#latex_show_pagerefs = False
+# latex_show_pagerefs = False
# If true, show URL addresses after external links.
-#latex_show_urls = False
+# latex_show_urls = False
# Additional stuff for the LaTeX preamble.
-#latex_preamble = ''
+# latex_preamble = ''
# Documents to append as an appendix to all manuals.
-#latex_appendices = []
+# latex_appendices = []
# If false, no module index is generated.
-#latex_domain_indices = True
+# latex_domain_indices = True
# -- Options for manual page output --------------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
-man_pages = [
- ('index', 'snakeoil', 'snakeoil Documentation', [], 1)
-]
+man_pages = [("index", "snakeoil", "snakeoil Documentation", [], 1)]
# Example configuration for intersphinx: refer to the Python standard library.
-intersphinx_mapping = {'http://docs.python.org/': None}
-autodoc_default_flags = ['members', 'show-inheritance', 'inherited-members'] # + ['undoc-members']
+intersphinx_mapping = {"http://docs.python.org/": None}
+autodoc_default_flags = [
+ "members",
+ "show-inheritance",
+ "inherited-members",
+] # + ['undoc-members']
autosummary_generate = False
rst_epilog = """
.. |homepage| replace:: https://github.com/pkgcore/snakeoil
.. |release_url| replace:: https://github.com/pkgcore/snakeoil/releases
-""" % {"release": release}
+""" % {
+ "release": release
+}
extlinks = {
- 'git_tag': ('https://github.com/pkgcore/snakeoil/releases/tag/%s', 'git log '),
- 'git_release': ('https://github.com/pkgcore/snakeoil/archive/%s.tar.gz',
- 'release download '),
+ "git_tag": ("https://github.com/pkgcore/snakeoil/releases/tag/%s", "git log "),
+ "git_release": (
+ "https://github.com/pkgcore/snakeoil/archive/%s.tar.gz",
+ "release download ",
+ ),
}
diff --git a/src/snakeoil/__init__.py b/src/snakeoil/__init__.py
index 636d6636..a1c1f205 100644
--- a/src/snakeoil/__init__.py
+++ b/src/snakeoil/__init__.py
@@ -10,5 +10,5 @@ This library is a bit of a grabbag of the following:
* optimized implementations of common patterns
"""
-__title__ = 'snakeoil'
-__version__ = '0.10.4'
+__title__ = "snakeoil"
+__version__ = "0.10.4"
diff --git a/src/snakeoil/_fileutils.py b/src/snakeoil/_fileutils.py
index 51f82c76..4a226486 100644
--- a/src/snakeoil/_fileutils.py
+++ b/src/snakeoil/_fileutils.py
@@ -5,7 +5,9 @@ Access this functionality from :py:module:`snakeoil.osutils` instead
"""
__all__ = (
- "mmap_and_close", "readlines_iter", "native_readlines",
+ "mmap_and_close",
+ "readlines_iter",
+ "native_readlines",
"native_readfile",
)
@@ -31,6 +33,7 @@ def mmap_and_close(fd, *args, **kwargs):
class readlines_iter:
__slots__ = ("iterable", "mtime", "source")
+
def __init__(self, iterable, mtime, close=True, source=None):
if source is None:
source = iterable
@@ -54,17 +57,25 @@ class readlines_iter:
source.close()
def close(self):
- if hasattr(self.source, 'close'):
+ if hasattr(self.source, "close"):
self.source.close()
def __iter__(self):
return self.iterable
+
def _native_readlines_shim(*args, **kwds):
- return native_readlines('r', *args, **kwds)
+ return native_readlines("r", *args, **kwds)
+
-def native_readlines(mode, mypath, strip_whitespace=True, swallow_missing=False,
- none_on_missing=False, encoding=None):
+def native_readlines(
+ mode,
+ mypath,
+ strip_whitespace=True,
+ swallow_missing=False,
+ none_on_missing=False,
+ encoding=None,
+):
"""Read a file, yielding each line.
:param mypath: fs path for the file to read
@@ -102,8 +113,10 @@ def _py2k_ascii_strict_filter(source):
raise ValueError("character ordinal over 127")
yield line
+
def _native_readfile_shim(*args, **kwds):
- return native_readfile('r', *args, **kwds)
+ return native_readfile("r", *args, **kwds)
+
def native_readfile(mode, mypath, none_on_missing=False, encoding=None):
"""Read a file, returning the contents.
diff --git a/src/snakeoil/bash.py b/src/snakeoil/bash.py
index 5eaaca2c..3ca737cb 100644
--- a/src/snakeoil/bash.py
+++ b/src/snakeoil/bash.py
@@ -15,19 +15,25 @@ from .fileutils import readlines
from .log import logger
from .mappings import ProtectedDict
-demand_compile_regexp('line_cont_regexp', r'^(.*[^\\]|)\\$')
-demand_compile_regexp('inline_comment_regexp', r'^.*\s#.*$')
-demand_compile_regexp('var_find', r'\\?(\${\w+}|\$\w+)')
-demand_compile_regexp('backslash_find', r'\\.')
-demand_compile_regexp('ansi_escape_re', r'(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]')
+demand_compile_regexp("line_cont_regexp", r"^(.*[^\\]|)\\$")
+demand_compile_regexp("inline_comment_regexp", r"^.*\s#.*$")
+demand_compile_regexp("var_find", r"\\?(\${\w+}|\$\w+)")
+demand_compile_regexp("backslash_find", r"\\.")
+demand_compile_regexp("ansi_escape_re", r"(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]")
__all__ = (
- "iter_read_bash", "read_bash", "read_dict", "read_bash_dict",
- "bash_parser", "BashParseError")
-
-
-def iter_read_bash(bash_source, allow_inline_comments=True,
- allow_line_cont=False, enum_line=False):
+ "iter_read_bash",
+ "read_bash",
+ "read_dict",
+ "read_bash_dict",
+ "bash_parser",
+ "BashParseError",
+)
+
+
+def iter_read_bash(
+ bash_source, allow_inline_comments=True, allow_line_cont=False, enum_line=False
+):
"""Iterate over a file honoring bash commenting rules and line continuations.
Note that it's considered good behaviour to close filehandles, as
@@ -45,7 +51,7 @@ def iter_read_bash(bash_source, allow_inline_comments=True,
"""
if isinstance(bash_source, str):
bash_source = readlines(bash_source, True)
- s = ''
+ s = ""
for lineno, line in enumerate(bash_source, 1):
if allow_line_cont and s:
s += line
@@ -53,19 +59,20 @@ def iter_read_bash(bash_source, allow_inline_comments=True,
s = line.lstrip()
if s:
- if s[0] != '#':
+ if s[0] != "#":
if allow_inline_comments:
- if (not allow_line_cont or
- (allow_line_cont and inline_comment_regexp.match(line))):
+ if not allow_line_cont or (
+ allow_line_cont and inline_comment_regexp.match(line)
+ ):
s = s.split("#", 1)[0].rstrip()
if allow_line_cont and line_cont_regexp.match(line):
- s = s.rstrip('\\\n')
+ s = s.rstrip("\\\n")
continue
if enum_line:
yield lineno, s.rstrip()
else:
yield s.rstrip()
- s = ''
+ s = ""
if s:
if enum_line:
yield lineno, s
@@ -122,7 +129,7 @@ def read_bash_dict(bash_source, vars_dict=None, sourcing_command=None):
try:
while tok is not None:
key = s.get_token()
- if key == 'export':
+ if key == "export":
# discard 'export' token from "export VAR=VALUE" lines
key = s.get_token()
if key is None:
@@ -133,23 +140,23 @@ def read_bash_dict(bash_source, vars_dict=None, sourcing_command=None):
# detect empty assigns
continue
eq = s.get_token()
- if eq != '=':
+ if eq != "=":
raise BashParseError(
- bash_source, s.lineno,
- "got token %r, was expecting '='" % eq)
+ bash_source, s.lineno, "got token %r, was expecting '='" % eq
+ )
val = s.get_token()
if val is None:
- val = ''
- elif val == 'export':
+ val = ""
+ elif val == "export":
val = s.get_token()
# look ahead to see if we just got an empty assign.
next_tok = s.get_token()
- if next_tok == '=':
+ if next_tok == "=":
# ... we did.
# leftmost insertions, thus reversed ordering
s.push_token(next_tok)
s.push_token(val)
- val = ''
+ val = ""
else:
s.push_token(next_tok)
d[key] = val
@@ -163,9 +170,15 @@ def read_bash_dict(bash_source, vars_dict=None, sourcing_command=None):
return d
-def read_dict(bash_source, splitter="=", source_isiter=False,
- allow_inline_comments=True, strip=False, filename=None,
- ignore_errors=False):
+def read_dict(
+ bash_source,
+ splitter="=",
+ source_isiter=False,
+ allow_inline_comments=True,
+ strip=False,
+ filename=None,
+ ignore_errors=False,
+):
"""Read key value pairs from a file, ignoring bash-style comments.
:param splitter: the string to split on. Can be None to
@@ -180,12 +193,11 @@ def read_dict(bash_source, splitter="=", source_isiter=False,
d = {}
if not source_isiter:
filename = bash_source
- i = iter_read_bash(
- bash_source, allow_inline_comments=allow_inline_comments)
+ i = iter_read_bash(bash_source, allow_inline_comments=allow_inline_comments)
else:
if filename is None:
# XXX what to do?
- filename = '<unknown>'
+ filename = "<unknown>"
i = bash_source
line_count = 0
try:
@@ -195,10 +207,11 @@ def read_dict(bash_source, splitter="=", source_isiter=False,
k, v = k.split(splitter, 1)
except ValueError as e:
if filename == "<unknown>":
- filename = getattr(bash_source, 'name', bash_source)
+ filename = getattr(bash_source, "name", bash_source)
if ignore_errors:
logger.error(
- 'bash parse error in %r, line %s', filename, line_count)
+ "bash parse error in %r, line %s", filename, line_count
+ )
continue
else:
raise BashParseError(filename, line_count) from e
@@ -239,7 +252,7 @@ class bash_parser(shlex):
:param env: initial environment to use for variable interpolation
:type env: must be a mapping; if None, an empty dict is used
"""
- self.__dict__['state'] = ' '
+ self.__dict__["state"] = " "
super().__init__(source, posix=True, infile=infile)
self.wordchars += "@${}/.-+/:~^*"
self.wordchars = frozenset(self.wordchars)
@@ -252,12 +265,10 @@ class bash_parser(shlex):
def __setattr__(self, attr, val):
if attr == "state":
- if (self.state, val) in (
- ('"', 'a'), ('a', '"'), ('a', ' '), ("'", 'a')):
+ if (self.state, val) in (('"', "a"), ("a", '"'), ("a", " "), ("'", "a")):
strl = len(self.token)
if self.__pos != strl:
- self.changed_state.append(
- (self.state, self.token[self.__pos:]))
+ self.changed_state.append((self.state, self.token[self.__pos :]))
self.__pos = strl
self.__dict__[attr] = val
@@ -275,13 +286,13 @@ class bash_parser(shlex):
return token
if self.state is None:
# eof reached.
- self.changed_state.append((self.state, token[self.__pos:]))
+ self.changed_state.append((self.state, token[self.__pos :]))
else:
- self.changed_state.append((self.state, self.token[self.__pos:]))
- tok = ''
+ self.changed_state.append((self.state, self.token[self.__pos :]))
+ tok = ""
for s, t in self.changed_state:
if s in ('"', "a"):
- tok += self.var_expand(t).replace("\\\n", '')
+ tok += self.var_expand(t).replace("\\\n", "")
else:
tok += t
return tok
@@ -291,26 +302,27 @@ class bash_parser(shlex):
l = []
while match := var_find.search(val, pos):
pos = match.start()
- if val[pos] == '\\':
+ if val[pos] == "\\":
# it's escaped. either it's \\$ or \\${ , either way,
# skipping two ahead handles it.
pos += 2
else:
- var = val[match.start():match.end()].strip("${}")
+ var = val[match.start() : match.end()].strip("${}")
if prev != pos:
l.append(val[prev:pos])
if var in self.env:
if not isinstance(self.env[var], str):
raise ValueError(
- "env key %r must be a string, not %s: %r" % (
- var, type(self.env[var]), self.env[var]))
+ "env key %r must be a string, not %s: %r"
+ % (var, type(self.env[var]), self.env[var])
+ )
l.append(self.env[var])
else:
l.append("")
prev = pos = match.end()
# do \\ cleansing, collapsing val down also.
- val = backslash_find.sub(_nuke_backslash, ''.join(l) + val[prev:])
+ val = backslash_find.sub(_nuke_backslash, "".join(l) + val[prev:])
return val
@@ -320,10 +332,11 @@ class BashParseError(Exception):
def __init__(self, filename, line, errmsg=None):
if errmsg is not None:
super().__init__(
- "error parsing '%s' on or before line %i: err %s" %
- (filename, line, errmsg))
+ "error parsing '%s' on or before line %i: err %s"
+ % (filename, line, errmsg)
+ )
else:
super().__init__(
- "error parsing '%s' on or before line %i" %
- (filename, line))
+ "error parsing '%s' on or before line %i" % (filename, line)
+ )
self.file, self.line, self.errmsg = filename, line, errmsg
diff --git a/src/snakeoil/caching.py b/src/snakeoil/caching.py
index 3ed75d23..b20f0741 100644
--- a/src/snakeoil/caching.py
+++ b/src/snakeoil/caching.py
@@ -79,21 +79,22 @@ class WeakInstMeta(type):
Examples of usage is the restrictions subsystem for
U{pkgcore project<http://pkgcore.org>}
"""
+
def __new__(cls, name, bases, d):
if d.get("__inst_caching__", False):
d["__inst_caching__"] = True
d["__inst_dict__"] = WeakValueDictionary()
else:
d["__inst_caching__"] = False
- slots = d.get('__slots__')
+ slots = d.get("__slots__")
# get ourselves a singleton to be safe...
o = object()
if slots is not None:
for base in bases:
- if getattr(base, '__weakref__', o) is not o:
+ if getattr(base, "__weakref__", o) is not o:
break
else:
- d['__slots__'] = tuple(slots) + ('__weakref__',)
+ d["__slots__"] = tuple(slots) + ("__weakref__",)
return type.__new__(cls, name, bases, d)
def __call__(cls, *a, **kw):
@@ -105,8 +106,7 @@ class WeakInstMeta(type):
try:
instance = cls.__inst_dict__.get(key)
except (NotImplementedError, TypeError) as t:
- warnings.warn(
- f"caching keys for {cls}, got {t} for a={a}, kw={kw}")
+ warnings.warn(f"caching keys for {cls}, got {t} for a={a}, kw={kw}")
del t
key = instance = None
diff --git a/src/snakeoil/chksum/__init__.py b/src/snakeoil/chksum/__init__.py
index 8f2bc731..5bd6a7ce 100644
--- a/src/snakeoil/chksum/__init__.py
+++ b/src/snakeoil/chksum/__init__.py
@@ -64,10 +64,9 @@ def init(additional_handlers=None):
:param additional_handlers: None, or pass in a dict of type:func
"""
- global __inited__ # pylint: disable=global-statement
+ global __inited__ # pylint: disable=global-statement
- if additional_handlers is not None and not isinstance(
- additional_handlers, dict):
+ if additional_handlers is not None and not isinstance(additional_handlers, dict):
raise TypeError("additional handlers must be a dict!")
chksum_types.clear()
@@ -119,15 +118,19 @@ def get_chksums(location, *chksums, **kwds):
# try to hand off to the per file handler, may be faster.
if len(chksums) == 1:
return [handlers[chksums[0]](location)]
- if len(chksums) == 2 and 'size' in chksums:
+ if len(chksums) == 2 and "size" in chksums:
parallelize = False
else:
parallelize = kwds.get("parallelize", True)
can_mmap = True
for k in chksums:
can_mmap &= handlers[k].can_mmap
- return chksum_loop_over_file(location, [handlers[k].new() for k in chksums],
- parallelize=parallelize, can_mmap=can_mmap)
+ return chksum_loop_over_file(
+ location,
+ [handlers[k].new() for k in chksums],
+ parallelize=parallelize,
+ can_mmap=can_mmap,
+ )
class LazilyHashedPath(metaclass=klass.immutable_instance):
@@ -135,7 +138,7 @@ class LazilyHashedPath(metaclass=klass.immutable_instance):
def __init__(self, path, **initial_values):
f = object.__setattr__
- f(self, 'path', path)
+ f(self, "path", path)
for attr, val in initial_values.items():
f(self, attr, val)
@@ -143,7 +146,7 @@ class LazilyHashedPath(metaclass=klass.immutable_instance):
if not attr.islower():
# Disallow sHa1.
raise AttributeError(attr)
- elif attr == 'mtime':
+ elif attr == "mtime":
val = osutils.stat_mtime_long(self.path)
else:
try:
diff --git a/src/snakeoil/chksum/defaults.py b/src/snakeoil/chksum/defaults.py
index bf2be41a..fe01377b 100644
--- a/src/snakeoil/chksum/defaults.py
+++ b/src/snakeoil/chksum/defaults.py
@@ -11,7 +11,7 @@ from sys import intern
from ..data_source import base as base_data_source
from ..fileutils import mmap_or_open_for_read
-blocksize = 2 ** 17
+blocksize = 2**17
blake2b_size = 128
blake2s_size = 64
@@ -36,8 +36,11 @@ def chf_thread(queue, callback):
def chksum_loop_over_file(filename, chfs, parallelize=True, can_mmap=True):
chfs = [chf() for chf in chfs]
loop_over_file(
- filename, [chf.update for chf in chfs],
- parallelize=parallelize, can_mmap=can_mmap)
+ filename,
+ [chf.update for chf in chfs],
+ parallelize=parallelize,
+ can_mmap=can_mmap,
+ )
return [int(chf.hexdigest(), 16) for chf in chfs]
@@ -54,7 +57,7 @@ def loop_over_file(handle, callbacks, parallelize=True, can_mmap=True):
else:
f = handle
close_f = False
- if getattr(handle, 'encoding', None):
+ if getattr(handle, "encoding", None):
# wanker. bypass the encoding, go straight to the raw source.
f = f.buffer
# reset; we do it for compat, but it also avoids unpleasant issues from
@@ -68,8 +71,10 @@ def loop_over_file(handle, callbacks, parallelize=True, can_mmap=True):
if parallelize:
queues = [queue.Queue(8) for _ in callbacks]
- threads = [threading.Thread(target=chf_thread, args=(queue, functor))
- for queue, functor in zip(queues, callbacks)]
+ threads = [
+ threading.Thread(target=chf_thread, args=(queue, functor))
+ for queue, functor in zip(queues, callbacks)
+ ]
for thread in threads:
thread.start()
@@ -79,7 +84,7 @@ def loop_over_file(handle, callbacks, parallelize=True, can_mmap=True):
if m is not None:
for callback in callbacks:
callback(m)
- elif hasattr(f, 'getvalue'):
+ elif hasattr(f, "getvalue"):
data = f.getvalue()
if not isinstance(data, bytes):
data = data.encode()
@@ -107,7 +112,6 @@ def loop_over_file(handle, callbacks, parallelize=True, can_mmap=True):
class Chksummer:
-
def __init__(self, chf_type, obj, str_size, can_mmap=True):
self.obj = obj
self.chf_type = chf_type
@@ -118,15 +122,14 @@ class Chksummer:
return self.obj
def long2str(self, val):
- return ("%x" % val).rjust(self.str_size, '0')
+ return ("%x" % val).rjust(self.str_size, "0")
@staticmethod
def str2long(val):
return int(val, 16)
def __call__(self, filename):
- return chksum_loop_over_file(
- filename, [self.obj], can_mmap=self.can_mmap)[0]
+ return chksum_loop_over_file(filename, [self.obj], can_mmap=self.can_mmap)[0]
def __str__(self):
return "%s chksummer" % self.chf_type
@@ -134,31 +137,26 @@ class Chksummer:
chksum_types = {
chksumname: Chksummer(chksumname, partial(hashlib.new, hashlibname), size)
-
for hashlibname, chksumname, size in [
# conditional upon FIPS, but available in >3.8.
- ('md5', 'md5', md5_size),
-
+ ("md5", "md5", md5_size),
# Guaranteed as of python 3.8
- ('blake2b', 'blake2b', blake2b_size),
- ('blake2s', 'blake2s', blake2s_size),
- ('sha1', 'sha1', sha1_size),
- ('sha256', 'sha256', sha256_size),
- ('sha3_256', 'sha3_256', sha3_256_size),
- ('sha3_512', 'sha3_512', sha3_512_size),
- ('sha512', 'sha512', sha512_size),
-
+ ("blake2b", "blake2b", blake2b_size),
+ ("blake2s", "blake2s", blake2s_size),
+ ("sha1", "sha1", sha1_size),
+ ("sha256", "sha256", sha256_size),
+ ("sha3_256", "sha3_256", sha3_256_size),
+ ("sha3_512", "sha3_512", sha3_512_size),
+ ("sha512", "sha512", sha512_size),
# not guaranteed, but may be available.
- ('whirlpool', 'whirlpool', whirlpool_size),
- ('ripemd160', 'rmd160', rmd160_size),
-
+ ("whirlpool", "whirlpool", whirlpool_size),
+ ("ripemd160", "rmd160", rmd160_size),
]
if hashlibname in hashlib.algorithms_available
}
class SizeUpdater:
-
def __init__(self):
self.count = 0
@@ -176,8 +174,7 @@ class SizeChksummer(Chksummer):
"""
def __init__(self):
- super().__init__(
- chf_type='size', obj=SizeUpdater, str_size=1000000000)
+ super().__init__(chf_type="size", obj=SizeUpdater, str_size=1000000000)
@staticmethod
def long2str(val):
diff --git a/src/snakeoil/cli/arghparse.py b/src/snakeoil/cli/arghparse.py
index 1c832c3e..774699ed 100644
--- a/src/snakeoil/cli/arghparse.py
+++ b/src/snakeoil/cli/arghparse.py
@@ -9,8 +9,18 @@ import pkgutil
import subprocess
import sys
import traceback
-from argparse import (_UNRECOGNIZED_ARGS_ATTR, OPTIONAL, PARSER, REMAINDER, SUPPRESS, ZERO_OR_MORE,
- ArgumentError, _, _get_action_name, _SubParsersAction)
+from argparse import (
+ _UNRECOGNIZED_ARGS_ATTR,
+ OPTIONAL,
+ PARSER,
+ REMAINDER,
+ SUPPRESS,
+ ZERO_OR_MORE,
+ ArgumentError,
+ _,
+ _get_action_name,
+ _SubParsersAction,
+)
from collections import Counter
from functools import partial
from itertools import chain
@@ -31,11 +41,11 @@ from ..version import get_version
_generate_docs = False
-@klass.patch('argparse.ArgumentParser.add_subparsers')
-@klass.patch('argparse._SubParsersAction.add_parser')
-@klass.patch('argparse._ActionsContainer.add_mutually_exclusive_group')
-@klass.patch('argparse._ActionsContainer.add_argument_group')
-@klass.patch('argparse._ActionsContainer.add_argument')
+@klass.patch("argparse.ArgumentParser.add_subparsers")
+@klass.patch("argparse._SubParsersAction.add_parser")
+@klass.patch("argparse._ActionsContainer.add_mutually_exclusive_group")
+@klass.patch("argparse._ActionsContainer.add_argument_group")
+@klass.patch("argparse._ActionsContainer.add_argument")
def _add_argument_docs(orig_func, self, *args, **kwargs):
"""Enable docs keyword argument support for argparse arguments.
@@ -48,16 +58,16 @@ def _add_argument_docs(orig_func, self, *args, **kwargs):
enable the global _generate_docs variable in order to replace the
summarized help strings with the extended doc strings.
"""
- docs = kwargs.pop('docs', None)
+ docs = kwargs.pop("docs", None)
obj = orig_func(self, *args, **kwargs)
if _generate_docs and docs is not None:
if isinstance(docs, (list, tuple)):
# list args are often used if originator wanted to strip
# off first description summary line
- docs = '\n'.join(docs)
- docs = '\n'.join(dedent(docs).strip().split('\n'))
+ docs = "\n".join(docs)
+ docs = "\n".join(dedent(docs).strip().split("\n"))
- if orig_func.__name__ == 'add_subparsers':
+ if orig_func.__name__ == "add_subparsers":
# store original description before overriding it with extended
# docs for general subparsers argument groups
self._subparsers._description = self._subparsers.description
@@ -93,7 +103,7 @@ class ParseNonblockingStdin(argparse.Action):
"""Accept arguments from standard input in a non-blocking fashion."""
def __init__(self, *args, **kwargs):
- self.filter_func = kwargs.pop('filter_func', lambda x: x.strip())
+ self.filter_func = kwargs.pop("filter_func", lambda x: x.strip())
super().__init__(*args, **kwargs)
def _stdin(self):
@@ -106,9 +116,11 @@ class ParseNonblockingStdin(argparse.Action):
break
def __call__(self, parser, namespace, values, option_string=None):
- if values is not None and len(values) == 1 and values[0] == '-':
+ if values is not None and len(values) == 1 and values[0] == "-":
if sys.stdin.isatty():
- raise argparse.ArgumentError(self, "'-' is only valid when piping data in")
+ raise argparse.ArgumentError(
+ self, "'-' is only valid when piping data in"
+ )
values = self._stdin()
setattr(namespace, self.dest, values)
@@ -117,16 +129,18 @@ class ParseStdin(ExtendAction):
"""Accept arguments from standard input in a blocking fashion."""
def __init__(self, *args, **kwargs):
- self.filter_func = kwargs.pop('filter_func', lambda x: x.strip())
+ self.filter_func = kwargs.pop("filter_func", lambda x: x.strip())
super().__init__(*args, **kwargs)
def __call__(self, parser, namespace, values, option_string=None):
- if values is not None and len(values) == 1 and values[0] == '-':
+ if values is not None and len(values) == 1 and values[0] == "-":
if sys.stdin.isatty():
- raise argparse.ArgumentError(self, "'-' is only valid when piping data in")
+ raise argparse.ArgumentError(
+ self, "'-' is only valid when piping data in"
+ )
values = [x.rstrip() for x in sys.stdin.readlines() if self.filter_func(x)]
# reassign stdin to allow interactivity (currently only works for unix)
- sys.stdin = open('/dev/tty')
+ sys.stdin = open("/dev/tty")
super().__call__(parser, namespace, values, option_string)
@@ -136,10 +150,10 @@ class CommaSeparatedValues(argparse._AppendAction):
def parse_values(self, values):
items = []
if isinstance(values, str):
- items.extend(x for x in values.split(',') if x)
+ items.extend(x for x in values.split(",") if x)
else:
for value in values:
- items.extend(x for x in value.split(',') if x)
+ items.extend(x for x in value.split(",") if x)
return items
def __call__(self, parser, namespace, values, option_string=None):
@@ -174,16 +188,16 @@ class CommaSeparatedNegations(argparse._AppendAction):
values = [values]
for value in values:
try:
- neg, pos = split_negations(x for x in value.split(',') if x)
+ neg, pos = split_negations(x for x in value.split(",") if x)
except ValueError as e:
raise argparse.ArgumentTypeError(e)
disabled.extend(neg)
enabled.extend(pos)
if colliding := set(disabled).intersection(enabled):
- collisions = ', '.join(map(repr, sorted(colliding)))
+ collisions = ", ".join(map(repr, sorted(colliding)))
s = pluralism(colliding)
- msg = f'colliding value{s}: {collisions}'
+ msg = f"colliding value{s}: {collisions}"
raise argparse.ArgumentError(self, msg)
return disabled, enabled
@@ -222,7 +236,7 @@ class CommaSeparatedElements(argparse._AppendAction):
values = [values]
for value in values:
try:
- neg, neu, pos = split_elements(x for x in value.split(',') if x)
+ neg, neu, pos = split_elements(x for x in value.split(",") if x)
except ValueError as e:
raise argparse.ArgumentTypeError(e)
disabled.extend(neg)
@@ -231,9 +245,9 @@ class CommaSeparatedElements(argparse._AppendAction):
elements = [set(x) for x in (disabled, neutral, enabled) if x]
if len(elements) > 1 and (colliding := set.intersection(*elements)):
- collisions = ', '.join(map(repr, sorted(colliding)))
+ collisions = ", ".join(map(repr, sorted(colliding)))
s = pluralism(colliding)
- msg = f'colliding value{s}: {collisions}'
+ msg = f"colliding value{s}: {collisions}"
raise argparse.ArgumentError(self, msg)
return disabled, neutral, enabled
@@ -260,14 +274,14 @@ class ManHelpAction(argparse._HelpAction):
"""Display man pages for long --help option and abbreviated output for -h."""
def __call__(self, parser, namespace, values, option_string=None):
- if option_string == '--help':
+ if option_string == "--help":
# Try spawning man page -- assumes one level deep for subcommand
# specific man pages with commands separated by hyphen. For example
# running `pinspect profile --help` tries to open pinspect-profile
# man page, but `pinspect profile masks --help` also tries to open
# pinspect-profile.
- man_page = '-'.join(parser.prog.split()[:2])
- p = subprocess.Popen(['man', man_page], stderr=subprocess.DEVNULL)
+ man_page = "-".join(parser.prog.split()[:2])
+ p = subprocess.Popen(["man", man_page], stderr=subprocess.DEVNULL)
p.communicate()
if p.returncode == 0:
parser.exit()
@@ -279,16 +293,17 @@ class ManHelpAction(argparse._HelpAction):
class StoreBool(argparse._StoreAction):
-
- def __init__(self,
- option_strings,
- dest,
- nargs=None,
- const=None,
- default=None,
- required=False,
- help=None,
- metavar='BOOLEAN'):
+ def __init__(
+ self,
+ option_strings,
+ dest,
+ nargs=None,
+ const=None,
+ default=None,
+ required=False,
+ help=None,
+ metavar="BOOLEAN",
+ ):
super().__init__(
option_strings=option_strings,
dest=dest,
@@ -298,38 +313,42 @@ class StoreBool(argparse._StoreAction):
type=self.boolean,
required=required,
help=help,
- metavar=metavar)
+ metavar=metavar,
+ )
@staticmethod
def boolean(value):
value = value.lower()
- if value in ('y', 'yes', 'true', '1'):
+ if value in ("y", "yes", "true", "1"):
return True
- elif value in ('n', 'no', 'false', '0'):
+ elif value in ("n", "no", "false", "0"):
return False
raise ValueError("value %r must be [y|yes|true|1|n|no|false|0]" % (value,))
class EnableDebug(argparse._StoreTrueAction):
-
def __call__(self, parser, namespace, values, option_string=None):
super().__call__(parser, namespace, values, option_string=option_string)
logging.root.setLevel(logging.DEBUG)
class Verbosity(argparse.Action):
-
def __init__(self, option_strings, dest, default=None, required=False, help=None):
super().__init__(
- option_strings=option_strings, dest=dest, nargs=0,
- default=default, required=required, help=help)
+ option_strings=option_strings,
+ dest=dest,
+ nargs=0,
+ default=default,
+ required=required,
+ help=help,
+ )
# map verbose/quiet args to increment/decrement the underlying verbosity value
self.value_map = {
- '-q': -1,
- '--quiet': -1,
- '-v': 1,
- '--verbose': 1,
+ "-q": -1,
+ "--quiet": -1,
+ "-v": 1,
+ "--verbose": 1,
}
def __call__(self, parser, namespace, values, option_string=None):
@@ -343,7 +362,6 @@ class Verbosity(argparse.Action):
class DelayedValue:
-
def __init__(self, invokable, priority=0):
self.priority = priority
if not callable(invokable):
@@ -355,7 +373,6 @@ class DelayedValue:
class DelayedDefault(DelayedValue):
-
@classmethod
def wipe(cls, attrs, priority):
if isinstance(attrs, str):
@@ -376,20 +393,17 @@ class DelayedDefault(DelayedValue):
class DelayedParse(DelayedValue):
-
def __call__(self, namespace, attr):
self.invokable()
class OrderedParse(DelayedValue):
-
def __call__(self, namespace, attr):
self.invokable(namespace)
delattr(namespace, attr)
class Delayed(argparse.Action):
-
def __init__(self, option_strings, dest, target=None, priority=0, **kwargs):
if target is None:
raise ValueError("target must be non None for Delayed")
@@ -397,21 +411,30 @@ class Delayed(argparse.Action):
self.priority = int(priority)
self.target = target(option_strings=option_strings, dest=dest, **kwargs.copy())
super().__init__(
- option_strings=option_strings[:], dest=dest,
- nargs=kwargs.get("nargs", None), required=kwargs.get("required", None),
- help=kwargs.get("help", None), metavar=kwargs.get("metavar", None),
- default=kwargs.get("default", None))
+ option_strings=option_strings[:],
+ dest=dest,
+ nargs=kwargs.get("nargs", None),
+ required=kwargs.get("required", None),
+ help=kwargs.get("help", None),
+ metavar=kwargs.get("metavar", None),
+ default=kwargs.get("default", None),
+ )
def __call__(self, parser, namespace, values, option_string=None):
- setattr(namespace, self.dest, DelayedParse(
- partial(self.target, parser, namespace, values, option_string),
- self.priority))
+ setattr(
+ namespace,
+ self.dest,
+ DelayedParse(
+ partial(self.target, parser, namespace, values, option_string),
+ self.priority,
+ ),
+ )
class Expansion(argparse.Action):
-
- def __init__(self, option_strings, dest, nargs=None, help=None,
- required=None, subst=None):
+ def __init__(
+ self, option_strings, dest, nargs=None, help=None, required=None, subst=None
+ ):
if subst is None:
raise TypeError("substitution string must be set")
# simple aliases with no required arguments shouldn't need to specify nargs
@@ -424,7 +447,8 @@ class Expansion(argparse.Action):
help=help,
required=required,
default=False,
- nargs=nargs)
+ nargs=nargs,
+ )
self.subst = tuple(subst)
def __call__(self, parser, namespace, values, option_string=None):
@@ -434,7 +458,7 @@ class Expansion(argparse.Action):
if isinstance(values, str):
vals = [vals]
dvals = {str(idx): val for idx, val in enumerate(vals)}
- dvals['*'] = ' '.join(vals)
+ dvals["*"] = " ".join(vals)
for action in actions:
action_map.update((option, action) for option in action.option_strings)
@@ -445,8 +469,8 @@ class Expansion(argparse.Action):
args = [x % dvals for x in args]
if not action:
raise ValueError(
- "unable to find option %r for %r" %
- (option, self.option_strings))
+ "unable to find option %r for %r" % (option, self.option_strings)
+ )
if action.type is not None:
args = list(map(action.type, args))
if action.nargs in (1, None):
@@ -456,7 +480,6 @@ class Expansion(argparse.Action):
class _SubParser(argparse._SubParsersAction):
-
def add_parser(self, name, cls=None, **kwargs):
"""Subparser that links description/help if one is specified."""
description = kwargs.get("description")
@@ -465,7 +488,7 @@ class _SubParser(argparse._SubParsersAction):
if help_txt is not None:
kwargs["description"] = help_txt
elif help_txt is None:
- kwargs["help"] = description.split('\n', 1)[0]
+ kwargs["help"] = description.split("\n", 1)[0]
# support using a custom parser class for the subparser
orig_class = self._parser_class
@@ -486,7 +509,7 @@ class _SubParser(argparse._SubParsersAction):
Note that this assumes a specific module naming and layout scheme for commands.
"""
prog = self._prog_prefix
- module = f'{prog}.scripts.{prog}_{subcmd}'
+ module = f"{prog}.scripts.{prog}_{subcmd}"
func = partial(self._lazy_parser, module, subcmd)
self._name_parser_map[subcmd] = lazy_object_proxy.Proxy(func)
@@ -507,8 +530,8 @@ class _SubParser(argparse._SubParsersAction):
try:
parser = self._name_parser_map[parser_name]
except KeyError:
- tup = parser_name, ', '.join(self._name_parser_map)
- msg = _('unknown parser %r (choices: %s)') % tup
+ tup = parser_name, ", ".join(self._name_parser_map)
+ msg = _("unknown parser %r (choices: %s)") % tup
raise argparse.ArgumentError(self, msg)
# parse all the remaining options into the namespace
@@ -526,11 +549,13 @@ class CsvHelpFormatter(argparse.HelpFormatter):
def _format_args(self, action, default_metavar):
get_metavar = self._metavar_formatter(action, default_metavar)
if isinstance(action, (CommaSeparatedValues, CommaSeparatedValuesAppend)):
- result = '%s[,%s,...]' % get_metavar(2)
- elif isinstance(action, (CommaSeparatedNegations, CommaSeparatedNegationsAppend)):
- result = '%s[,-%s,...]' % get_metavar(2)
+ result = "%s[,%s,...]" % get_metavar(2)
+ elif isinstance(
+ action, (CommaSeparatedNegations, CommaSeparatedNegationsAppend)
+ ):
+ result = "%s[,-%s,...]" % get_metavar(2)
elif isinstance(action, (CommaSeparatedElements, CommaSeparatedElementsAppend)):
- result = '%s[,-%s,+%s...]' % get_metavar(3)
+ result = "%s[,-%s,+%s...]" % get_metavar(3)
else:
result = super()._format_args(action, default_metavar)
return result
@@ -540,7 +565,7 @@ class SortedHelpFormatter(CsvHelpFormatter):
"""Help formatter that sorts arguments by option strings."""
def add_arguments(self, actions):
- actions = sorted(actions, key=attrgetter('option_strings'))
+ actions = sorted(actions, key=attrgetter("option_strings"))
super().add_arguments(actions)
@@ -576,7 +601,7 @@ class SubcmdAbbrevArgumentParser(argparse.ArgumentParser):
# for everything but PARSER, REMAINDER args, strip out first '--'
if action.nargs not in [PARSER, REMAINDER]:
try:
- arg_strings.remove('--')
+ arg_strings.remove("--")
except ValueError:
pass
@@ -592,8 +617,11 @@ class SubcmdAbbrevArgumentParser(argparse.ArgumentParser):
# when nargs='*' on a positional, if there were no command-line
# args, use the default if it is anything other than None
- elif (not arg_strings and action.nargs == ZERO_OR_MORE and
- not action.option_strings):
+ elif (
+ not arg_strings
+ and action.nargs == ZERO_OR_MORE
+ and not action.option_strings
+ ):
if action.default is not None:
value = action.default
else:
@@ -602,7 +630,7 @@ class SubcmdAbbrevArgumentParser(argparse.ArgumentParser):
# single argument or optional argument produces a single value
elif len(arg_strings) == 1 and action.nargs in [None, OPTIONAL]:
- arg_string, = arg_strings
+ (arg_string,) = arg_strings
value = self._get_value(action, arg_string)
self._check_value(action, value)
@@ -688,7 +716,7 @@ class OptionalsParser(argparse.ArgumentParser):
for i, mutex_action in enumerate(mutex_group._group_actions):
conflicts = action_conflicts.setdefault(mutex_action, [])
conflicts.extend(group_actions[:i])
- conflicts.extend(group_actions[i + 1:])
+ conflicts.extend(group_actions[i + 1 :])
# find all option indices, and determine the arg_string_pattern
# which has an 'O' if there is an option at an index,
@@ -699,24 +727,24 @@ class OptionalsParser(argparse.ArgumentParser):
for i, arg_string in enumerate(arg_strings_iter):
# all args after -- are non-options
- if arg_string == '--':
- arg_string_pattern_parts.append('-')
+ if arg_string == "--":
+ arg_string_pattern_parts.append("-")
for arg_string in arg_strings_iter:
- arg_string_pattern_parts.append('A')
+ arg_string_pattern_parts.append("A")
# otherwise, add the arg to the arg strings
# and note the index if it was an option
else:
option_tuple = self._parse_optional(arg_string)
if option_tuple is None:
- pattern = 'A'
+ pattern = "A"
else:
option_string_indices[i] = option_tuple
- pattern = 'O'
+ pattern = "O"
arg_string_pattern_parts.append(pattern)
# join the pieces together to form the pattern
- arg_strings_pattern = ''.join(arg_string_pattern_parts)
+ arg_strings_pattern = "".join(arg_string_pattern_parts)
# converts arg strings to the appropriate and then takes the action
seen_actions = set()
@@ -733,7 +761,7 @@ class OptionalsParser(argparse.ArgumentParser):
seen_non_default_actions.add(action)
for conflict_action in action_conflicts.get(action, []):
if conflict_action in seen_non_default_actions:
- msg = _('not allowed with argument %s')
+ msg = _("not allowed with argument %s")
action_name = _get_action_name(conflict_action)
raise ArgumentError(action, msg % action_name)
@@ -762,14 +790,14 @@ class OptionalsParser(argparse.ArgumentParser):
# if we match help options, skip them for now so subparsers
# show up in the help output
- if arg_strings[start_index] in ('-h', '--help'):
+ if arg_strings[start_index] in ("-h", "--help"):
extras.append(arg_strings[start_index])
return start_index + 1
# if there is an explicit argument, try to match the
# optional's string arguments to only this
if explicit_arg is not None:
- arg_count = match_argument(action, 'A')
+ arg_count = match_argument(action, "A")
# if the action is a single-dash option and takes no
# arguments, try to parse more single-dash options out
@@ -785,7 +813,7 @@ class OptionalsParser(argparse.ArgumentParser):
action = optionals_map[option_string]
explicit_arg = new_explicit_arg
else:
- msg = _('ignored explicit argument %r')
+ msg = _("ignored explicit argument %r")
raise ArgumentError(action, msg % explicit_arg)
# if the action expect exactly one argument, we've
@@ -799,7 +827,7 @@ class OptionalsParser(argparse.ArgumentParser):
# error if a double-dash option did not use the
# explicit argument
else:
- msg = _('ignored explicit argument %r')
+ msg = _("ignored explicit argument %r")
raise ArgumentError(action, msg % explicit_arg)
# if there is no explicit argument, try to match the
@@ -835,13 +863,13 @@ class OptionalsParser(argparse.ArgumentParser):
# slice off the appropriate arg strings for each Positional
# and add the Positional and its args to the list
for action, arg_count in zip(positionals, arg_counts):
- args = arg_strings[start_index: start_index + arg_count]
+ args = arg_strings[start_index : start_index + arg_count]
start_index += arg_count
take_action(action, args)
# slice off the Positionals that we just parsed and return the
# index at which the Positionals' string args stopped
- positionals[:] = positionals[len(arg_counts):]
+ positionals[:] = positionals[len(arg_counts) :]
return start_index
# consume Positionals and Optionals alternately, until we have
@@ -855,10 +883,9 @@ class OptionalsParser(argparse.ArgumentParser):
while start_index <= max_option_string_index:
# consume any Positionals preceding the next option
- next_option_string_index = min([
- index
- for index in option_string_indices
- if index >= start_index])
+ next_option_string_index = min(
+ [index for index in option_string_indices if index >= start_index]
+ )
if start_index != next_option_string_index:
# positionals_end_index = consume_positionals(start_index)
positionals_end_index = start_index
@@ -894,7 +921,9 @@ class OptionalsParser(argparse.ArgumentParser):
for action in self._actions:
if action not in seen_actions:
# ignore required subcommands and positionals as they'll be handled later
- skip = not action.option_strings or isinstance(action, _SubParsersAction)
+ skip = not action.option_strings or isinstance(
+ action, _SubParsersAction
+ )
if action.required and not skip:
required_actions.append(_get_action_name(action))
else:
@@ -902,16 +931,23 @@ class OptionalsParser(argparse.ArgumentParser):
# parsing arguments to avoid calling convert functions
# twice (which may fail) if the argument was given, but
# only if it was defined already in the namespace
- if (action.default is not None and
- isinstance(action.default, str) and
- hasattr(namespace, action.dest) and
- action.default is getattr(namespace, action.dest)):
- setattr(namespace, action.dest,
- self._get_value(action, action.default))
+ if (
+ action.default is not None
+ and isinstance(action.default, str)
+ and hasattr(namespace, action.dest)
+ and action.default is getattr(namespace, action.dest)
+ ):
+ setattr(
+ namespace,
+ action.dest,
+ self._get_value(action, action.default),
+ )
if required_actions:
- self.error(_('the following arguments are required: %s') %
- ', '.join(required_actions))
+ self.error(
+ _("the following arguments are required: %s")
+ % ", ".join(required_actions)
+ )
# make sure all required groups had one option present
for group in self._mutually_exclusive_groups:
@@ -922,11 +958,13 @@ class OptionalsParser(argparse.ArgumentParser):
# if no actions were used, report the error
else:
- names = [_get_action_name(action)
- for action in group._group_actions
- if action.help is not SUPPRESS]
- msg = _('one of the arguments %s is required')
- self.error(msg % ' '.join(names))
+ names = [
+ _get_action_name(action)
+ for action in group._group_actions
+ if action.help is not SUPPRESS
+ ]
+ msg = _("one of the arguments %s is required")
+ self.error(msg % " ".join(names))
# return the updated namespace and the extra arguments
return namespace, extras
@@ -937,31 +975,49 @@ class CsvActionsParser(argparse.ArgumentParser):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.register('action', 'csv', CommaSeparatedValues)
- self.register('action', 'csv_append', CommaSeparatedValuesAppend)
- self.register('action', 'csv_negations', CommaSeparatedNegations)
- self.register('action', 'csv_negations_append', CommaSeparatedNegationsAppend)
- self.register('action', 'csv_elements', CommaSeparatedElements)
- self.register('action', 'csv_elements_append', CommaSeparatedElementsAppend)
+ self.register("action", "csv", CommaSeparatedValues)
+ self.register("action", "csv_append", CommaSeparatedValuesAppend)
+ self.register("action", "csv_negations", CommaSeparatedNegations)
+ self.register("action", "csv_negations_append", CommaSeparatedNegationsAppend)
+ self.register("action", "csv_elements", CommaSeparatedElements)
+ self.register("action", "csv_elements_append", CommaSeparatedElementsAppend)
class ArgumentParser(OptionalsParser, CsvActionsParser):
"""Extended, argparse-compatible argument parser."""
- def __init__(self, suppress=False, subcmds=False, color=True, debug=True, quiet=True,
- verbose=True, version=True, add_help=True, sorted_help=False,
- description=None, docs=None, script=None, prog=None, **kwargs):
- self.debug = debug and '--debug' in sys.argv[1:]
+ def __init__(
+ self,
+ suppress=False,
+ subcmds=False,
+ color=True,
+ debug=True,
+ quiet=True,
+ verbose=True,
+ version=True,
+ add_help=True,
+ sorted_help=False,
+ description=None,
+ docs=None,
+ script=None,
+ prog=None,
+ **kwargs,
+ ):
+ self.debug = debug and "--debug" in sys.argv[1:]
self.verbosity = int(verbose)
if self.verbosity:
argv = Counter(sys.argv[1:])
# Only supports single, short opts (i.e. -vv isn't recognized),
# post argparsing the proper value supporting those kind of args is
# in the options namespace.
- self.verbosity = sum(chain.from_iterable((
- (-1 for x in range(argv['-q'] + argv['--quiet'])),
- (1 for x in range(argv['-v'] + argv['--verbose'])),
- )))
+ self.verbosity = sum(
+ chain.from_iterable(
+ (
+ (-1 for x in range(argv["-q"] + argv["--quiet"])),
+ (1 for x in range(argv["-v"] + argv["--verbose"])),
+ )
+ )
+ )
# subparsers action object from calling add_subparsers()
self.__subparsers = None
@@ -979,7 +1035,7 @@ class ArgumentParser(OptionalsParser, CsvActionsParser):
# usage such as adding conflicting options to both the root command and
# subcommands without causing issues in addition to helping support
# default subparsers.
- self._parents = tuple(kwargs.get('parents', ()))
+ self._parents = tuple(kwargs.get("parents", ()))
# extract the description to use and set docs for doc generation
description = self._update_desc(description, docs)
@@ -993,11 +1049,12 @@ class ArgumentParser(OptionalsParser, CsvActionsParser):
raise TypeError
except TypeError:
raise ValueError(
- "invalid script parameter, should be (__file__, __name__)")
+ "invalid script parameter, should be (__file__, __name__)"
+ )
- project = script_module.split('.')[0]
+ project = script_module.split(".")[0]
if prog is None:
- prog = script_module.split('.')[-1]
+ prog = script_module.split(".")[-1]
if sorted_help:
formatter = SortedHelpFormatter
@@ -1005,27 +1062,36 @@ class ArgumentParser(OptionalsParser, CsvActionsParser):
formatter = CsvHelpFormatter
super().__init__(
- description=description, formatter_class=formatter,
- prog=prog, add_help=False, **kwargs)
+ description=description,
+ formatter_class=formatter,
+ prog=prog,
+ add_help=False,
+ **kwargs,
+ )
# register custom actions
- self.register('action', 'parsers', _SubParser)
+ self.register("action", "parsers", _SubParser)
if not suppress:
- base_opts = self.add_argument_group('base options')
+ base_opts = self.add_argument_group("base options")
if add_help:
base_opts.add_argument(
- '-h', '--help', action=ManHelpAction, default=argparse.SUPPRESS,
- help='show this help message and exit',
+ "-h",
+ "--help",
+ action=ManHelpAction,
+ default=argparse.SUPPRESS,
+ help="show this help message and exit",
docs="""
Show this help message and exit. To get more
information see the related man page.
- """)
+ """,
+ )
if version and script is not None:
# Note that this option will currently only be available on the
# base command, not on subcommands.
base_opts.add_argument(
- '--version', action='version',
+ "--version",
+ action="version",
version=get_version(project, script_path),
help="show this program's version info and exit",
docs="""
@@ -1034,39 +1100,58 @@ class ArgumentParser(OptionalsParser, CsvActionsParser):
When running from within a git repo or a version
installed from git the latest commit hash and date will
be shown.
- """)
+ """,
+ )
if debug:
base_opts.add_argument(
- '--debug', action=EnableDebug, help='enable debugging checks',
- docs='Enable debug checks and show verbose debug output.')
+ "--debug",
+ action=EnableDebug,
+ help="enable debugging checks",
+ docs="Enable debug checks and show verbose debug output.",
+ )
if quiet:
base_opts.add_argument(
- '-q', '--quiet', action=Verbosity, dest='verbosity', default=0,
- help='suppress non-error messages',
- docs="Suppress non-error, informational messages.")
+ "-q",
+ "--quiet",
+ action=Verbosity,
+ dest="verbosity",
+ default=0,
+ help="suppress non-error messages",
+ docs="Suppress non-error, informational messages.",
+ )
if verbose:
base_opts.add_argument(
- '-v', '--verbose', action=Verbosity, dest='verbosity', default=0,
- help='show verbose output',
- docs="Increase the verbosity of various output.")
+ "-v",
+ "--verbose",
+ action=Verbosity,
+ dest="verbosity",
+ default=0,
+ help="show verbose output",
+ docs="Increase the verbosity of various output.",
+ )
if color:
base_opts.add_argument(
- '--color', action=StoreBool,
+ "--color",
+ action=StoreBool,
default=sys.stdout.isatty(),
- help='enable/disable color support',
+ help="enable/disable color support",
docs="""
Toggle colored output support. This can be used to forcibly
enable color support when piping output or other sitations
where stdout is not a tty.
- """)
+ """,
+ )
# register existing subcommands
if subcmds:
- prefix = f'{prog}.scripts.{prog}_'
+ prefix = f"{prog}.scripts.{prog}_"
if subcmd_modules := [
- name[len(prefix):] for _, name, _ in
- pkgutil.walk_packages([os.path.dirname(script_path)], f'{prog}.scripts.')
- if name.startswith(prefix)]:
+ name[len(prefix) :]
+ for _, name, _ in pkgutil.walk_packages(
+ [os.path.dirname(script_path)], f"{prog}.scripts."
+ )
+ if name.startswith(prefix)
+ ]:
subparsers = self.add_subparsers()
for subcmd in subcmd_modules:
subparsers.add_command(subcmd)
@@ -1080,7 +1165,7 @@ class ArgumentParser(OptionalsParser, CsvActionsParser):
"""
description_lines = []
if description is not None:
- description_lines = description.strip().split('\n', 1)
+ description_lines = description.strip().split("\n", 1)
description = description_lines[0]
if _generate_docs:
if docs is None and len(description_lines) == 2:
@@ -1156,7 +1241,9 @@ class ArgumentParser(OptionalsParser, CsvActionsParser):
try:
# run registered early parse functions from all parsers
- for functor, parser in chain.from_iterable(x.__early_parse for x in self.parsers):
+ for functor, parser in chain.from_iterable(
+ x.__early_parse for x in self.parsers
+ ):
namespace, args = functor(parser, namespace, args)
# parse the arguments and exit if there are any errors
@@ -1176,7 +1263,7 @@ class ArgumentParser(OptionalsParser, CsvActionsParser):
args, unknown_args = self.parse_known_args(args, namespace)
# make sure the correct function and prog are set if running a subcommand
- subcmd_parser = self.subparsers.get(getattr(args, 'subcommand', None), None)
+ subcmd_parser = self.subparsers.get(getattr(args, "subcommand", None), None)
if subcmd_parser is not None:
# override the running program with full subcommand
self.prog = subcmd_parser.prog
@@ -1186,7 +1273,7 @@ class ArgumentParser(OptionalsParser, CsvActionsParser):
namespace.main_func = subcmd_parser.__main_func
if unknown_args:
- self.error('unrecognized arguments: %s' % ' '.join(unknown_args))
+ self.error("unrecognized arguments: %s" % " ".join(unknown_args))
# Two runs are required; first, handle any suppression defaults
# introduced. Subparsers defaults cannot override the parent parser, as
@@ -1198,14 +1285,20 @@ class ArgumentParser(OptionalsParser, CsvActionsParser):
# intentionally no protection of suppression code; this should
# just work.
- i = ((attr, val) for attr, val in args.__dict__.items()
- if isinstance(val, DelayedDefault))
+ i = (
+ (attr, val)
+ for attr, val in args.__dict__.items()
+ if isinstance(val, DelayedDefault)
+ )
for attr, functor in sorted(i, key=lambda val: val[1].priority):
functor(args, attr)
# now run the delays
- i = ((attr, val) for attr, val in args.__dict__.items()
- if isinstance(val, DelayedValue))
+ i = (
+ (attr, val)
+ for attr, val in args.__dict__.items()
+ if isinstance(val, DelayedValue)
+ )
try:
for attr, delayed in sorted(i, key=lambda val: val[1].priority):
delayed(args, attr)
@@ -1216,7 +1309,9 @@ class ArgumentParser(OptionalsParser, CsvActionsParser):
self.error(str(e))
# run final arg validation
- final_checks = [k for k in args.__dict__.keys() if k.startswith('__final_check__')]
+ final_checks = [
+ k for k in args.__dict__.keys() if k.startswith("__final_check__")
+ ]
for check in final_checks:
functor = args.pop(check)
functor(self, args)
@@ -1232,7 +1327,7 @@ class ArgumentParser(OptionalsParser, CsvActionsParser):
if self.debug and sys.exc_info() != (None, None, None):
# output traceback if any exception is on the stack
traceback.print_exc()
- self.exit(status, '%s: error: %s\n' % (self.prog, message))
+ self.exit(status, "%s: error: %s\n" % (self.prog, message))
def bind_main_func(self, functor):
"""Decorator to set a main function for the parser."""
@@ -1245,8 +1340,8 @@ class ArgumentParser(OptionalsParser, CsvActionsParser):
def bind_class(self, obj):
if not isinstance(obj, ArgparseCommand):
raise ValueError(
- "expected obj to be an instance of "
- "ArgparseCommand; got %r" % (obj,))
+ "expected obj to be an instance of " "ArgparseCommand; got %r" % (obj,)
+ )
obj.bind_to_parser(self)
return self
@@ -1261,10 +1356,12 @@ class ArgumentParser(OptionalsParser, CsvActionsParser):
"""Only run delayed default functor if the attribute isn't set."""
if isinstance(object.__getattribute__(namespace, attr), DelayedValue):
functor(namespace, attr)
+
if name is None:
name = functor.__name__
self.set_defaults(**{name: DelayedValue(default, priority)})
return functor
+
return f
def bind_parse_priority(self, priority):
@@ -1272,6 +1369,7 @@ class ArgumentParser(OptionalsParser, CsvActionsParser):
name = functor.__name__
self.set_defaults(**{name: OrderedParse(functor, priority)})
return functor
+
return f
def add_subparsers(self, **kwargs):
@@ -1280,9 +1378,9 @@ class ArgumentParser(OptionalsParser, CsvActionsParser):
if self.__subparsers is not None:
return self.__subparsers
- kwargs.setdefault('title', 'subcommands')
- kwargs.setdefault('dest', 'subcommand')
- kwargs.setdefault('prog', self.prog)
+ kwargs.setdefault("title", "subcommands")
+ kwargs.setdefault("dest", "subcommand")
+ kwargs.setdefault("prog", self.prog)
subparsers = argparse.ArgumentParser.add_subparsers(self, **kwargs)
subparsers.required = True
self.__subparsers = subparsers
@@ -1300,18 +1398,17 @@ class ArgumentParser(OptionalsParser, CsvActionsParser):
def bind_final_check(self, functor):
"""Decorator to bind a function for argument validation."""
- name = f'__final_check__{functor.__name__}'
+ name = f"__final_check__{functor.__name__}"
self.set_defaults(**{name: functor})
return functor
class ArgparseCommand:
-
def bind_to_parser(self, parser):
parser.bind_main_func(self)
def __call__(self, namespace, out, err):
- raise NotImplementedError(self, '__call__')
+ raise NotImplementedError(self, "__call__")
class FileType(argparse.FileType):
@@ -1322,11 +1419,11 @@ class FileType(argparse.FileType):
def __call__(self, string):
# the special argument "-" means sys.std{in,out}
- if string == '-':
- if 'r' in self._mode:
- return sys.stdin.buffer if 'b' in self._mode else sys.stdin
- elif any(c in self._mode for c in 'wax'):
- return sys.stdout.buffer if 'b' in self._mode else sys.stdout
+ if string == "-":
+ if "r" in self._mode:
+ return sys.stdin.buffer if "b" in self._mode else sys.stdin
+ elif any(c in self._mode for c in "wax"):
+ return sys.stdout.buffer if "b" in self._mode else sys.stdout
else:
msg = _('argument "-" with mode %r') % self._mode
raise ValueError(msg)
@@ -1342,23 +1439,27 @@ class FileType(argparse.FileType):
def existent_path(value):
"""Check if file argument path exists."""
if not os.path.exists(value):
- raise argparse.ArgumentTypeError(f'nonexistent path: {value!r}')
+ raise argparse.ArgumentTypeError(f"nonexistent path: {value!r}")
try:
return os.path.realpath(value)
except EnvironmentError as e:
- raise ValueError(f'while resolving path {value!r}, encountered error: {e}') from e
+ raise ValueError(
+ f"while resolving path {value!r}, encountered error: {e}"
+ ) from e
def existent_dir(value):
"""Check if argument path exists and is a directory."""
if not os.path.exists(value):
- raise argparse.ArgumentTypeError(f'nonexistent dir: {value!r}')
+ raise argparse.ArgumentTypeError(f"nonexistent dir: {value!r}")
elif not os.path.isdir(value):
- raise argparse.ArgumentTypeError(f'file already exists: {value!r}')
+ raise argparse.ArgumentTypeError(f"file already exists: {value!r}")
try:
return os.path.realpath(value)
except EnvironmentError as e:
- raise ValueError(f'while resolving path {value!r}, encountered error: {e}') from e
+ raise ValueError(
+ f"while resolving path {value!r}, encountered error: {e}"
+ ) from e
def create_dir(value):
@@ -1367,9 +1468,9 @@ def create_dir(value):
try:
os.makedirs(path, exist_ok=True)
except FileExistsError:
- raise argparse.ArgumentTypeError(f'file already exists: {value!r}')
+ raise argparse.ArgumentTypeError(f"file already exists: {value!r}")
except IOError as e:
- raise argparse.ArgumentTypeError(f'failed creating dir: {e}')
+ raise argparse.ArgumentTypeError(f"failed creating dir: {e}")
return path
@@ -1378,12 +1479,12 @@ def bounded_int(func, desc, x):
try:
n = int(x)
except ValueError:
- raise argparse.ArgumentTypeError('invalid integer value')
+ raise argparse.ArgumentTypeError("invalid integer value")
if not func(n):
- raise argparse.ArgumentTypeError(f'must be {desc}')
+ raise argparse.ArgumentTypeError(f"must be {desc}")
return n
def positive_int(x):
"""Check if argument is a positive integer."""
- return bounded_int(lambda n: n >= 1, '>= 1', x)
+ return bounded_int(lambda n: n >= 1, ">= 1", x)
diff --git a/src/snakeoil/cli/exceptions.py b/src/snakeoil/cli/exceptions.py
index fb0149fd..42b8d676 100644
--- a/src/snakeoil/cli/exceptions.py
+++ b/src/snakeoil/cli/exceptions.py
@@ -11,7 +11,7 @@ class UserException(Exception):
self._verbosity = verbosity
def msg(self, verbosity=0):
- return ''
+ return ""
class ExitException(Exception):
@@ -30,6 +30,8 @@ class ExitException(Exception):
def find_user_exception(exc):
"""Find the UserException related to a given exception if one exists."""
try:
- return next(e for e in walk_exception_chain(exc) if isinstance(e, UserException))
+ return next(
+ e for e in walk_exception_chain(exc) if isinstance(e, UserException)
+ )
except StopIteration:
return None
diff --git a/src/snakeoil/cli/input.py b/src/snakeoil/cli/input.py
index a89d61ed..9db9ea8c 100644
--- a/src/snakeoil/cli/input.py
+++ b/src/snakeoil/cli/input.py
@@ -41,9 +41,9 @@ def userquery(prompt, out, err, responses=None, default_answer=None, limit=3):
"""
if responses is None:
responses = {
- 'yes': (True, out.fg('green'), 'Yes'),
- 'no': (False, out.fg('red'), 'No'),
- }
+ "yes": (True, out.fg("green"), "Yes"),
+ "no": (False, out.fg("red"), "No"),
+ }
if default_answer is None:
default_answer = True
if default_answer is not None:
@@ -52,25 +52,25 @@ def userquery(prompt, out, err, responses=None, default_answer=None, limit=3):
default_answer_name = val[1:]
break
else:
- raise ValueError('default answer matches no responses')
+ raise ValueError("default answer matches no responses")
for i in range(limit):
# XXX see docstring about crummyness
if isinstance(prompt, tuple):
out.write(autoline=False, *prompt)
else:
out.write(prompt, autoline=False)
- out.write(' [', autoline=False)
+ out.write(" [", autoline=False)
prompts = list(responses.values())
for choice in prompts[:-1]:
out.write(autoline=False, *choice[1:])
- out.write(out.reset, '/', autoline=False)
+ out.write(out.reset, "/", autoline=False)
out.write(autoline=False, *prompts[-1][1:])
- out.write(out.reset, ']', autoline=False)
+ out.write(out.reset, "]", autoline=False)
if default_answer is not None:
- out.write(' (default: ', autoline=False)
+ out.write(" (default: ", autoline=False)
out.write(autoline=False, *default_answer_name)
- out.write(')', autoline=False)
- out.write(': ', autoline=False)
+ out.write(")", autoline=False)
+ out.write(": ", autoline=False)
try:
response = input()
except EOFError as e:
@@ -83,15 +83,20 @@ def userquery(prompt, out, err, responses=None, default_answer=None, limit=3):
raise
if not response:
return default_answer
- results = sorted(set(
- (key, value) for key, value in responses.items()
- if key[:len(response)].lower() == response.lower()))
+ results = sorted(
+ set(
+ (key, value)
+ for key, value in responses.items()
+ if key[: len(response)].lower() == response.lower()
+ )
+ )
if not results:
- err.write('Sorry, response %r not understood.' % (response,))
+ err.write("Sorry, response %r not understood." % (response,))
elif len(results) > 1:
err.write(
- 'Response %r is ambiguous (%s)' %
- (response, ', '.join(key for key, val in results)))
+ "Response %r is ambiguous (%s)"
+ % (response, ", ".join(key for key, val in results))
+ )
else:
return list(results)[0][1][0]
diff --git a/src/snakeoil/cli/tool.py b/src/snakeoil/cli/tool.py
index 2a142b72..d00b3cf2 100644
--- a/src/snakeoil/cli/tool.py
+++ b/src/snakeoil/cli/tool.py
@@ -36,14 +36,14 @@ class Tool:
if not sys.stdout.isatty() and sys.stdout == sys.__stdout__:
# if redirecting/piping stdout use line buffering, skip if
# stdout has been set to some non-standard object
- outfile = os.fdopen(sys.stdout.fileno(), 'w', 1)
+ outfile = os.fdopen(sys.stdout.fileno(), "w", 1)
else:
outfile = sys.stdout
if errfile is None:
errfile = sys.stderr
out_fd = err_fd = None
- if hasattr(outfile, 'fileno') and hasattr(errfile, 'fileno'):
+ if hasattr(outfile, "fileno") and hasattr(errfile, "fileno"):
# annoyingly, fileno can exist but through unsupport
try:
out_fd, err_fd = outfile.fileno(), errfile.fileno()
@@ -52,9 +52,11 @@ class Tool:
if out_fd is not None and err_fd is not None:
out_stat, err_stat = os.fstat(out_fd), os.fstat(err_fd)
- if out_stat.st_dev == err_stat.st_dev \
- and out_stat.st_ino == err_stat.st_ino and \
- not errfile.isatty():
+ if (
+ out_stat.st_dev == err_stat.st_dev
+ and out_stat.st_ino == err_stat.st_ino
+ and not errfile.isatty()
+ ):
# they're the same underlying fd. thus
# point the handles at the same so we don't
# get intermixed buffering issues.
@@ -64,7 +66,7 @@ class Tool:
self._errfile = errfile
self.out = self.parser.out = formatters.PlainTextFormatter(outfile)
self.err = self.parser.err = formatters.PlainTextFormatter(errfile)
- self.out.verbosity = self.err.verbosity = getattr(self.parser, 'verbosity', 0)
+ self.out.verbosity = self.err.verbosity = getattr(self.parser, "verbosity", 0)
def __call__(self, args=None):
"""Run the utility.
@@ -98,19 +100,21 @@ class Tool:
try:
self.pre_parse(args, namespace)
options = self.parser.parse_args(args=args, namespace=namespace)
- main_func = options.pop('main_func', None)
+ main_func = options.pop("main_func", None)
if main_func is None:
raise RuntimeError("argparser missing main method")
# reconfigure formatters for colored output if enabled
- if getattr(options, 'color', True):
+ if getattr(options, "color", True):
formatter_factory = partial(
- formatters.get_formatter, force_color=getattr(options, 'color', False))
+ formatters.get_formatter,
+ force_color=getattr(options, "color", False),
+ )
self.out = formatter_factory(self._outfile)
self.err = formatter_factory(self._errfile)
# reconfigure formatters with properly parsed output verbosity
- self.out.verbosity = self.err.verbosity = getattr(options, 'verbosity', 0)
+ self.out.verbosity = self.err.verbosity = getattr(options, "verbosity", 0)
if logging.root.handlers:
# Remove the default handler.
@@ -138,13 +142,13 @@ class Tool:
exc = find_user_exception(e)
if exc is not None:
# allow exception attribute to override user verbosity level
- if getattr(exc, '_verbosity', None) is not None:
+ if getattr(exc, "_verbosity", None) is not None:
verbosity = exc._verbosity
else:
- verbosity = getattr(self.parser, 'verbosity', 0)
+ verbosity = getattr(self.parser, "verbosity", 0)
# output verbose error message if it exists
if verbosity > 0:
- msg = exc.msg(verbosity).strip('\n')
+ msg = exc.msg(verbosity).strip("\n")
if msg:
self.err.write(msg)
raise SystemExit
@@ -166,15 +170,17 @@ class Tool:
try:
with suppress_warnings:
- self.options, func = self.parse_args(args=self.args, namespace=self.options)
+ self.options, func = self.parse_args(
+ args=self.args, namespace=self.options
+ )
exitstatus = func(self.options, self.out, self.err)
except SystemExit as e:
# handle argparse or other third party modules using sys.exit internally
exitstatus = e.code
except KeyboardInterrupt:
- self._errfile.write('keyboard interrupted- exiting')
+ self._errfile.write("keyboard interrupted- exiting")
if self.parser.debug:
- self._errfile.write('\n')
+ self._errfile.write("\n")
traceback.print_exc()
signal(SIGINT, SIG_DFL)
os.killpg(os.getpgid(0), SIGINT)
@@ -187,9 +193,9 @@ class Tool:
if self.options is not None:
# set terminal title on exit
if exitstatus:
- self.out.title(f'{self.options.prog} failed')
+ self.out.title(f"{self.options.prog} failed")
else:
- self.out.title(f'{self.options.prog} succeeded')
+ self.out.title(f"{self.options.prog} succeeded")
return exitstatus
@@ -204,18 +210,25 @@ class FormattingHandler(logging.Handler):
def emit(self, record):
if record.levelno >= logging.ERROR:
- color = 'red'
+ color = "red"
elif record.levelno >= logging.WARNING:
- color = 'yellow'
+ color = "yellow"
else:
- color = 'cyan'
- first_prefix = (self.out.fg(color), self.out.bold, record.levelname,
- self.out.reset, ' ', record.name, ': ')
- later_prefix = (len(record.levelname) + len(record.name)) * ' ' + ' : '
+ color = "cyan"
+ first_prefix = (
+ self.out.fg(color),
+ self.out.bold,
+ record.levelname,
+ self.out.reset,
+ " ",
+ record.name,
+ ": ",
+ )
+ later_prefix = (len(record.levelname) + len(record.name)) * " " + " : "
self.out.first_prefix.extend(first_prefix)
self.out.later_prefix.append(later_prefix)
try:
- for line in self.format(record).split('\n'):
+ for line in self.format(record).split("\n"):
self.out.write(line, wrap=True)
except Exception:
self.handleError(record)
diff --git a/src/snakeoil/compatibility.py b/src/snakeoil/compatibility.py
index c9a8a545..4fee4173 100644
--- a/src/snakeoil/compatibility.py
+++ b/src/snakeoil/compatibility.py
@@ -8,12 +8,15 @@ __all__ = ("cmp", "sorted_cmp", "sort_cmp")
def sorted_key_from_cmp(cmp_func, key_func=None):
class _key_proxy:
- __slots__ = ('_obj',)
+ __slots__ = ("_obj",)
+
+ if key_func: # done this way for speed reasons.
- if key_func: # done this way for speed reasons.
def __init__(self, obj, key_convert=key_func):
self._obj = key_convert(obj)
+
else:
+
def __init__(self, obj):
self._obj = obj
@@ -40,13 +43,13 @@ def cmp(obj1, obj2, raw_cmp=_raw_cmp):
def sorted_cmp(sequence, func, key=None, reverse=False):
- return sorted(sequence, reverse=reverse,
- key=sorted_key_from_cmp(func, key_func=key))
+ return sorted(
+ sequence, reverse=reverse, key=sorted_key_from_cmp(func, key_func=key)
+ )
def sort_cmp(list_inst, func, key=None, reverse=False):
- list_inst.sort(reverse=reverse,
- key=sorted_key_from_cmp(func, key_func=key))
+ list_inst.sort(reverse=reverse, key=sorted_key_from_cmp(func, key_func=key))
IGNORED_EXCEPTIONS = (RuntimeError, MemoryError, SystemExit, KeyboardInterrupt)
diff --git a/src/snakeoil/compression/__init__.py b/src/snakeoil/compression/__init__.py
index 77c0631b..9eee103b 100644
--- a/src/snakeoil/compression/__init__.py
+++ b/src/snakeoil/compression/__init__.py
@@ -9,13 +9,12 @@ from ..process.spawn import spawn_get_output
class _transform_source:
-
def __init__(self, name):
self.name = name
@cached_property
def module(self):
- return import_module(f'snakeoil.compression._{self.name}')
+ return import_module(f"snakeoil.compression._{self.name}")
def compress_data(self, data, level, parallelize=False):
parallelize = parallelize and self.module.parallelizable
@@ -34,7 +33,7 @@ class _transform_source:
return self.module.decompress_handle(handle, parallelize=parallelize)
-_transforms = {name: _transform_source(name) for name in ('bzip2', 'xz')}
+_transforms = {name: _transform_source(name) for name in ("bzip2", "xz")}
def compress_data(compressor_type, data, level=9, **kwds):
@@ -73,13 +72,13 @@ class ArComp:
cls = cls.known_exts[ext]
return super(ArComp, cls).__new__(cls)
except KeyError:
- raise ArCompError(f'unknown compression file extension: {ext!r}')
+ raise ArCompError(f"unknown compression file extension: {ext!r}")
def __init_subclass__(cls, **kwargs):
"""Initialize result subclasses and register archive extensions."""
super().__init_subclass__(**kwargs)
- if not all((cls.binary, cls.default_unpack_cmd, cls.exts)): # pragma: no cover
- raise ValueError(f'class missing required attrs: {cls!r}')
+ if not all((cls.binary, cls.default_unpack_cmd, cls.exts)): # pragma: no cover
+ raise ValueError(f"class missing required attrs: {cls!r}")
for ext in cls.exts:
cls.known_exts[ext] = cls
@@ -95,9 +94,10 @@ class ArComp:
except process.CommandNotFound:
continue
else:
- choices = ', '.join(self.binary)
+ choices = ", ".join(self.binary)
raise ArCompError(
- f'required binary not found from the following choices: {choices}')
+ f"required binary not found from the following choices: {choices}"
+ )
cmd = self.default_unpack_cmd.format(binary=binary, path=self.path)
return cmd
@@ -112,7 +112,7 @@ class _Archive:
cmd = shlex.split(self._unpack_cmd.format(path=self.path))
ret, output = spawn_get_output(cmd, collect_fds=(2,), **kwargs)
if ret:
- msg = '\n'.join(output) if output else f'unpacking failed: {self.path!r}'
+ msg = "\n".join(output) if output else f"unpacking failed: {self.path!r}"
raise ArCompError(msg, code=ret)
@@ -121,11 +121,12 @@ class _CompressedFile:
def unpack(self, dest=None, **kwargs):
cmd = shlex.split(self._unpack_cmd.format(path=self.path))
- with open(dest, 'wb') as f:
+ with open(dest, "wb") as f:
ret, output = spawn_get_output(
- cmd, collect_fds=(2,), fd_pipes={1: f.fileno()}, **kwargs)
+ cmd, collect_fds=(2,), fd_pipes={1: f.fileno()}, **kwargs
+ )
if ret:
- msg = '\n'.join(output) if output else f'unpacking failed: {self.path!r}'
+ msg = "\n".join(output) if output else f"unpacking failed: {self.path!r}"
raise ArCompError(msg, code=ret)
@@ -134,18 +135,25 @@ class _CompressedStdin:
def unpack(self, dest=None, **kwargs):
cmd = shlex.split(self._unpack_cmd)
- with open(self.path, 'rb') as src, open(dest, 'wb') as f:
+ with open(self.path, "rb") as src, open(dest, "wb") as f:
ret, output = spawn_get_output(
- cmd, collect_fds=(2,), fd_pipes={0: src.fileno(), 1: f.fileno()}, **kwargs)
+ cmd,
+ collect_fds=(2,),
+ fd_pipes={0: src.fileno(), 1: f.fileno()},
+ **kwargs,
+ )
if ret:
- msg = '\n'.join(output) if output else f'unpacking failed: {self.path!r}'
+ msg = "\n".join(output) if output else f"unpacking failed: {self.path!r}"
raise ArCompError(msg, code=ret)
class _Tar(_Archive, ArComp):
- exts = frozenset(['.tar'])
- binary = ('gtar', 'tar',)
+ exts = frozenset([".tar"])
+ binary = (
+ "gtar",
+ "tar",
+ )
compress_binary = None
default_unpack_cmd = '{binary} xf "{path}"'
@@ -162,95 +170,96 @@ class _Tar(_Archive, ArComp):
except process.CommandNotFound:
pass
else:
- choices = ', '.join(next(zip(*self.compress_binary)))
+ choices = ", ".join(next(zip(*self.compress_binary)))
raise ArCompError(
- 'no compression binary found from the '
- f'following choices: {choices}')
+ "no compression binary found from the "
+ f"following choices: {choices}"
+ )
return cmd
class _TarGZ(_Tar):
- exts = frozenset(['.tar.gz', '.tgz', '.tar.Z', '.tar.z'])
- compress_binary = (('pigz',), ('gzip',))
+ exts = frozenset([".tar.gz", ".tgz", ".tar.Z", ".tar.z"])
+ compress_binary = (("pigz",), ("gzip",))
class _TarBZ2(_Tar):
- exts = frozenset(['.tar.bz2', '.tbz2', '.tbz'])
- compress_binary = (('lbzip2',), ('pbzip2',), ('bzip2',))
+ exts = frozenset([".tar.bz2", ".tbz2", ".tbz"])
+ compress_binary = (("lbzip2",), ("pbzip2",), ("bzip2",))
class _TarLZMA(_Tar):
- exts = frozenset(['.tar.lzma'])
- compress_binary = (('lzma',))
+ exts = frozenset([".tar.lzma"])
+ compress_binary = ("lzma",)
class _TarXZ(_Tar):
- exts = frozenset(['.tar.xz', '.txz'])
- compress_binary = (('pixz',), ('xz', f'-T{multiprocessing.cpu_count()}'))
+ exts = frozenset([".tar.xz", ".txz"])
+ compress_binary = (("pixz",), ("xz", f"-T{multiprocessing.cpu_count()}"))
class _Zip(_Archive, ArComp):
- exts = frozenset(['.ZIP', '.zip', '.jar'])
- binary = ('unzip',)
+ exts = frozenset([".ZIP", ".zip", ".jar"])
+ binary = ("unzip",)
default_unpack_cmd = '{binary} -qo "{path}"'
class _GZ(_CompressedStdin, ArComp):
- exts = frozenset(['.gz', '.Z', '.z'])
- binary = ('pigz', 'gzip')
- default_unpack_cmd = '{binary} -d -c'
+ exts = frozenset([".gz", ".Z", ".z"])
+ binary = ("pigz", "gzip")
+ default_unpack_cmd = "{binary} -d -c"
class _BZ2(_CompressedStdin, ArComp):
- exts = frozenset(['.bz2', '.bz'])
- binary = ('lbzip2', 'pbzip2', 'bzip2')
- default_unpack_cmd = '{binary} -d -c'
+ exts = frozenset([".bz2", ".bz"])
+ binary = ("lbzip2", "pbzip2", "bzip2")
+ default_unpack_cmd = "{binary} -d -c"
class _XZ(_CompressedStdin, ArComp):
- exts = frozenset(['.xz'])
- binary = ('pixz', 'xz')
- default_unpack_cmd = '{binary} -d -c'
+ exts = frozenset([".xz"])
+ binary = ("pixz", "xz")
+ default_unpack_cmd = "{binary} -d -c"
class _7Z(_Archive, ArComp):
- exts = frozenset(['.7Z', '.7z'])
- binary = ('7z',)
+ exts = frozenset([".7Z", ".7z"])
+ binary = ("7z",)
default_unpack_cmd = '{binary} x -y "{path}"'
class _Rar(_Archive, ArComp):
- exts = frozenset(['.RAR', '.rar'])
- binary = ('unrar',)
+ exts = frozenset([".RAR", ".rar"])
+ binary = ("unrar",)
default_unpack_cmd = '{binary} x -idq -o+ "{path}"'
class _LHA(_Archive, ArComp):
- exts = frozenset(['.LHa', '.LHA', '.lha', '.lzh'])
- binary = ('lha',)
+ exts = frozenset([".LHa", ".LHA", ".lha", ".lzh"])
+ binary = ("lha",)
default_unpack_cmd = '{binary} xfq "{path}"'
class _Ar(_Archive, ArComp):
- exts = frozenset(['.a', '.deb'])
- binary = ('ar',)
+ exts = frozenset([".a", ".deb"])
+ binary = ("ar",)
default_unpack_cmd = '{binary} x "{path}"'
class _LZMA(_CompressedFile, ArComp):
- exts = frozenset(['.lzma'])
- binary = ('lzma',)
+ exts = frozenset([".lzma"])
+ binary = ("lzma",)
default_unpack_cmd = '{binary} -dc "{path}"'
diff --git a/src/snakeoil/compression/_bzip2.py b/src/snakeoil/compression/_bzip2.py
index 1a38922c..122debd1 100644
--- a/src/snakeoil/compression/_bzip2.py
+++ b/src/snakeoil/compression/_bzip2.py
@@ -25,6 +25,7 @@ bz2_path = process.find_binary("bzip2")
try:
from bz2 import BZ2File, compress as _compress_data, decompress as _decompress_data
+
native = True
except ImportError:
@@ -40,7 +41,7 @@ _decompress_handle = partial(_util.decompress_handle, bz2_path)
try:
lbzip2_path = process.find_binary("lbzip2")
- lbzip2_compress_args = (f'-n{multiprocessing.cpu_count()}', )
+ lbzip2_compress_args = (f"-n{multiprocessing.cpu_count()}",)
lbzip2_decompress_args = lbzip2_compress_args
parallelizable = True
except process.CommandNotFound:
@@ -51,28 +52,35 @@ except process.CommandNotFound:
def compress_data(data, level=9, parallelize=False):
if parallelize and parallelizable:
- return _util.compress_data(lbzip2_path, data, compresslevel=level,
- extra_args=lbzip2_compress_args)
+ return _util.compress_data(
+ lbzip2_path, data, compresslevel=level, extra_args=lbzip2_compress_args
+ )
return _compress_data(data, compresslevel=level)
+
def decompress_data(data, parallelize=False):
if parallelize and parallelizable:
- return _util.decompress_data(lbzip2_path, data,
- extra_args=lbzip2_decompress_args)
+ return _util.decompress_data(
+ lbzip2_path, data, extra_args=lbzip2_decompress_args
+ )
return _decompress_data(data)
+
def compress_handle(handle, level=9, parallelize=False):
if parallelize and parallelizable:
- return _util.compress_handle(lbzip2_path, handle, compresslevel=level,
- extra_args=lbzip2_compress_args)
+ return _util.compress_handle(
+ lbzip2_path, handle, compresslevel=level, extra_args=lbzip2_compress_args
+ )
elif native and isinstance(handle, str):
- return BZ2File(handle, mode='w', compresslevel=level)
+ return BZ2File(handle, mode="w", compresslevel=level)
return _compress_handle(handle, compresslevel=level)
+
def decompress_handle(handle, parallelize=False):
if parallelize and parallelizable:
- return _util.decompress_handle(lbzip2_path, handle,
- extra_args=lbzip2_decompress_args)
+ return _util.decompress_handle(
+ lbzip2_path, handle, extra_args=lbzip2_decompress_args
+ )
elif native and isinstance(handle, str):
- return BZ2File(handle, mode='r')
+ return BZ2File(handle, mode="r")
return _decompress_handle(handle)
diff --git a/src/snakeoil/compression/_util.py b/src/snakeoil/compression/_util.py
index e1af5aef..b95d80c1 100644
--- a/src/snakeoil/compression/_util.py
+++ b/src/snakeoil/compression/_util.py
@@ -6,15 +6,20 @@ import subprocess
def _drive_process(args, mode, data):
- p = subprocess.Popen(args,
- stdin=subprocess.PIPE, stdout=subprocess.PIPE,
- stderr=subprocess.PIPE, close_fds=True)
+ p = subprocess.Popen(
+ args,
+ stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ close_fds=True,
+ )
try:
stdout, stderr = p.communicate(data)
if p.returncode != 0:
- args = ' '.join(args)
+ args = " ".join(args)
raise ValueError(
- f"{mode} returned {p.returncode} exitcode from '{args}', stderr={stderr.decode()}")
+ f"{mode} returned {p.returncode} exitcode from '{args}', stderr={stderr.decode()}"
+ )
return stdout
finally:
if p is not None and p.returncode is None:
@@ -22,21 +27,20 @@ def _drive_process(args, mode, data):
def compress_data(binary, data, compresslevel=9, extra_args=()):
- args = [binary, f'-{compresslevel}c']
+ args = [binary, f"-{compresslevel}c"]
args.extend(extra_args)
- return _drive_process(args, 'compression', data)
+ return _drive_process(args, "compression", data)
def decompress_data(binary, data, extra_args=()):
- args = [binary, '-dc']
+ args = [binary, "-dc"]
args.extend(extra_args)
- return _drive_process(args, 'decompression', data)
+ return _drive_process(args, "decompression", data)
class _process_handle:
-
def __init__(self, handle, args, is_read=False):
- self.mode = 'rb' if is_read else 'wb'
+ self.mode = "rb" if is_read else "wb"
self.args = tuple(args)
self.is_read = is_read
@@ -51,9 +55,10 @@ class _process_handle:
handle = open(handle, mode=self.mode)
close = True
elif not isinstance(handle, int):
- if not hasattr(handle, 'fileno'):
+ if not hasattr(handle, "fileno"):
raise TypeError(
- f"handle {handle!r} isn't a string, integer, and lacks a fileno method")
+ f"handle {handle!r} isn't a string, integer, and lacks a fileno method"
+ )
handle = handle.fileno()
try:
@@ -64,18 +69,17 @@ class _process_handle:
def _setup_process(self, handle):
self.position = 0
- stderr = open(os.devnull, 'wb')
+ stderr = open(os.devnull, "wb")
kwds = dict(stderr=stderr)
if self.is_read:
- kwds['stdin'] = handle
- kwds['stdout'] = subprocess.PIPE
+ kwds["stdin"] = handle
+ kwds["stdout"] = subprocess.PIPE
else:
- kwds['stdout'] = handle
- kwds['stdin'] = subprocess.PIPE
+ kwds["stdout"] = handle
+ kwds["stdin"] = subprocess.PIPE
try:
- self._process = subprocess.Popen(
- self.args, close_fds=True, **kwds)
+ self._process = subprocess.Popen(self.args, close_fds=True, **kwds)
finally:
stderr.close()
@@ -106,7 +110,8 @@ class _process_handle:
if self._allow_reopen is None:
raise TypeError(
f"instance {self} can't do negative seeks: "
- f"asked for {position}, was at {self.position}")
+ f"asked for {position}, was at {self.position}"
+ )
self._terminate()
self._open_handle(self._allow_reopen)
return self.seek(position)
@@ -130,7 +135,7 @@ class _process_handle:
# reallocating it continually; via this usage, we
# only slice once the val is less than seek_size;
# iow, two allocations worst case.
- null_block = '\0' * seek_size
+ null_block = "\0" * seek_size
while val:
self.write(null_block[:val])
offset -= val
@@ -145,11 +150,13 @@ class _process_handle:
raise
def close(self):
- if not hasattr(self, '_process'):
+ if not hasattr(self, "_process"):
return
if self._process.returncode is not None:
if self._process.returncode != 0:
- raise Exception(f"{self.args} invocation had non zero exit: {self._process.returncode}")
+ raise Exception(
+ f"{self.args} invocation had non zero exit: {self._process.returncode}"
+ )
return
self.handle.close()
@@ -163,12 +170,12 @@ class _process_handle:
def compress_handle(binary_path, handle, compresslevel=9, extra_args=()):
- args = [binary_path, f'-{compresslevel}c']
+ args = [binary_path, f"-{compresslevel}c"]
args.extend(extra_args)
return _process_handle(handle, args, False)
def decompress_handle(binary_path, handle, extra_args=()):
- args = [binary_path, '-dc']
+ args = [binary_path, "-dc"]
args.extend(extra_args)
return _process_handle(handle, args, True)
diff --git a/src/snakeoil/compression/_xz.py b/src/snakeoil/compression/_xz.py
index 47077379..9a91a4c5 100644
--- a/src/snakeoil/compression/_xz.py
+++ b/src/snakeoil/compression/_xz.py
@@ -21,7 +21,7 @@ from ..compression import _util
# if xz can't be found, throw an error.
xz_path = process.find_binary("xz")
-xz_compress_args = (f'-T{multiprocessing.cpu_count()}',)
+xz_compress_args = (f"-T{multiprocessing.cpu_count()}",)
xz_decompress_args = xz_compress_args
parallelizable = True
@@ -29,6 +29,7 @@ try:
from lzma import LZMAFile
from lzma import compress as _compress_data
from lzma import decompress as _decompress_data
+
native = True
except ImportError:
@@ -45,30 +46,33 @@ _decompress_handle = partial(_util.decompress_handle, xz_path)
def compress_data(data, level=9, parallelize=False):
if parallelize and parallelizable:
- return _util.compress_data(xz_path, data, compresslevel=level,
- extra_args=xz_compress_args)
+ return _util.compress_data(
+ xz_path, data, compresslevel=level, extra_args=xz_compress_args
+ )
if native:
return _compress_data(data, preset=level)
return _compress_data(data, compresslevel=level)
+
def decompress_data(data, parallelize=False):
if parallelize and parallelizable:
- return _util.decompress_data(xz_path, data,
- extra_args=xz_decompress_args)
+ return _util.decompress_data(xz_path, data, extra_args=xz_decompress_args)
return _decompress_data(data)
+
def compress_handle(handle, level=9, parallelize=False):
if parallelize and parallelizable:
- return _util.compress_handle(xz_path, handle, compresslevel=level,
- extra_args=xz_compress_args)
+ return _util.compress_handle(
+ xz_path, handle, compresslevel=level, extra_args=xz_compress_args
+ )
elif native and isinstance(handle, str):
- return LZMAFile(handle, mode='w', preset=level)
+ return LZMAFile(handle, mode="w", preset=level)
return _compress_handle(handle, compresslevel=level)
+
def decompress_handle(handle, parallelize=False):
if parallelize and parallelizable:
- return _util.decompress_handle(xz_path, handle,
- extra_args=xz_decompress_args)
- elif (native and isinstance(handle, str)):
- return LZMAFile(handle, mode='r')
+ return _util.decompress_handle(xz_path, handle, extra_args=xz_decompress_args)
+ elif native and isinstance(handle, str):
+ return LZMAFile(handle, mode="r")
return _decompress_handle(handle)
diff --git a/src/snakeoil/constraints.py b/src/snakeoil/constraints.py
index 63e46715..c239727c 100644
--- a/src/snakeoil/constraints.py
+++ b/src/snakeoil/constraints.py
@@ -31,8 +31,9 @@ class Constraint(Protocol):
domain.
:return: ``True`` if the assignment is satisfied.
"""
+
def __call__(self, **kwargs: Any) -> bool:
- raise NotImplementedError('Constraint', '__call__')
+ raise NotImplementedError("Constraint", "__call__")
class _Domain(list):
@@ -75,10 +76,13 @@ class Problem:
of a :py:class:`dict` assigning to each variable in the problem a
single value from it's domain.
"""
+
def __init__(self):
self.variables: dict[str, _Domain] = {}
self.constraints: list[tuple[Constraint, frozenset[str]]] = []
- self.vconstraints: dict[str, list[tuple[Constraint, frozenset[str]]]] = defaultdict(list)
+ self.vconstraints: dict[
+ str, list[tuple[Constraint, frozenset[str]]]
+ ] = defaultdict(list)
def add_variable(self, domain: Iterable[Any], *variables: str):
"""Add variables to the problem, which use the specified domain.
@@ -94,7 +98,9 @@ class Problem:
from each domain.
"""
for variable in variables:
- assert variable not in self.variables, f'variable {variable!r} was already added'
+ assert (
+ variable not in self.variables
+ ), f"variable {variable!r} was already added"
self.variables[variable] = _Domain(domain)
def add_constraint(self, constraint: Constraint, variables: frozenset[str]):
@@ -110,10 +116,15 @@ class Problem:
"""
self.constraints.append((constraint, variables))
for variable in variables:
- assert variable in self.variables, f'unknown variable {variable!r}'
+ assert variable in self.variables, f"unknown variable {variable!r}"
self.vconstraints[variable].append((constraint, variables))
- def __check(self, constraint: Constraint, variables: frozenset[str], assignments: dict[str, Any]) -> bool:
+ def __check(
+ self,
+ constraint: Constraint,
+ variables: frozenset[str],
+ assignments: dict[str, Any],
+ ) -> bool:
assignments = {k: v for k, v in assignments.items() if k in variables}
unassigned = variables - assignments.keys()
if not unassigned:
@@ -147,14 +158,17 @@ class Problem:
# mix the Degree and Minimum Remaining Values (MRV) heuristics
lst = sorted(
(-len(self.vconstraints[name]), len(domain), name)
- for name, domain in self.variables.items())
+ for name, domain in self.variables.items()
+ )
for _, _, variable in lst:
if variable not in assignments:
values = self.variables[variable][:]
push_domains = tuple(
- domain for name, domain in self.variables.items()
- if name != variable and name not in assignments)
+ domain
+ for name, domain in self.variables.items()
+ if name != variable and name not in assignments
+ )
break
else:
# no unassigned variables, we've got a solution.
diff --git a/src/snakeoil/containers.py b/src/snakeoil/containers.py
index ba211556..ebfc1b39 100644
--- a/src/snakeoil/containers.py
+++ b/src/snakeoil/containers.py
@@ -4,8 +4,12 @@ Container classes and functionality for implementing them
"""
__all__ = (
- "InvertedContains", "SetMixin", "LimitedChangeSet", "Unchangable",
- "ProtectedSet", "RefCountingSet"
+ "InvertedContains",
+ "SetMixin",
+ "LimitedChangeSet",
+ "Unchangable",
+ "ProtectedSet",
+ "RefCountingSet",
)
from itertools import chain, filterfalse
@@ -70,9 +74,11 @@ class SetMixin:
@steal_docs(set)
def __xor__(self, other, kls=None):
- return (kls or self.__class__)(chain(
- (x for x in self if x not in other),
- (x for x in other if x not in self)))
+ return (kls or self.__class__)(
+ chain(
+ (x for x in self if x not in other), (x for x in other if x not in self)
+ )
+ )
@steal_docs(set)
def __rxor__(self, other):
@@ -120,8 +126,7 @@ class LimitedChangeSet(SetMixin):
def _default_key_validator(val):
return val
- def __init__(self, initial_keys, unchangable_keys=None,
- key_validator=None):
+ def __init__(self, initial_keys, unchangable_keys=None, key_validator=None):
"""
:param initial_keys: iterable holding the initial values to set
:param unchangable_keys: container holding keys that cannot be changed
@@ -185,8 +190,7 @@ class LimitedChangeSet(SetMixin):
def rollback(self, point=0):
l = self.changes_count()
if point < 0 or point > l:
- raise TypeError(
- "%s point must be >=0 and <= changes_count()" % point)
+ raise TypeError("%s point must be >=0 and <= changes_count()" % point)
while l > point:
change, key = self._change_order.pop(-1)
self._changed.remove(key)
@@ -221,9 +225,8 @@ class LimitedChangeSet(SetMixin):
class Unchangable(Exception):
-
def __init__(self, key):
- super().__init__(f'key {key!r} is unchangable')
+ super().__init__(f"key {key!r} is unchangable")
self.key = key
@@ -240,6 +243,7 @@ class ProtectedSet(SetMixin):
>>> myset.remove(2)
>>> assert 2 not in protected
"""
+
def __init__(self, orig_set):
self._orig = orig_set
self._new = set()
diff --git a/src/snakeoil/contexts.py b/src/snakeoil/contexts.py
index 57092b4d..394574b5 100644
--- a/src/snakeoil/contexts.py
+++ b/src/snakeoil/contexts.py
@@ -41,6 +41,7 @@ from .sequences import predicate_split
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
+
class SplitExec:
"""Context manager separating code execution across parent/child processes.
@@ -48,6 +49,7 @@ class SplitExec:
of the context are executed only on the forked child. Exceptions are
pickled and passed back to the parent.
"""
+
def __init__(self):
self.__trace_lock = threading.Lock()
self.__orig_sys_trace = None
@@ -184,7 +186,7 @@ class SplitExec:
@staticmethod
def __excepthook(_exc_type, exc_value, exc_traceback):
"""Output the proper traceback information from the chroot context."""
- if hasattr(exc_value, '__traceback_list__'):
+ if hasattr(exc_value, "__traceback_list__"):
sys.stderr.write(exc_value.__traceback_list__)
else:
traceback.print_tb(exc_traceback)
@@ -253,7 +255,7 @@ class SplitExec:
except AttributeError:
# an offset of two accounts for this method and its caller
frame = inspect.stack(0)[2][0]
- while frame.f_locals.get('self') is self:
+ while frame.f_locals.get("self") is self:
frame = frame.f_back
self.__frame = frame # pylint: disable=W0201
return frame
@@ -262,11 +264,24 @@ class SplitExec:
class Namespace(SplitExec):
"""Context manager that provides Linux namespace support."""
- def __init__(self, mount=False, uts=True, ipc=False, net=False, pid=False,
- user=False, hostname=None):
+ def __init__(
+ self,
+ mount=False,
+ uts=True,
+ ipc=False,
+ net=False,
+ pid=False,
+ user=False,
+ hostname=None,
+ ):
self._hostname = hostname
self._namespaces = {
- 'mount': mount, 'uts': uts, 'ipc': ipc, 'net': net, 'pid': pid, 'user': user,
+ "mount": mount,
+ "uts": uts,
+ "ipc": ipc,
+ "net": net,
+ "pid": pid,
+ "user": user,
}
super().__init__()
@@ -279,8 +294,8 @@ class GitStash(AbstractContextManager):
def __init__(self, path, pathspecs=None, staged=False):
self.path = path
- self.pathspecs = ['--'] + pathspecs if pathspecs else []
- self._staged = ['--keep-index'] if staged else []
+ self.pathspecs = ["--"] + pathspecs if pathspecs else []
+ self._staged = ["--keep-index"] if staged else []
self._stashed = False
def __enter__(self):
@@ -288,14 +303,18 @@ class GitStash(AbstractContextManager):
# check for untracked or modified/uncommitted files
try:
p = subprocess.run(
- ['git', 'status', '--porcelain=1', '-u'] + self.pathspecs,
- stdout=subprocess.PIPE, stderr=subprocess.DEVNULL,
- cwd=self.path, encoding='utf8', check=True)
+ ["git", "status", "--porcelain=1", "-u"] + self.pathspecs,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.DEVNULL,
+ cwd=self.path,
+ encoding="utf8",
+ check=True,
+ )
except subprocess.CalledProcessError:
- raise ValueError(f'not a git repo: {self.path}')
+ raise ValueError(f"not a git repo: {self.path}")
# split file changes into unstaged vs staged
- unstaged, staged = predicate_split(lambda x: x[1] == ' ', p.stdout.splitlines())
+ unstaged, staged = predicate_split(lambda x: x[1] == " ", p.stdout.splitlines())
# don't stash when no relevant changes exist
if self._staged:
@@ -306,14 +325,18 @@ class GitStash(AbstractContextManager):
# stash all existing untracked or modified/uncommitted files
try:
- stash_cmd = ['git', 'stash', 'push', '-u', '-m', 'pkgcheck scan --commits']
+ stash_cmd = ["git", "stash", "push", "-u", "-m", "pkgcheck scan --commits"]
subprocess.run(
stash_cmd + self._staged + self.pathspecs,
- stdout=subprocess.DEVNULL, stderr=subprocess.PIPE,
- cwd=self.path, check=True, encoding='utf8')
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.PIPE,
+ cwd=self.path,
+ check=True,
+ encoding="utf8",
+ )
except subprocess.CalledProcessError as e:
error = e.stderr.splitlines()[0]
- raise UserException(f'git failed stashing files: {error}')
+ raise UserException(f"git failed stashing files: {error}")
self._stashed = True
def __exit__(self, _exc_type, _exc_value, _traceback):
@@ -321,12 +344,16 @@ class GitStash(AbstractContextManager):
if self._stashed:
try:
subprocess.run(
- ['git', 'stash', 'pop'],
- stdout=subprocess.DEVNULL, stderr=subprocess.PIPE,
- cwd=self.path, check=True, encoding='utf8')
+ ["git", "stash", "pop"],
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.PIPE,
+ cwd=self.path,
+ check=True,
+ encoding="utf8",
+ )
except subprocess.CalledProcessError as e:
error = e.stderr.splitlines()[0]
- raise UserException(f'git failed applying stash: {error}')
+ raise UserException(f"git failed applying stash: {error}")
@contextmanager
@@ -347,7 +374,7 @@ def chdir(path):
@contextmanager
-def syspath(path: str, condition: bool=True, position: int=0):
+def syspath(path: str, condition: bool = True, position: int = 0):
"""Context manager that mangles ``sys.path`` and then reverts on exit.
:param path: The directory path to add to ``sys.path``.
@@ -425,6 +452,7 @@ def os_environ(*remove, **update):
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
@contextmanager
def patch(target, new):
"""Simplified module monkey patching via context manager.
@@ -434,7 +462,7 @@ def patch(target, new):
"""
def _import_module(target):
- components = target.split('.')
+ components = target.split(".")
import_path = components.pop(0)
module = import_module(import_path)
for comp in components:
@@ -448,16 +476,16 @@ def patch(target, new):
def _get_target(target):
if isinstance(target, str):
try:
- module, attr = target.rsplit('.', 1)
+ module, attr = target.rsplit(".", 1)
except (TypeError, ValueError):
- raise TypeError(f'invalid target: {target!r}')
+ raise TypeError(f"invalid target: {target!r}")
module = _import_module(module)
return module, attr
else:
try:
obj, attr = target
except (TypeError, ValueError):
- raise TypeError(f'invalid target: {target!r}')
+ raise TypeError(f"invalid target: {target!r}")
return obj, attr
obj, attr = _get_target(target)
diff --git a/src/snakeoil/currying.py b/src/snakeoil/currying.py
index f8d8d971..bc16dc0e 100644
--- a/src/snakeoil/currying.py
+++ b/src/snakeoil/currying.py
@@ -46,14 +46,19 @@ def pre_curry(func, *args, **kwargs):
"""
if not kwargs:
+
def callit(*moreargs, **morekwargs):
return func(*(args + moreargs), **morekwargs)
+
elif not args:
+
def callit(*moreargs, **morekwargs):
kw = kwargs.copy()
kw.update(morekwargs)
return func(*moreargs, **kw)
+
else:
+
def callit(*moreargs, **morekwargs):
kw = kwargs.copy()
kw.update(morekwargs)
@@ -67,14 +72,19 @@ def post_curry(func, *args, **kwargs):
"""passed in args are appended to any further args supplied"""
if not kwargs:
+
def callit(*moreargs, **morekwargs):
return func(*(moreargs + args), **morekwargs)
+
elif not args:
+
def callit(*moreargs, **morekwargs):
kw = morekwargs.copy()
kw.update(kwargs)
return func(*moreargs, **kw)
+
else:
+
def callit(*moreargs, **morekwargs):
kw = morekwargs.copy()
kw.update(kwargs)
@@ -112,18 +122,32 @@ def wrap_exception(recast_exception, *args, **kwds):
# set this here so that 2to3 will rewrite it.
try:
if not issubclass(recast_exception, Exception):
- raise ValueError("recast_exception must be an %s derivative: got %r" %
- (Exception, recast_exception))
+ raise ValueError(
+ "recast_exception must be an %s derivative: got %r"
+ % (Exception, recast_exception)
+ )
except TypeError as e:
- raise TypeError("recast_exception must be an %s derivative; got %r, failed %r",
- (Exception.__name__, recast_exception, e))
+ raise TypeError(
+ "recast_exception must be an %s derivative; got %r, failed %r",
+ (Exception.__name__, recast_exception, e),
+ )
ignores = kwds.pop("ignores", (recast_exception,))
pass_error = kwds.pop("pass_error", None)
- return wrap_exception_complex(partial(_simple_throw, recast_exception, args, kwds, pass_error), ignores)
-
-
-def _simple_throw(recast_exception, recast_args, recast_kwds, pass_error,
- exception, functor, args, kwds):
+ return wrap_exception_complex(
+ partial(_simple_throw, recast_exception, args, kwds, pass_error), ignores
+ )
+
+
+def _simple_throw(
+ recast_exception,
+ recast_args,
+ recast_kwds,
+ pass_error,
+ exception,
+ functor,
+ args,
+ kwds,
+):
if pass_error:
recast_kwds[pass_error] = exception
return recast_exception(*recast_args, **recast_kwds)
@@ -131,15 +155,22 @@ def _simple_throw(recast_exception, recast_args, recast_kwds, pass_error,
def wrap_exception_complex(creation_func, ignores):
try:
- if not hasattr(ignores, '__iter__') and issubclass(ignores, Exception) or ignores is Exception:
+ if (
+ not hasattr(ignores, "__iter__")
+ and issubclass(ignores, Exception)
+ or ignores is Exception
+ ):
ignores = (ignores,)
ignores = tuple(ignores)
except TypeError as e:
- raise TypeError("ignores must be either a tuple of %s, or a %s: got %r, error %r"
- % (Exception.__name__, Exception.__name__, ignores, e))
+ raise TypeError(
+ "ignores must be either a tuple of %s, or a %s: got %r, error %r"
+ % (Exception.__name__, Exception.__name__, ignores, e)
+ )
if not all(issubclass(x, Exception) for x in ignores):
- raise TypeError("ignores has a non %s derivative in it: %r" %
- (Exception.__name__, ignores))
+ raise TypeError(
+ "ignores has a non %s derivative in it: %r" % (Exception.__name__, ignores)
+ )
return partial(_inner_wrap_exception, creation_func, ignores)
@@ -153,5 +184,6 @@ def _inner_wrap_exception(exception_maker, ignores, functor):
raise
except Exception as e:
raise exception_maker(e, functor, args, kwargs) from e
+
_wrap_exception.func = functor
return pretty_docs(_wrap_exception, name=functor.__name__)
diff --git a/src/snakeoil/data_source.py b/src/snakeoil/data_source.py
index 1faa5400..3b0ccb16 100644
--- a/src/snakeoil/data_source.py
+++ b/src/snakeoil/data_source.py
@@ -33,8 +33,13 @@ we caught the exception.
"""
__all__ = (
- "base", "bz2_source", "data_source", "local_source", "text_data_source",
- "bytes_data_source", "invokable_data_source",
+ "base",
+ "bz2_source",
+ "data_source",
+ "local_source",
+ "text_data_source",
+ "bytes_data_source",
+ "invokable_data_source",
)
import errno
@@ -62,10 +67,9 @@ def _mk_writable_cls(base, name):
exceptions attribute
"""
-
base_cls = base
exceptions = (MemoryError,)
- __slots__ = ('_callback',)
+ __slots__ = ("_callback",)
def __init__(self, callback, data):
"""
@@ -85,6 +89,7 @@ def _mk_writable_cls(base, name):
self._callback(self.read())
self._callback = None
self.base_cls.close(self)
+
kls.__name__ = name
return kls
@@ -100,6 +105,7 @@ class text_ro_StringIO(stringio.text_readonly):
Specifically this adds the necessary `exceptions` attribute; see
:py:class:`snakeoil.stringio.text_readonly` for methods details.
"""
+
__slots__ = ()
exceptions = (MemoryError, TypeError)
@@ -111,6 +117,7 @@ class bytes_ro_StringIO(stringio.bytes_readonly):
Specifically this adds the necessary `exceptions` attribute; see
:py:class:`snakeoil.stringio.bytes_readonly` for methods details.
"""
+
__slots__ = ()
exceptions = (MemoryError, TypeError)
@@ -131,6 +138,7 @@ class base:
:ivar path: If None, no local path is available- else it's the ondisk path to
the data
"""
+
__slots__ = ("weakref",)
path = None
@@ -155,7 +163,8 @@ class base:
def transfer_to_path(self, path):
return self.transfer_to_data_source(
- local_source(path, mutable=True, encoding=None))
+ local_source(path, mutable=True, encoding=None)
+ )
def transfer_to_data_source(self, write_source):
read_f, m, write_f = None, None, None
@@ -208,31 +217,32 @@ class local_source(base):
raise TypeError("data source %s is immutable" % (self,))
if self.encoding:
opener = open_file
- opener = post_curry(opener, buffering=self.buffering_window,
- encoding=self.encoding)
+ opener = post_curry(
+ opener, buffering=self.buffering_window, encoding=self.encoding
+ )
else:
opener = post_curry(open_file, self.buffering_window)
if not writable:
- return opener(self.path, 'r')
+ return opener(self.path, "r")
try:
return opener(self.path, "r+")
except IOError as ie:
if ie.errno != errno.ENOENT:
raise
- return opener(self.path, 'w+')
+ return opener(self.path, "w+")
@klass.steal_docs(base)
def bytes_fileobj(self, writable=False):
if not writable:
- return open_file(self.path, 'rb', self.buffering_window)
+ return open_file(self.path, "rb", self.buffering_window)
if not self.mutable:
raise TypeError("data source %s is immutable" % (self,))
try:
- return open_file(self.path, 'rb+', self.buffering_window)
+ return open_file(self.path, "rb+", self.buffering_window)
except IOError as ie:
if ie.errno != errno.ENOENT:
raise
- return open_file(self.path, 'wb+', self.buffering_window)
+ return open_file(self.path, "wb+", self.buffering_window)
class bz2_source(base):
@@ -255,7 +265,8 @@ class bz2_source(base):
def text_fileobj(self, writable=False):
data = compression.decompress_data(
- 'bzip2', fileutils.readfile_bytes(self.path)).decode()
+ "bzip2", fileutils.readfile_bytes(self.path)
+ ).decode()
if writable:
if not self.mutable:
raise TypeError(f"data source {self} is not mutable")
@@ -263,8 +274,7 @@ class bz2_source(base):
return text_ro_StringIO(data)
def bytes_fileobj(self, writable=False):
- data = compression.decompress_data(
- 'bzip2', fileutils.readfile_bytes(self.path))
+ data = compression.decompress_data("bzip2", fileutils.readfile_bytes(self.path))
if writable:
if not self.mutable:
raise TypeError(f"data source {self} is not mutable")
@@ -275,7 +285,7 @@ class bz2_source(base):
if isinstance(data, str):
data = data.encode()
with open(self.path, "wb") as f:
- f.write(compression.compress_data('bzip2', data))
+ f.write(compression.compress_data("bzip2", data))
class data_source(base):
@@ -293,7 +303,7 @@ class data_source(base):
:ivar path: note that path is None for this class- no on disk location available.
"""
- __slots__ = ('data', 'mutable')
+ __slots__ = ("data", "mutable")
def __init__(self, data, mutable=False):
"""
@@ -305,7 +315,7 @@ class data_source(base):
self.mutable = mutable
def _convert_data(self, mode):
- if mode == 'bytes':
+ if mode == "bytes":
if isinstance(self.data, bytes):
return self.data
return self.data.encode()
@@ -318,9 +328,8 @@ class data_source(base):
if writable:
if not self.mutable:
raise TypeError(f"data source {self} is not mutable")
- return text_wr_StringIO(self._reset_data,
- self._convert_data('text'))
- return text_ro_StringIO(self._convert_data('text'))
+ return text_wr_StringIO(self._reset_data, self._convert_data("text"))
+ return text_ro_StringIO(self._convert_data("text"))
def _reset_data(self, data):
if isinstance(self.data, bytes):
@@ -335,9 +344,8 @@ class data_source(base):
if writable:
if not self.mutable:
raise TypeError(f"data source {self} is not mutable")
- return bytes_wr_StringIO(self._reset_data,
- self._convert_data('bytes'))
- return bytes_ro_StringIO(self._convert_data('bytes'))
+ return bytes_wr_StringIO(self._reset_data, self._convert_data("bytes"))
+ return bytes_ro_StringIO(self._convert_data("bytes"))
class text_data_source(data_source):
@@ -355,7 +363,7 @@ class text_data_source(data_source):
data_source.__init__(self, data, mutable=mutable)
def _convert_data(self, mode):
- if mode != 'bytes':
+ if mode != "bytes":
return self.data
return self.data.encode()
@@ -375,7 +383,7 @@ class bytes_data_source(data_source):
data_source.__init__(self, data, mutable=mutable)
def _convert_data(self, mode):
- if mode == 'bytes':
+ if mode == "bytes":
return self.data
return self.data.decode()
@@ -390,6 +398,7 @@ class invokable_data_source(data_source):
Note that this instance is explicitly readonly.
"""
+
__slots__ = ()
def __init__(self, data):
@@ -412,7 +421,9 @@ class invokable_data_source(data_source):
return self.data(False)
@classmethod
- def wrap_function(cls, invokable, returns_text=True, returns_handle=False, encoding_hint=None):
+ def wrap_function(
+ cls, invokable, returns_text=True, returns_handle=False, encoding_hint=None
+ ):
"""
Helper function to automatically convert a function that returns text or bytes into appropriate
callable
@@ -425,10 +436,20 @@ class invokable_data_source(data_source):
:param encoding_hint: the preferred encoding to use for encoding
:return: invokable_data_source instance
"""
- return cls(partial(cls._simple_wrapper, invokable, encoding_hint, returns_text, returns_handle))
+ return cls(
+ partial(
+ cls._simple_wrapper,
+ invokable,
+ encoding_hint,
+ returns_text,
+ returns_handle,
+ )
+ )
@staticmethod
- def _simple_wrapper(invokable, encoding_hint, returns_text, returns_handle, text_wanted):
+ def _simple_wrapper(
+ invokable, encoding_hint, returns_text, returns_handle, text_wanted
+ ):
data = invokable()
if returns_text != text_wanted:
if text_wanted:
@@ -446,7 +467,7 @@ class invokable_data_source(data_source):
data = data.read()
if encoding_hint is None:
# fallback to utf8
- encoding_hint = 'utf8'
+ encoding_hint = "utf8"
data = data.encode(encoding_hint)
elif returns_handle:
return data
diff --git a/src/snakeoil/decorators.py b/src/snakeoil/decorators.py
index 550d19cb..c0e0429d 100644
--- a/src/snakeoil/decorators.py
+++ b/src/snakeoil/decorators.py
@@ -7,29 +7,36 @@ from .contexts import Namespace, SplitExec
def splitexec(func):
"""Run the decorated function in another process."""
+
@wraps(func)
def wrapper(*args, **kwargs):
with SplitExec():
return func(*args, **kwargs)
+
return wrapper
def namespace(**namespaces):
"""Run the decorated function in a specified namespace."""
+
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
with Namespace(**namespaces):
return func(*args, **kwargs)
+
return wrapper
+
return decorator
def coroutine(func):
"""Prime a coroutine for input."""
+
@wraps(func)
def prime(*args, **kwargs):
cr = func(*args, **kwargs)
next(cr)
return cr
+
return prime
diff --git a/src/snakeoil/demandimport.py b/src/snakeoil/demandimport.py
index c87e114d..0a24ccc5 100644
--- a/src/snakeoil/demandimport.py
+++ b/src/snakeoil/demandimport.py
@@ -13,14 +13,16 @@ from importlib.util import LazyLoader
_disabled = False
# modules that have issues when lazily imported
-_skip = frozenset([
- '__builtin__',
- '__future__',
- 'builtins',
- 'grp',
- 'pwd',
- 'OpenSSL.SSL', # pyopenssl
-])
+_skip = frozenset(
+ [
+ "__builtin__",
+ "__future__",
+ "builtins",
+ "grp",
+ "pwd",
+ "OpenSSL.SSL", # pyopenssl
+ ]
+)
class _LazyLoader(LazyLoader):
@@ -35,10 +37,8 @@ class _LazyLoader(LazyLoader):
# custom loaders using our extended LazyLoader
-_extensions_loader = _LazyLoader.factory(
- importlib.machinery.ExtensionFileLoader)
-_bytecode_loader = _LazyLoader.factory(
- importlib.machinery.SourcelessFileLoader)
+_extensions_loader = _LazyLoader.factory(importlib.machinery.ExtensionFileLoader)
+_bytecode_loader = _LazyLoader.factory(importlib.machinery.SourcelessFileLoader)
_source_loader = _LazyLoader.factory(importlib.machinery.SourceFileLoader)
@@ -54,7 +54,11 @@ def _filefinder(path):
def enable():
"""Enable lazy loading for all future module imports."""
- if os.environ.get('SNAKEOIL_DEMANDIMPORT', 'y').lower() not in ('n', 'no' '0', 'false'):
+ if os.environ.get("SNAKEOIL_DEMANDIMPORT", "y").lower() not in (
+ "n",
+ "no" "0",
+ "false",
+ ):
sys.path_hooks.insert(0, _filefinder)
diff --git a/src/snakeoil/demandload.py b/src/snakeoil/demandload.py
index 0c84e8cf..3800622f 100644
--- a/src/snakeoil/demandload.py
+++ b/src/snakeoil/demandload.py
@@ -49,8 +49,9 @@ from .modules import load_any
# There are some demandloaded imports below the definition of demandload.
-_allowed_chars = "".join((x.isalnum() or x in "_.") and " " or "a"
- for x in map(chr, range(256)))
+_allowed_chars = "".join(
+ (x.isalnum() or x in "_.") and " " or "a" for x in map(chr, range(256))
+)
def parse_imports(imports):
@@ -72,15 +73,16 @@ def parse_imports(imports):
:rtype: iterable of tuples of two C{str} objects.
"""
for s in imports:
- fromlist = s.split(':', 1)
+ fromlist = s.split(":", 1)
if len(fromlist) == 1:
# Not a "from" import.
- if '.' in s:
+ if "." in s:
raise ValueError(
"dotted imports are disallowed; see "
"snakeoil.demandload docstring for "
- f"details; {s!r}")
- split = s.split('@', 1)
+ f"details; {s!r}"
+ )
+ split = s.split("@", 1)
for s in split:
if not s.translate(_allowed_chars).isspace():
raise ValueError(f"bad target: {s}")
@@ -94,28 +96,33 @@ def parse_imports(imports):
base, targets = fromlist
if not base.translate(_allowed_chars).isspace():
raise ValueError(f"bad target: {base}")
- for target in targets.split(','):
- split = target.split('@', 1)
+ for target in targets.split(","):
+ split = target.split("@", 1)
for s in split:
if not s.translate(_allowed_chars).isspace():
raise ValueError(f"bad target: {s}")
- yield base + '.' + split[0], split[-1]
+ yield base + "." + split[0], split[-1]
+
def _protection_enabled_disabled():
return False
+
def _noisy_protection_disabled():
return False
+
def _protection_enabled_enabled():
val = os.environ.get("SNAKEOIL_DEMANDLOAD_PROTECTION", "n").lower()
return val in ("yes", "true", "1", "y")
+
def _noisy_protection_enabled():
val = os.environ.get("SNAKEOIL_DEMANDLOAD_WARN", "y").lower()
return val in ("yes", "true", "1", "y")
-if 'pydoc' in sys.modules or 'epydoc' in sys.modules:
+
+if "pydoc" in sys.modules or "epydoc" in sys.modules:
_protection_enabled = _protection_enabled_disabled
_noisy_protection = _noisy_protection_disabled
else:
@@ -164,15 +171,15 @@ class Placeholder:
"""
if not callable(load_func):
raise TypeError(f"load_func must be callable; got {load_func!r}")
- object.__setattr__(self, '_scope', scope)
- object.__setattr__(self, '_name', name)
- object.__setattr__(self, '_replacing_tids', [])
- object.__setattr__(self, '_load_func', load_func)
- object.__setattr__(self, '_loading_lock', threading.Lock())
+ object.__setattr__(self, "_scope", scope)
+ object.__setattr__(self, "_name", name)
+ object.__setattr__(self, "_replacing_tids", [])
+ object.__setattr__(self, "_load_func", load_func)
+ object.__setattr__(self, "_loading_lock", threading.Lock())
def _target_already_loaded(self, complain=True):
- name = object.__getattribute__(self, '_name')
- scope = object.__getattribute__(self, '_scope')
+ name = object.__getattribute__(self, "_name")
+ scope = object.__getattribute__(self, "_scope")
# in a threaded environment, it's possible for tid1 to get the
# placeholder from globals, python switches to tid2, which triggers
@@ -188,13 +195,16 @@ class Placeholder:
# it's impossible for this pathway to accidentally be triggered twice-
# meaning it is a misuse by the consuming client code.
if complain:
- tids_to_complain_about = object.__getattribute__(self, '_replacing_tids')
+ tids_to_complain_about = object.__getattribute__(self, "_replacing_tids")
if threading.current_thread().ident in tids_to_complain_about:
if _protection_enabled():
- raise ValueError(f'Placeholder for {name!r} was triggered twice')
+ raise ValueError(f"Placeholder for {name!r} was triggered twice")
elif _noisy_protection():
- logging.warning('Placeholder for %r was triggered multiple times '
- 'in file %r', name, scope.get("__file__", "unknown"))
+ logging.warning(
+ "Placeholder for %r was triggered multiple times " "in file %r",
+ name,
+ scope.get("__file__", "unknown"),
+ )
return scope[name]
def _get_target(self):
@@ -202,9 +212,9 @@ class Placeholder:
:return: the result of calling C{_load_func}.
"""
- preloaded_func = object.__getattribute__(self, '_target_already_loaded')
- with object.__getattribute__(self, '_loading_lock'):
- load_func = object.__getattribute__(self, '_load_func')
+ preloaded_func = object.__getattribute__(self, "_target_already_loaded")
+ with object.__getattribute__(self, "_loading_lock"):
+ load_func = object.__getattribute__(self, "_load_func")
if load_func is None:
# This means that there was contention; two threads made it into
# _get_target. That's fine; suppress complaints, and return the
@@ -215,18 +225,17 @@ class Placeholder:
# fix the scope, and replace this method with one that shortcircuits
# (and appropriately complains) the lookup.
result = load_func()
- scope = object.__getattribute__(self, '_scope')
- name = object.__getattribute__(self, '_name')
+ scope = object.__getattribute__(self, "_scope")
+ name = object.__getattribute__(self, "_name")
scope[name] = result
# Replace this method with the fast path/preloaded one; this
# is to ensure complaints get leveled if needed.
- object.__setattr__(self, '_get_target', preloaded_func)
- object.__setattr__(self, '_load_func', None)
-
+ object.__setattr__(self, "_get_target", preloaded_func)
+ object.__setattr__(self, "_load_func", None)
# note this step *has* to follow scope modification; else it
# will go maximum depth recursion.
- tids = object.__getattribute__(self, '_replacing_tids')
+ tids = object.__getattribute__(self, "_replacing_tids")
tids.append(threading.current_thread().ident)
return result
@@ -237,18 +246,18 @@ class Placeholder:
# Various methods proxied to our replacement.
def __str__(self):
- return self.__getattribute__('__str__')()
+ return self.__getattribute__("__str__")()
def __getattribute__(self, attr):
- result = object.__getattribute__(self, '_get_target')()
+ result = object.__getattribute__(self, "_get_target")()
return getattr(result, attr)
def __setattr__(self, attr, value):
- result = object.__getattribute__(self, '_get_target')()
+ result = object.__getattribute__(self, "_get_target")()
setattr(result, attr, value)
def __call__(self, *args, **kwargs):
- result = object.__getattribute__(self, '_get_target')()
+ result = object.__getattribute__(self, "_get_target")()
return result(*args, **kwargs)
@@ -267,7 +276,7 @@ def demandload(*imports, **kwargs):
"""
# pull the caller's global namespace if undefined
- scope = kwargs.pop('scope', sys._getframe(1).f_globals)
+ scope = kwargs.pop("scope", sys._getframe(1).f_globals)
for source, target in parse_imports(imports):
scope[target] = Placeholder.load_namespace(scope, target, source)
@@ -280,7 +289,7 @@ enabled_demandload = demandload
def disabled_demandload(*imports, **kwargs):
"""Exactly like :py:func:`demandload` but does all imports immediately."""
- scope = kwargs.pop('scope', sys._getframe(1).f_globals)
+ scope = kwargs.pop("scope", sys._getframe(1).f_globals)
for source, target in parse_imports(imports):
scope[target] = load_any(source)
@@ -292,21 +301,25 @@ def demand_compile_regexp(name, *args, **kwargs):
:param name: the name of the compiled re object in that scope.
"""
- scope = kwargs.pop('scope', sys._getframe(1).f_globals)
+ scope = kwargs.pop("scope", sys._getframe(1).f_globals)
scope[name] = Placeholder.load_regex(scope, name, *args, **kwargs)
def disabled_demand_compile_regexp(name, *args, **kwargs):
"""Exactly like :py:func:`demand_compile_regexp` but does all imports immediately."""
- scope = kwargs.pop('scope', sys._getframe(1).f_globals)
+ scope = kwargs.pop("scope", sys._getframe(1).f_globals)
scope[name] = re.compile(*args, **kwargs)
-if os.environ.get("SNAKEOIL_DEMANDLOAD_DISABLED", 'n').lower() in ('y', 'yes' '1', 'true'):
+if os.environ.get("SNAKEOIL_DEMANDLOAD_DISABLED", "n").lower() in (
+ "y",
+ "yes" "1",
+ "true",
+):
demandload = disabled_demandload
demand_compile_regexp = disabled_demand_compile_regexp
demandload(
- 'logging',
- 're',
+ "logging",
+ "re",
)
diff --git a/src/snakeoil/dependant_methods.py b/src/snakeoil/dependant_methods.py
index 551af1e6..031ad630 100644
--- a/src/snakeoil/dependant_methods.py
+++ b/src/snakeoil/dependant_methods.py
@@ -63,7 +63,7 @@ def _ensure_deps(cls_id, name, func, self, *a, **kw):
s = _yield_deps(self, self.stage_depends, name)
r = True
- if not hasattr(self, '_stage_state'):
+ if not hasattr(self, "_stage_state"):
self._stage_state = set()
for dep in s:
if dep not in self._stage_state:
@@ -108,8 +108,8 @@ def __wrap_stage_dependencies__(cls):
f = getattr(cls, x)
except AttributeError:
raise TypeError(
- "class %r stage_depends specifies %r, which doesn't exist" %
- (cls, x))
+ "class %r stage_depends specifies %r, which doesn't exist" % (cls, x)
+ )
f2 = pre_curry(_ensure_deps, cls_id, x, f)
f2.sd_raw_func = f
setattr(cls, x, f2)
@@ -122,9 +122,9 @@ def __unwrap_stage_dependencies__(cls):
f = getattr(cls, x)
except AttributeError:
raise TypeError(
- "class %r stage_depends specifies %r, which doesn't exist" %
- (cls, x))
- setattr(cls, x, getattr(f, 'sd_raw_func', f))
+ "class %r stage_depends specifies %r, which doesn't exist" % (cls, x)
+ )
+ setattr(cls, x, getattr(f, "sd_raw_func", f))
def __set_stage_state__(self, state):
@@ -165,17 +165,17 @@ class ForcedDepends(type):
def __new__(cls, name, bases, d):
obj = super(ForcedDepends, cls).__new__(cls, name, bases, d)
- if not hasattr(obj, 'stage_depends'):
+ if not hasattr(obj, "stage_depends"):
obj.stage_depends = {}
for x in ("wrap", "unwrap"):
- s = '__%s_stage_dependencies__' % x
+ s = "__%s_stage_dependencies__" % x
if not hasattr(obj, s):
setattr(obj, s, classmethod(globals()[s]))
obj.__unwrap_stage_dependencies__()
obj.__wrap_stage_dependencies__()
- if not hasattr(obj, '__force_stage_state__'):
+ if not hasattr(obj, "__force_stage_state__"):
obj.__set_stage_state__ = __set_stage_state__
- if not hasattr(obj, '__stage_step_callback__'):
+ if not hasattr(obj, "__stage_step_callback__"):
obj.__stage_step_callback__ = __stage_step_callback__
return obj
diff --git a/src/snakeoil/errors.py b/src/snakeoil/errors.py
index 9b7b5411..ab373489 100644
--- a/src/snakeoil/errors.py
+++ b/src/snakeoil/errors.py
@@ -12,20 +12,20 @@ def walk_exception_chain(exc, ignore_first=False, reverse=False):
def _inner_walk_exception_chain(exc, ignore_first):
if not ignore_first:
yield exc
- exc = getattr(exc, '__cause__', None)
+ exc = getattr(exc, "__cause__", None)
while exc is not None:
yield exc
- exc = getattr(exc, '__cause__', None)
+ exc = getattr(exc, "__cause__", None)
def dump_error(raw_exc, msg=None, handle=sys.stderr, tb=None):
# force default output for exceptions
- if getattr(handle, 'reset', False):
+ if getattr(handle, "reset", False):
handle.write(handle.reset)
- prefix = ''
+ prefix = ""
if msg:
- prefix = ' '
+ prefix = " "
handle.write(msg.rstrip("\n") + ":\n")
if tb:
handle.write("Traceback follows:\n")
@@ -34,8 +34,8 @@ def dump_error(raw_exc, msg=None, handle=sys.stderr, tb=None):
if raw_exc is not None:
for exc in walk_exception_chain(raw_exc):
exc_strings.extend(
- prefix + x.strip()
- for x in (x for x in str(exc).split("\n") if x))
+ prefix + x.strip() for x in (x for x in str(exc).split("\n") if x)
+ )
if exc_strings:
if msg and tb:
handle.write(f"\n{raw_exc.__class__.__name__}:\n")
diff --git a/src/snakeoil/fileutils.py b/src/snakeoil/fileutils.py
index ed3a0342..b0aa7247 100644
--- a/src/snakeoil/fileutils.py
+++ b/src/snakeoil/fileutils.py
@@ -23,23 +23,28 @@ def touch(fname, mode=0o644, **kwargs):
See os.utime for other supported arguments.
"""
flags = os.O_CREAT | os.O_APPEND
- dir_fd = kwargs.get('dir_fd', None)
+ dir_fd = kwargs.get("dir_fd", None)
os_open = partial(os.open, dir_fd=dir_fd)
with os.fdopen(os_open(fname, flags, mode)) as f:
os.utime(
f.fileno() if os.utime in os.supports_fd else fname,
- dir_fd=None if os.supports_fd else dir_fd, **kwargs)
+ dir_fd=None if os.supports_fd else dir_fd,
+ **kwargs
+ )
+
def mmap_or_open_for_read(path):
size = os.stat(path).st_size
if size == 0:
- return (None, data_source.bytes_ro_StringIO(b''))
+ return (None, data_source.bytes_ro_StringIO(b""))
fd = None
try:
fd = os.open(path, os.O_RDONLY)
- return (_fileutils.mmap_and_close(
- fd, size, mmap.MAP_SHARED, mmap.PROT_READ), None)
+ return (
+ _fileutils.mmap_and_close(fd, size, mmap.MAP_SHARED, mmap.PROT_READ),
+ None,
+ )
except IGNORED_EXCEPTIONS:
raise
except:
@@ -85,7 +90,8 @@ class AtomicWriteFile_mixin:
fp = os.path.realpath(fp)
self._original_fp = fp
self._temp_fp = os.path.join(
- os.path.dirname(fp), ".update." + os.path.basename(fp))
+ os.path.dirname(fp), ".update." + os.path.basename(fp)
+ )
old_umask = None
if perms:
# give it just write perms
@@ -140,7 +146,7 @@ class AtomicWriteFile(AtomicWriteFile_mixin):
self.raw = open(self._temp_fp, mode=self._computed_mode)
def _real_close(self):
- if hasattr(self, 'raw'):
+ if hasattr(self, "raw"):
return self.raw.close()
return None
@@ -149,24 +155,23 @@ class AtomicWriteFile(AtomicWriteFile_mixin):
def _mk_pretty_derived_func(func, name_base, name, *args, **kwds):
if name:
- name = '_' + name
- return pretty_docs(partial(func, *args, **kwds),
- name='%s%s' % (name_base, name))
+ name = "_" + name
+ return pretty_docs(partial(func, *args, **kwds), name="%s%s" % (name_base, name))
-_mk_readfile = partial(
- _mk_pretty_derived_func, _fileutils.native_readfile, 'readfile')
+_mk_readfile = partial(_mk_pretty_derived_func, _fileutils.native_readfile, "readfile")
-readfile_ascii = _mk_readfile('ascii', 'rt')
-readfile_bytes = _mk_readfile('bytes', 'rb')
-readfile_utf8 = _mk_readfile('utf8', 'r', encoding='utf8')
+readfile_ascii = _mk_readfile("ascii", "rt")
+readfile_bytes = _mk_readfile("bytes", "rb")
+readfile_utf8 = _mk_readfile("utf8", "r", encoding="utf8")
readfile = readfile_utf8
_mk_readlines = partial(
- _mk_pretty_derived_func, _fileutils.native_readlines, 'readlines')
+ _mk_pretty_derived_func, _fileutils.native_readlines, "readlines"
+)
-readlines_ascii = _mk_readlines('ascii', 'r', encoding='ascii')
-readlines_bytes = _mk_readlines('bytes', 'rb')
-readlines_utf8 = _mk_readlines('utf8', 'r', encoding='utf8')
+readlines_ascii = _mk_readlines("ascii", "r", encoding="ascii")
+readlines_bytes = _mk_readlines("bytes", "rb")
+readlines_utf8 = _mk_readlines("utf8", "r", encoding="utf8")
readlines = readlines_utf8
diff --git a/src/snakeoil/formatters.py b/src/snakeoil/formatters.py
index 397667fa..3faf661a 100644
--- a/src/snakeoil/formatters.py
+++ b/src/snakeoil/formatters.py
@@ -10,7 +10,9 @@ from .klass import GetAttrProxy, steal_docs
from .mappings import defaultdictkey
__all__ = (
- "Formatter", "PlainTextFormatter", "get_formatter",
+ "Formatter",
+ "PlainTextFormatter",
+ "get_formatter",
"decorate_forced_wrapping",
)
@@ -98,13 +100,11 @@ class Formatter:
def error(self, message):
"""Format a string as an error message."""
- self.write(message, prefixes=(
- self.fg('red'), self.bold, '!!! ', self.reset))
+ self.write(message, prefixes=(self.fg("red"), self.bold, "!!! ", self.reset))
def warn(self, message):
"""Format a string as a warning message."""
- self.write(message, prefixes=(
- self.fg('yellow'), self.bold, '*** ', self.reset))
+ self.write(message, prefixes=(self.fg("yellow"), self.bold, "*** ", self.reset))
def title(self, string):
"""Set the title to string"""
@@ -123,7 +123,7 @@ class PlainTextFormatter(Formatter):
every write.
"""
- bold = underline = reset = ''
+ bold = underline = reset = ""
def __init__(self, stream, width=79, encoding=None):
"""Initialize.
@@ -144,12 +144,12 @@ class PlainTextFormatter(Formatter):
else:
self.stream = stream
if encoding is None:
- encoding = getattr(self.stream, 'encoding', None)
+ encoding = getattr(self.stream, "encoding", None)
if encoding is None:
try:
encoding = locale.getpreferredencoding()
except locale.Error:
- encoding = 'ascii'
+ encoding = "ascii"
self.encoding = encoding
self.width = width
self._pos = 0
@@ -162,7 +162,7 @@ class PlainTextFormatter(Formatter):
return True
def _force_encoding(self, val):
- return val.encode(self.encoding, 'replace')
+ return val.encode(self.encoding, "replace")
def _write_prefix(self, wrap):
if self._in_first_line:
@@ -190,34 +190,32 @@ class PlainTextFormatter(Formatter):
@steal_docs(Formatter)
def write(self, *args, **kwargs):
- wrap = kwargs.get('wrap', self.wrap)
- autoline = kwargs.get('autoline', self.autoline)
- prefixes = kwargs.get('prefixes')
- first_prefixes = kwargs.get('first_prefixes')
- later_prefixes = kwargs.get('later_prefixes')
+ wrap = kwargs.get("wrap", self.wrap)
+ autoline = kwargs.get("autoline", self.autoline)
+ prefixes = kwargs.get("prefixes")
+ first_prefixes = kwargs.get("first_prefixes")
+ later_prefixes = kwargs.get("later_prefixes")
if prefixes is not None:
if first_prefixes is not None or later_prefixes is not None:
raise TypeError(
- 'do not pass first_prefixes or later_prefixes '
- 'if prefixes is passed')
+ "do not pass first_prefixes or later_prefixes "
+ "if prefixes is passed"
+ )
first_prefixes = later_prefixes = prefixes
- prefix = kwargs.get('prefix')
- first_prefix = kwargs.get('first_prefix')
- later_prefix = kwargs.get('later_prefix')
+ prefix = kwargs.get("prefix")
+ first_prefix = kwargs.get("first_prefix")
+ later_prefix = kwargs.get("later_prefix")
if prefix is not None:
if first_prefix is not None or later_prefix is not None:
- raise TypeError(
- 'do not pass first_prefix or later_prefix with prefix')
+ raise TypeError("do not pass first_prefix or later_prefix with prefix")
first_prefix = later_prefix = prefix
if first_prefix is not None:
if first_prefixes is not None:
- raise TypeError(
- 'do not pass both first_prefix and first_prefixes')
+ raise TypeError("do not pass both first_prefix and first_prefixes")
first_prefixes = (first_prefix,)
if later_prefix is not None:
if later_prefixes is not None:
- raise TypeError(
- 'do not pass both later_prefix and later_prefixes')
+ raise TypeError("do not pass both later_prefix and later_prefixes")
later_prefixes = (later_prefix,)
if first_prefixes is not None:
self.first_prefix.extend(first_prefixes)
@@ -242,7 +240,7 @@ class PlainTextFormatter(Formatter):
while wrap and self._pos + len(arg) > self.width:
# We have to split.
maxlen = self.width - self._pos
- space = arg.rfind(' ', 0, maxlen)
+ space = arg.rfind(" ", 0, maxlen)
if space == -1:
# No space to split on.
@@ -254,7 +252,7 @@ class PlainTextFormatter(Formatter):
# written something we can also go to the next
# line.
if self._in_first_line or self._wrote_something:
- bit = ''
+ bit = ""
else:
# Forcibly split this as far to the right as
# possible.
@@ -263,11 +261,11 @@ class PlainTextFormatter(Formatter):
else:
bit = arg[:space]
# Omit the space we split on.
- arg = arg[space + 1:]
+ arg = arg[space + 1 :]
if conversion_needed:
bit = self._force_encoding(bit)
self.stream.write(bit)
- self.stream.write(self._force_encoding('\n'))
+ self.stream.write(self._force_encoding("\n"))
self._pos = 0
self._in_first_line = False
self._wrote_something = False
@@ -280,7 +278,7 @@ class PlainTextFormatter(Formatter):
arg = self._force_encoding(arg)
self.stream.write(arg)
if autoline:
- self.stream.write(self._force_encoding('\n'))
+ self.stream.write(self._force_encoding("\n"))
self._wrote_something = False
self._pos = 0
self._in_first_line = True
@@ -290,32 +288,28 @@ class PlainTextFormatter(Formatter):
raise
finally:
if first_prefixes is not None:
- self.first_prefix = self.first_prefix[:-len(first_prefixes)]
+ self.first_prefix = self.first_prefix[: -len(first_prefixes)]
if later_prefixes is not None:
- self.later_prefix = self.later_prefix[:-len(later_prefixes)]
+ self.later_prefix = self.later_prefix[: -len(later_prefixes)]
def fg(self, color=None):
"""change fg color
Compatibility method- no coloring escapes are returned from it.
"""
- return ''
+ return ""
def bg(self, color=None):
"""change bg color
Compatibility method- no coloring escapes are returned from it.
"""
- return ''
+ return ""
def flush(self):
self.stream.flush()
-
-
-
-
class TerminfoDisabled(Exception):
"""Raised if Terminfo is disabled."""
@@ -331,7 +325,7 @@ class TerminfoUnsupported(Exception):
self.term = term
def __str__(self):
- return f'unsupported terminal type: {self.term!r}'
+ return f"unsupported terminal type: {self.term!r}"
# This is necessary because the curses module is optional (and we
@@ -341,6 +335,7 @@ try:
except ImportError:
TerminfoColor = None
else:
+
class TerminfoColor:
"""Class encapsulating a specific terminfo entry for a color.
@@ -351,8 +346,8 @@ else:
__slots__ = ("mode", "color", "__weakref__")
def __init__(self, mode, color):
- object.__setattr__(self, 'mode', mode)
- object.__setattr__(self, 'color', color)
+ object.__setattr__(self, "mode", mode)
+ object.__setattr__(self, "color", color)
def __call__(self, formatter):
if self.color is None:
@@ -374,7 +369,7 @@ else:
if template:
res = curses.tparm(template, color)
else:
- res = b''
+ res = b""
formatter._current_colors[self.mode] = res
formatter.stream.write(res)
@@ -393,7 +388,7 @@ else:
def __init__(self, value):
if value is None:
raise _BogusTerminfo()
- object.__setattr__(self, 'value', value)
+ object.__setattr__(self, "value", value)
def __setattr__(self, key, value):
raise AttributeError(f"{self.__class__.__name__} instances are immutable")
@@ -441,33 +436,32 @@ else:
super().__init__(stream, encoding=encoding)
fd = stream.fileno()
if term is None:
- if term := os.environ.get('TERM'):
+ if term := os.environ.get("TERM"):
try:
curses.setupterm(fd=fd, term=term)
except curses.error:
pass
else:
- raise TerminfoDisabled('no terminfo entries')
+ raise TerminfoDisabled("no terminfo entries")
else:
# TODO maybe do something more useful than raising curses.error
# if term is not in the terminfo db here?
curses.setupterm(fd=fd, term=term)
self._term = term
- self.width = curses.tigetnum('cols')
+ self.width = curses.tigetnum("cols")
try:
- self.reset = TerminfoReset(curses.tigetstr('sgr0'))
- self.bold = TerminfoMode(curses.tigetstr('bold'))
- self.underline = TerminfoMode(curses.tigetstr('smul'))
- self._color_reset = curses.tigetstr('op')
- self._set_color = (
- curses.tigetstr('setaf'),
- curses.tigetstr('setab'))
+ self.reset = TerminfoReset(curses.tigetstr("sgr0"))
+ self.bold = TerminfoMode(curses.tigetstr("bold"))
+ self.underline = TerminfoMode(curses.tigetstr("smul"))
+ self._color_reset = curses.tigetstr("op")
+ self._set_color = (curses.tigetstr("setaf"), curses.tigetstr("setab"))
except (_BogusTerminfo, curses.error) as e:
raise TerminfoUnsupported(self._term) from e
if not all(self._set_color):
raise TerminfoDisabled(
- 'setting background/foreground colors is not supported')
+ "setting background/foreground colors is not supported"
+ )
curses.tparm(self._set_color[0], curses.COLOR_WHITE)
@@ -507,16 +501,14 @@ else:
# not set the hs flag. So just check for the ability to
# jump to and out of the status line, without checking if
# the status line we're using exists.
- tsl = curses.tigetstr('tsl')
- fsl = curses.tigetstr('fsl')
+ tsl = curses.tigetstr("tsl")
+ fsl = curses.tigetstr("fsl")
if tsl and fsl:
- self.stream.write(
- tsl + string.encode(self.encoding, 'replace') + fsl)
+ self.stream.write(tsl + string.encode(self.encoding, "replace") + fsl)
self.stream.flush()
class ObserverFormatter:
-
def __init__(self, real_formatter):
self._formatter = real_formatter
@@ -542,7 +534,7 @@ def get_formatter(stream, force_color=False):
# needs an fd to pass to curses, not just a filelike talking to a tty.
if os.isatty(fd) or force_color:
try:
- term = 'ansi' if force_color else None
+ term = "ansi" if force_color else None
return TerminfoFormatter(stream, term=term)
except (curses.error, TerminfoDisabled, TerminfoUnsupported):
# This happens if TERM is unset and possibly in more cases.
@@ -553,6 +545,7 @@ def get_formatter(stream, force_color=False):
def decorate_forced_wrapping(setting=True):
"""Decorator to force a specific line wrapping state for the duration of invocation."""
+
def wrapped_func(func):
def f(out, *args, **kwds):
oldwrap = out.wrap
@@ -561,5 +554,7 @@ def decorate_forced_wrapping(setting=True):
return func(out, *args, **kwds)
finally:
out.wrap = oldwrap
+
return f
+
return wrapped_func
diff --git a/src/snakeoil/iterables.py b/src/snakeoil/iterables.py
index 787af672..7176a2be 100644
--- a/src/snakeoil/iterables.py
+++ b/src/snakeoil/iterables.py
@@ -19,8 +19,7 @@ def partition(iterable, predicate=bool):
filter and the second the matched items.
"""
a, b = itertools.tee((predicate(x), x) for x in iterable)
- return ((x for pred, x in a if not pred),
- (x for pred, x in b if pred))
+ return ((x for pred, x in a if not pred), (x for pred, x in b if pred))
class expandable_chain:
@@ -107,6 +106,7 @@ class caching_iter:
3
"""
+
__slots__ = ("iterable", "__weakref__", "cached_list", "sorter")
def __init__(self, iterable, sorter=None):
@@ -139,7 +139,7 @@ class caching_iter:
if self.iterable is not None:
i = itertools.islice(self.iterable, 0, index - (existing_len - 1))
self.cached_list.extend(i)
- if len(self.cached_list) -1 != index:
+ if len(self.cached_list) - 1 != index:
# consumed, baby.
self.iterable = None
self.cached_list = tuple(self.cached_list)
@@ -209,8 +209,7 @@ class caching_iter:
return len(self.cached_list)
def __iter__(self):
- if (self.sorter is not None and
- self.iterable is not None):
+ if self.sorter is not None and self.iterable is not None:
if self.cached_list:
self.cached_list.extend(self.iterable)
self.cached_list = tuple(self.sorter(self.cached_list))
@@ -237,8 +236,7 @@ class caching_iter:
return hash(self.cached_list)
def __str__(self):
- return "iterable(%s), cached: %s" % (
- self.iterable, str(self.cached_list))
+ return "iterable(%s), cached: %s" % (self.iterable, str(self.cached_list))
def iter_sort(sorter, *iterables):
diff --git a/src/snakeoil/klass.py b/src/snakeoil/klass.py
index 0e592588..23d6d3c5 100644
--- a/src/snakeoil/klass.py
+++ b/src/snakeoil/klass.py
@@ -7,12 +7,27 @@ involved in writing classes.
"""
__all__ = (
- "generic_equality", "reflective_hash", "inject_richcmp_methods_from_cmp",
- "static_attrgetter", "instance_attrgetter", "jit_attr", "jit_attr_none",
- "jit_attr_named", "jit_attr_ext_method", "alias_attr", "cached_hash",
- "cached_property", "cached_property_named",
- "steal_docs", "immutable_instance", "inject_immutable_instance",
- "alias_method", "aliased", "alias", "patch", "SlotsPicklingMixin",
+ "generic_equality",
+ "reflective_hash",
+ "inject_richcmp_methods_from_cmp",
+ "static_attrgetter",
+ "instance_attrgetter",
+ "jit_attr",
+ "jit_attr_none",
+ "jit_attr_named",
+ "jit_attr_ext_method",
+ "alias_attr",
+ "cached_hash",
+ "cached_property",
+ "cached_property_named",
+ "steal_docs",
+ "immutable_instance",
+ "inject_immutable_instance",
+ "alias_method",
+ "aliased",
+ "alias",
+ "patch",
+ "SlotsPicklingMixin",
)
import inspect
@@ -32,6 +47,7 @@ sentinel = object()
def GetAttrProxy(target):
def reflected_getattr(self, attr):
return getattr(object.__getattribute__(self, target), attr)
+
return reflected_getattr
@@ -43,6 +59,7 @@ def DirProxy(target):
except AttributeError:
attrs.extend(obj.__slots__)
return sorted(set(attrs))
+
return combined_dir
@@ -69,6 +86,8 @@ def get(self, key, default=None):
_attrlist_getter = attrgetter("__attr_comparison__")
+
+
def generic_attr_eq(inst1, inst2):
"""
compare inst1 to inst2, returning True if equal, False if not.
@@ -78,8 +97,7 @@ def generic_attr_eq(inst1, inst2):
if inst1 is inst2:
return True
for attr in _attrlist_getter(inst1):
- if getattr(inst1, attr, sentinel) != \
- getattr(inst2, attr, sentinel):
+ if getattr(inst1, attr, sentinel) != getattr(inst2, attr, sentinel):
return False
return True
@@ -105,28 +123,36 @@ def reflective_hash(attr):
:param attr: attribute name to pull the hash from on the instance
:return: hash value for instance this func is used in.
"""
+
def __hash__(self):
return getattr(self, attr)
+
return __hash__
+
def _internal_jit_attr(
- func, attr_name, singleton=None,
- use_cls_setattr=False, use_singleton=True, doc=None):
+ func, attr_name, singleton=None, use_cls_setattr=False, use_singleton=True, doc=None
+):
"""Object implementing the descriptor protocol for use in Just In Time access to attributes.
Consumers should likely be using the :py:func:`jit_func` line of helper functions
instead of directly consuming this.
"""
- doc = getattr(func, '__doc__', None) if doc is None else doc
+ doc = getattr(func, "__doc__", None) if doc is None else doc
class _internal_jit_attr(_raw_internal_jit_attr):
__doc__ = doc
__slots__ = ()
+
kls = _internal_jit_attr
return kls(
- func, attr_name, singleton=singleton, use_cls_setattr=use_cls_setattr,
- use_singleton=use_singleton)
+ func,
+ attr_name,
+ singleton=singleton,
+ use_cls_setattr=use_cls_setattr,
+ use_singleton=use_singleton,
+ )
class _raw_internal_jit_attr:
@@ -134,8 +160,9 @@ class _raw_internal_jit_attr:
__slots__ = ("storage_attr", "function", "_setter", "singleton", "use_singleton")
- def __init__(self, func, attr_name, singleton=None,
- use_cls_setattr=False, use_singleton=True):
+ def __init__(
+ self, func, attr_name, singleton=None, use_cls_setattr=False, use_singleton=True
+ ):
"""
:param func: function to invoke upon first request for this content
:param attr_name: attribute name to store the generated value in
@@ -178,8 +205,9 @@ class _raw_internal_jit_attr:
return obj
-def generic_equality(name, bases, scope, real_type=type,
- eq=generic_attr_eq, ne=generic_attr_ne):
+def generic_equality(
+ name, bases, scope, real_type=type, eq=generic_attr_eq, ne=generic_attr_ne
+):
"""
metaclass generating __eq__/__ne__ methods from an attribute list
@@ -208,7 +236,9 @@ def generic_equality(name, bases, scope, real_type=type,
attrlist = scope[attrlist]
for x in attrlist:
if not isinstance(x, str):
- raise TypeError(f"all members of attrlist must be strings- got {type(x)!r} {x!r}")
+ raise TypeError(
+ f"all members of attrlist must be strings- got {type(x)!r} {x!r}"
+ )
scope["__attr_comparison__"] = tuple(attrlist)
scope.setdefault("__eq__", eq)
@@ -285,9 +315,14 @@ def inject_richcmp_methods_from_cmp(scope):
:param scope: the modifiable scope of a class namespace to work on
"""
- for key, func in (("__lt__", generic_lt), ("__le__", generic_le),
- ("__eq__", generic_eq), ("__ne__", generic_ne),
- ("__ge__", generic_ge), ("__gt__", generic_gt)):
+ for key, func in (
+ ("__lt__", generic_lt),
+ ("__le__", generic_le),
+ ("__eq__", generic_eq),
+ ("__ne__", generic_ne),
+ ("__ge__", generic_ge),
+ ("__gt__", generic_gt),
+ ):
scope.setdefault(key, func)
@@ -329,7 +364,8 @@ class chained_getter(metaclass=partial(generic_equality, real_type=WeakInstMeta)
>>> print(o.recursive == foo.seq.__hash__)
True
"""
- __slots__ = ('namespace', 'getter')
+
+ __slots__ = ("namespace", "getter")
__fifo_cache__ = deque()
__inst_caching__ = True
__attr_comparison__ = ("namespace",)
@@ -361,16 +397,20 @@ instance_attrgetter = chained_getter
# this annoyingly means our docs have to be recommitted every change,
# even if no real code changed (since the id() continually moves)...
class _singleton_kls:
-
def __str__(self):
return "uncached singleton instance"
_uncached_singleton = _singleton_kls
-T = typing.TypeVar('T')
+T = typing.TypeVar("T")
+
-def jit_attr(func: typing.Callable[[typing.Any], T], kls=_internal_jit_attr, uncached_val: typing.Any=_uncached_singleton) -> T:
+def jit_attr(
+ func: typing.Callable[[typing.Any], T],
+ kls=_internal_jit_attr,
+ uncached_val: typing.Any = _uncached_singleton,
+) -> T:
"""
decorator to JIT generate, and cache the wrapped functions result in
'_' + func.__name__ on the instance.
@@ -399,8 +439,13 @@ def jit_attr_none(func: typing.Callable[[typing.Any], T], kls=_internal_jit_attr
return jit_attr(func, kls=kls, uncached_val=None)
-def jit_attr_named(stored_attr_name: str, use_cls_setattr=False, kls=_internal_jit_attr,
- uncached_val: typing.Any=_uncached_singleton, doc=None):
+def jit_attr_named(
+ stored_attr_name: str,
+ use_cls_setattr=False,
+ kls=_internal_jit_attr,
+ uncached_val: typing.Any = _uncached_singleton,
+ doc=None,
+):
"""
Version of :py:func:`jit_attr` decorator that allows for explicit control over the
attribute name used to store the cache value.
@@ -410,9 +455,14 @@ def jit_attr_named(stored_attr_name: str, use_cls_setattr=False, kls=_internal_j
return post_curry(kls, stored_attr_name, uncached_val, use_cls_setattr, doc=doc)
-def jit_attr_ext_method(func_name: str, stored_attr_name: str,
- use_cls_setattr=False, kls=_internal_jit_attr,
- uncached_val: typing.Any=_uncached_singleton, doc=None):
+def jit_attr_ext_method(
+ func_name: str,
+ stored_attr_name: str,
+ use_cls_setattr=False,
+ kls=_internal_jit_attr,
+ uncached_val: typing.Any = _uncached_singleton,
+ doc=None,
+):
"""
Decorator handing maximal control of attribute JIT'ing to the invoker.
@@ -421,11 +471,20 @@ def jit_attr_ext_method(func_name: str, stored_attr_name: str,
Generally speaking, you only need this when you are doing something rather *special*.
"""
- return kls(alias_method(func_name), stored_attr_name,
- uncached_val, use_cls_setattr, doc=doc)
+ return kls(
+ alias_method(func_name),
+ stored_attr_name,
+ uncached_val,
+ use_cls_setattr,
+ doc=doc,
+ )
-def cached_property(func: typing.Callable[[typing.Any], T], kls=_internal_jit_attr, use_cls_setattr=False) -> T:
+def cached_property(
+ func: typing.Callable[[typing.Any], T],
+ kls=_internal_jit_attr,
+ use_cls_setattr=False,
+) -> T:
"""
like `property`, just with caching
@@ -454,8 +513,9 @@ def cached_property(func: typing.Callable[[typing.Any], T], kls=_internal_jit_at
>>> print(obj.attr)
1
"""
- return kls(func, func.__name__, None, use_singleton=False,
- use_cls_setattr=use_cls_setattr)
+ return kls(
+ func, func.__name__, None, use_singleton=False, use_cls_setattr=use_cls_setattr
+ )
def cached_property_named(name: str, kls=_internal_jit_attr, use_cls_setattr=False):
@@ -538,11 +598,13 @@ def cached_hash(func):
>>> assert hash(f) == 12345 # note we still get the same value
>>> assert f.hash_invocations == 1 # and that the function was invoked only once.
"""
+
def __hash__(self):
- val = getattr(self, '_hash', None)
+ val = getattr(self, "_hash", None)
if val is None:
- object.__setattr__(self, '_hash', val := func(self))
+ object.__setattr__(self, "_hash", val := func(self))
return val
+
return __hash__
@@ -574,6 +636,7 @@ def steal_docs(target, ignore_missing=False, name=None):
>>> f = foo([1,2,3])
>>> assert f.extend.__doc__ == list.extend.__doc__
"""
+
def inner(functor):
if inspect.isclass(target):
if name is not None:
@@ -590,6 +653,7 @@ def steal_docs(target, ignore_missing=False, name=None):
obj = target
functor.__doc__ = obj.__doc__
return functor
+
return inner
@@ -611,7 +675,7 @@ def patch(target, external_decorator=None):
"""
def _import_module(target):
- components = target.split('.')
+ components = target.split(".")
import_path = components.pop(0)
module = import_module(import_path)
for comp in components:
@@ -624,7 +688,7 @@ def patch(target, external_decorator=None):
def _get_target(target):
try:
- module, attr = target.rsplit('.', 1)
+ module, attr = target.rsplit(".", 1)
except (TypeError, ValueError):
raise TypeError(f"invalid target: {target!r}")
module = _import_module(module)
@@ -632,7 +696,7 @@ def patch(target, external_decorator=None):
def decorator(func):
# use the original function wrapper
- func = getattr(func, '_func', func)
+ func = getattr(func, "_func", func)
module, attr = _get_target(target)
orig_func = getattr(module, attr)
@@ -749,6 +813,7 @@ class alias:
>>> speak = Speak()
>>> assert speak.shout('foo') == speak.yell('foo') == speak.scream('foo')
"""
+
def __init__(self, *aliases):
self.aliases = set(aliases)
@@ -762,11 +827,14 @@ def aliased(cls):
orig_methods = cls.__dict__.copy()
seen_aliases = set()
for _name, method in orig_methods.items():
- if hasattr(method, '_aliases'):
- collisions = method._aliases.intersection(orig_methods.keys() | seen_aliases)
+ if hasattr(method, "_aliases"):
+ collisions = method._aliases.intersection(
+ orig_methods.keys() | seen_aliases
+ )
if collisions:
raise ValueError(
- f"aliases collide with existing attributes: {', '.join(collisions)}")
+ f"aliases collide with existing attributes: {', '.join(collisions)}"
+ )
seen_aliases |= method._aliases
for alias in method._aliases:
setattr(cls, alias, method)
@@ -780,9 +848,13 @@ class SlotsPicklingMixin:
def __getstate__(self):
all_slots = itertools.chain.from_iterable(
- getattr(t, '__slots__', ()) for t in type(self).__mro__)
- state = {attr: getattr(self, attr) for attr in all_slots
- if hasattr(self, attr) and attr != '__weakref__'}
+ getattr(t, "__slots__", ()) for t in type(self).__mro__
+ )
+ state = {
+ attr: getattr(self, attr)
+ for attr in all_slots
+ if hasattr(self, attr) and attr != "__weakref__"
+ }
return state
def __setstate__(self, state):
diff --git a/src/snakeoil/mappings.py b/src/snakeoil/mappings.py
index c3498978..d4ac221d 100644
--- a/src/snakeoil/mappings.py
+++ b/src/snakeoil/mappings.py
@@ -3,10 +3,17 @@ Miscellaneous mapping related classes and functionality
"""
__all__ = (
- "DictMixin", "LazyValDict", "LazyFullValLoadDict",
- "ProtectedDict", "ImmutableDict", "IndeterminantDict",
- "defaultdictkey", "AttrAccessible", "StackedDict",
- "make_SlottedDict_kls", "ProxiedAttrs",
+ "DictMixin",
+ "LazyValDict",
+ "LazyFullValLoadDict",
+ "ProtectedDict",
+ "ImmutableDict",
+ "IndeterminantDict",
+ "defaultdictkey",
+ "AttrAccessible",
+ "StackedDict",
+ "make_SlottedDict_kls",
+ "ProxiedAttrs",
)
import operator
@@ -168,6 +175,7 @@ class LazyValDict(DictMixin):
given a function to get keys, and to look up the val for those keys, it'll
lazily load key definitions and values as requested
"""
+
__slots__ = ("_keys", "_keys_func", "_vals", "_val_func")
__externally_mutable__ = False
@@ -184,8 +192,7 @@ class LazyValDict(DictMixin):
self._keys_func = None
else:
if not callable(get_keys_func):
- raise TypeError(
- "get_keys_func isn't iterable or callable")
+ raise TypeError("get_keys_func isn't iterable or callable")
self._keys_func = get_keys_func
self._val_func = get_val_func
self._vals = {}
@@ -234,6 +241,7 @@ class LazyFullValLoadDict(LazyValDict):
The val function must still return values one by one per key.
"""
+
__slots__ = ()
def __getitem__(self, key):
@@ -297,8 +305,7 @@ class ProtectedDict(DictMixin):
yield k
def __contains__(self, key):
- return key in self.new or (key not in self.blacklist and
- key in self.orig)
+ return key in self.new or (key not in self.blacklist and key in self.orig)
class ImmutableDict(Mapping):
@@ -320,14 +327,14 @@ class ImmutableDict(Mapping):
try:
mapping = {k: v for k, v in data}
except TypeError as e:
- raise TypeError(f'unsupported data format: {e}')
- object.__setattr__(self, '_dict', mapping)
+ raise TypeError(f"unsupported data format: {e}")
+ object.__setattr__(self, "_dict", mapping)
def __getitem__(self, key):
# hack to avoid recursion exceptions for subclasses that use
# inject_getitem_as_getattr()
- if key == '_dict':
- return object.__getattribute__(self, '_dict')
+ if key == "_dict":
+ return object.__getattribute__(self, "_dict")
return self._dict[key]
def __iter__(self):
@@ -356,7 +363,7 @@ class OrderedFrozenSet(Set):
try:
self._dict = ImmutableDict({x: None for x in iterable})
except TypeError as e:
- raise TypeError('not iterable') from e
+ raise TypeError("not iterable") from e
def __contains__(self, key):
return key in self._dict
@@ -369,7 +376,7 @@ class OrderedFrozenSet(Set):
try:
return next(islice(self._dict, key, None))
except StopIteration:
- raise IndexError('index out of range')
+ raise IndexError("index out of range")
# handle keys using slice notation
return self.__class__(list(self._dict)[key])
@@ -384,8 +391,8 @@ class OrderedFrozenSet(Set):
return set(self._dict) == other
def __str__(self):
- elements_str = ', '.join(map(repr, self._dict))
- return f'{{{elements_str}}}'
+ elements_str = ", ".join(map(repr, self._dict))
+ return f"{{{elements_str}}}"
def __repr__(self):
return self.__str__()
@@ -413,7 +420,7 @@ class OrderedSet(OrderedFrozenSet, MutableSet):
try:
self._dict = {x: None for x in iterable}
except TypeError as e:
- raise TypeError('not iterable') from e
+ raise TypeError("not iterable") from e
def add(self, value):
self._dict[value] = None
@@ -434,7 +441,7 @@ class OrderedSet(OrderedFrozenSet, MutableSet):
self._dict.update((x, None) for x in iterable)
def __hash__(self):
- raise TypeError(f'unhashable type: {self.__class__.__name__!r}')
+ raise TypeError(f"unhashable type: {self.__class__.__name__!r}")
class IndeterminantDict:
@@ -473,12 +480,21 @@ class IndeterminantDict:
def __unmodifiable(func, *args):
raise TypeError(f"indeterminate dict: '{func}()' can't modify {args!r}")
- for func in ('__delitem__', '__setitem__', 'setdefault', 'popitem', 'update', 'clear'):
+
+ for func in (
+ "__delitem__",
+ "__setitem__",
+ "setdefault",
+ "popitem",
+ "update",
+ "clear",
+ ):
locals()[func] = partial(__unmodifiable, func)
def __indeterminate(func, *args):
raise TypeError(f"indeterminate dict: '{func}()' is inaccessible")
- for func in ('__iter__', '__len__', 'keys', 'values', 'items'):
+
+ for func in ("__iter__", "__len__", "keys", "values", "items"):
locals()[func] = partial(__indeterminate, func)
@@ -650,6 +666,7 @@ def _KeyError_to_Attr(functor):
return functor(self, *args)
except KeyError:
raise AttributeError(args[0])
+
inner.__name__ = functor.__name__
inner.__doc__ = functor.__doc__
return inner
@@ -681,9 +698,9 @@ def inject_getitem_as_getattr(scope):
:param scope: the scope of a class to modify, adding methods as needed
"""
- scope.setdefault('__getattr__', _KeyError_to_Attr(operator.__getitem__))
- scope.setdefault('__delattr__', _KeyError_to_Attr(operator.__delitem__))
- scope.setdefault('__setattr__', _KeyError_to_Attr(operator.__setitem__))
+ scope.setdefault("__getattr__", _KeyError_to_Attr(operator.__getitem__))
+ scope.setdefault("__delattr__", _KeyError_to_Attr(operator.__delitem__))
+ scope.setdefault("__setattr__", _KeyError_to_Attr(operator.__setitem__))
class AttrAccessible(dict):
@@ -713,7 +730,7 @@ class ProxiedAttrs(DictMixin):
:param target: The object to wrap.
"""
- __slots__ = ('__target__',)
+ __slots__ = ("__target__",)
def __init__(self, target):
self.__target__ = target
@@ -860,7 +877,7 @@ class _SlottedDict(DictMixin):
def make_SlottedDict_kls(keys):
"""Create a space efficient mapping class with a limited set of keys."""
new_keys = tuple(sorted(keys))
- cls_name = f'SlottedDict_{hash(new_keys)}'
+ cls_name = f"SlottedDict_{hash(new_keys)}"
o = globals().get(cls_name, None)
if o is None:
o = type(cls_name, (_SlottedDict,), {})
diff --git a/src/snakeoil/modules.py b/src/snakeoil/modules.py
index 740ea245..ec69701b 100644
--- a/src/snakeoil/modules.py
+++ b/src/snakeoil/modules.py
@@ -14,6 +14,7 @@ class FailedImport(ImportError):
"""
Raised when a requested target cannot be imported
"""
+
def __init__(self, trg, e):
super().__init__(self, f"Failed importing target '{trg}': '{e}'")
self.trg, self.e = trg, e
diff --git a/src/snakeoil/obj.py b/src/snakeoil/obj.py
index 73cede99..a8598bfb 100644
--- a/src/snakeoil/obj.py
+++ b/src/snakeoil/obj.py
@@ -74,7 +74,6 @@ try to proxy builtin objects like tuples, lists, dicts, sets, etc.
"""
-
__all__ = ("DelayedInstantiation", "DelayedInstantiation_kls", "make_kls", "popattr")
from . import klass
@@ -87,14 +86,25 @@ from . import klass
# pointless class creation- thus having two separate lists.
base_kls_descriptors = [
- '__delattr__', '__hash__', '__reduce__',
- '__reduce_ex__', '__repr__', '__setattr__', '__str__',
- '__format__', '__subclasshook__', # >=py2.6
- '__le__', '__lt__', '__ge__', '__gt__', '__eq__', '__ne__', # py3
- '__dir__', # >=py3.3
+ "__delattr__",
+ "__hash__",
+ "__reduce__",
+ "__reduce_ex__",
+ "__repr__",
+ "__setattr__",
+ "__str__",
+ "__format__",
+ "__subclasshook__", # >=py2.6
+ "__le__",
+ "__lt__",
+ "__ge__",
+ "__gt__",
+ "__eq__",
+ "__ne__", # py3
+ "__dir__", # >=py3.3
]
-if hasattr(object, '__sizeof__'):
- base_kls_descriptors.append('__sizeof__')
+if hasattr(object, "__sizeof__"):
+ base_kls_descriptors.append("__sizeof__")
base_kls_descriptors = frozenset(base_kls_descriptors)
@@ -134,13 +144,13 @@ class BaseDelayedObject:
def __getattribute__(self, attr):
obj = object.__getattribute__(self, "__obj__")
if obj is None:
- if attr == '__class__':
+ if attr == "__class__":
return object.__getattribute__(self, "__delayed__")[0]
- elif attr == '__doc__':
+ elif attr == "__doc__":
kls = object.__getattribute__(self, "__delayed__")[0]
- return getattr(kls, '__doc__', None)
+ return getattr(kls, "__doc__", None)
- obj = object.__getattribute__(self, '__instantiate_proxy_instance__')()
+ obj = object.__getattribute__(self, "__instantiate_proxy_instance__")()
if attr == "__obj__":
# special casing for klass.alias_method
@@ -157,61 +167,122 @@ class BaseDelayedObject:
# special case the normal descriptors
for x in base_kls_descriptors:
locals()[x] = klass.alias_method(
- "__obj__.%s" % (x,),
- doc=getattr(getattr(object, x), '__doc__', None))
+ "__obj__.%s" % (x,), doc=getattr(getattr(object, x), "__doc__", None)
+ )
# pylint: disable=undefined-loop-variable
del x
# note that we ignore __getattribute__; we already handle it.
-kls_descriptors = frozenset([
- # rich comparison protocol...
- '__le__', '__lt__', '__eq__', '__ne__', '__gt__', '__ge__',
- # unicode conversion
- '__unicode__',
- # truth...
- '__bool__',
- # container protocol...
- '__len__', '__getitem__', '__setitem__', '__delitem__',
- '__iter__', '__contains__', '__index__', '__reversed__',
- # deprecated sequence protocol bits...
- '__getslice__', '__setslice__', '__delslice__',
- # numeric...
- '__add__', '__sub__', '__mul__', '__floordiv__', '__mod__',
- '__divmod__', '__pow__', '__lshift__', '__rshift__',
- '__and__', '__xor__', '__or__', '__div__', '__truediv__',
- '__rad__', '__rsub__', '__rmul__', '__rdiv__', '__rtruediv__',
- '__rfloordiv__', '__rmod__', '__rdivmod__', '__rpow__',
- '__rlshift__', '__rrshift__', '__rand__', '__rxor__', '__ror__',
- '__iadd__', '__isub__', '__imul__', '__idiv__', '__itruediv__',
- '__ifloordiv__', '__imod__', '__ipow__', '__ilshift__',
- '__irshift__', '__iand__', '__ixor__', '__ior__',
- '__neg__', '__pos__', '__abs__', '__invert__', '__complex__',
- '__int__', '__long__', '__float__', '__oct__', '__hex__',
- '__coerce__', '__trunc__', '__radd__', '__floor__', '__ceil__',
- '__round__',
- # remaining...
- '__call__', '__sizeof__',
-])
+kls_descriptors = frozenset(
+ [
+ # rich comparison protocol...
+ "__le__",
+ "__lt__",
+ "__eq__",
+ "__ne__",
+ "__gt__",
+ "__ge__",
+ # unicode conversion
+ "__unicode__",
+ # truth...
+ "__bool__",
+ # container protocol...
+ "__len__",
+ "__getitem__",
+ "__setitem__",
+ "__delitem__",
+ "__iter__",
+ "__contains__",
+ "__index__",
+ "__reversed__",
+ # deprecated sequence protocol bits...
+ "__getslice__",
+ "__setslice__",
+ "__delslice__",
+ # numeric...
+ "__add__",
+ "__sub__",
+ "__mul__",
+ "__floordiv__",
+ "__mod__",
+ "__divmod__",
+ "__pow__",
+ "__lshift__",
+ "__rshift__",
+ "__and__",
+ "__xor__",
+ "__or__",
+ "__div__",
+ "__truediv__",
+ "__rad__",
+ "__rsub__",
+ "__rmul__",
+ "__rdiv__",
+ "__rtruediv__",
+ "__rfloordiv__",
+ "__rmod__",
+ "__rdivmod__",
+ "__rpow__",
+ "__rlshift__",
+ "__rrshift__",
+ "__rand__",
+ "__rxor__",
+ "__ror__",
+ "__iadd__",
+ "__isub__",
+ "__imul__",
+ "__idiv__",
+ "__itruediv__",
+ "__ifloordiv__",
+ "__imod__",
+ "__ipow__",
+ "__ilshift__",
+ "__irshift__",
+ "__iand__",
+ "__ixor__",
+ "__ior__",
+ "__neg__",
+ "__pos__",
+ "__abs__",
+ "__invert__",
+ "__complex__",
+ "__int__",
+ "__long__",
+ "__float__",
+ "__oct__",
+ "__hex__",
+ "__coerce__",
+ "__trunc__",
+ "__radd__",
+ "__floor__",
+ "__ceil__",
+ "__round__",
+ # remaining...
+ "__call__",
+ "__sizeof__",
+ ]
+)
kls_descriptors = kls_descriptors.difference(base_kls_descriptors)
-descriptor_overrides = {k: klass.alias_method(f"__obj__.{k}")
- for k in kls_descriptors}
+descriptor_overrides = {k: klass.alias_method(f"__obj__.{k}") for k in kls_descriptors}
_method_cache = {}
+
+
def make_kls(kls, proxy_base=BaseDelayedObject):
special_descriptors = kls_descriptors.intersection(dir(kls))
- doc = getattr(kls, '__doc__', None)
+ doc = getattr(kls, "__doc__", None)
if not special_descriptors and doc is None:
return proxy_base
key = (tuple(sorted(special_descriptors)), doc)
o = _method_cache.get(key, None)
if o is None:
+
class CustomDelayedObject(proxy_base):
- locals().update((k, descriptor_overrides[k])
- for k in special_descriptors)
+ locals().update((k, descriptor_overrides[k]) for k in special_descriptors)
__doc__ = doc
o = CustomDelayedObject
@@ -230,6 +301,8 @@ def DelayedInstantiation_kls(kls, *a, **kwd):
_class_cache = {}
+
+
def DelayedInstantiation(resultant_kls, func, *a, **kwd):
"""Generate an objects that does not get initialized before it is used.
diff --git a/src/snakeoil/osutils/__init__.py b/src/snakeoil/osutils/__init__.py
index f0ea4980..036a78df 100644
--- a/src/snakeoil/osutils/__init__.py
+++ b/src/snakeoil/osutils/__init__.py
@@ -34,9 +34,18 @@ pretty quickly.
"""
__all__ = (
- 'abspath', 'abssymlink', 'ensure_dirs', 'join', 'pjoin', 'listdir_files',
- 'listdir_dirs', 'listdir', 'readdir', 'normpath', 'unlink_if_exists',
- 'supported_systems',
+ "abspath",
+ "abssymlink",
+ "ensure_dirs",
+ "join",
+ "pjoin",
+ "listdir_files",
+ "listdir_dirs",
+ "listdir",
+ "readdir",
+ "normpath",
+ "unlink_if_exists",
+ "supported_systems",
)
import errno
@@ -86,14 +95,18 @@ def supported_systems(*systems):
...
NotImplementedError: func2 not supported on nonexistent
"""
+
def _decorator(f):
def _wrapper(*args, **kwargs):
if sys.platform.startswith(systems):
return f(*args, **kwargs)
else:
- raise NotImplementedError('%s not supported on %s'
- % (f.__name__, sys.platform))
+ raise NotImplementedError(
+ "%s not supported on %s" % (f.__name__, sys.platform)
+ )
+
return _wrapper
+
return _decorator
@@ -134,7 +147,7 @@ def ensure_dirs(path, gid=-1, uid=-1, mode=0o777, minimal=True):
try:
um = os.umask(0)
# if the dir perms would lack +wx, we have to force it
- force_temp_perms = ((mode & 0o300) != 0o300)
+ force_temp_perms = (mode & 0o300) != 0o300
resets = []
apath = normpath(os.path.abspath(path))
sticky_parent = False
@@ -149,7 +162,7 @@ def ensure_dirs(path, gid=-1, uid=-1, mode=0o777, minimal=True):
# if it's a subdir, we need +wx at least
if apath != base:
- sticky_parent = (st.st_mode & stat.S_ISGID)
+ sticky_parent = st.st_mode & stat.S_ISGID
except OSError:
# nothing exists.
@@ -185,8 +198,7 @@ def ensure_dirs(path, gid=-1, uid=-1, mode=0o777, minimal=True):
return False
try:
- if ((gid != -1 and gid != st.st_gid) or
- (uid != -1 and uid != st.st_uid)):
+ if (gid != -1 and gid != st.st_gid) or (uid != -1 and uid != st.st_uid):
os.chown(path, uid, gid)
if minimal:
if mode != (st.st_mode & mode):
@@ -207,9 +219,9 @@ def abssymlink(path):
a symlink
"""
mylink = os.readlink(path)
- if mylink[0] != '/':
+ if mylink[0] != "/":
mydir = os.path.dirname(path)
- mylink = mydir + '/' + mylink
+ mylink = mydir + "/" + mylink
return normpath(mylink)
@@ -256,7 +268,7 @@ def normpath(mypath: str) -> str:
`os.path.normpath` only in that it'll convert leading '//' into '/'
"""
newpath = os.path.normpath(mypath)
- double_sep = b'//' if isinstance(newpath, bytes) else '//'
+ double_sep = b"//" if isinstance(newpath, bytes) else "//"
if newpath.startswith(double_sep):
return newpath[1:]
return newpath
@@ -306,9 +318,9 @@ def fallback_access(path, mode, root=0):
return mode == (mode & (st.st_mode & 0x7))
-if os.uname()[0].lower() == 'sunos':
+if os.uname()[0].lower() == "sunos":
access = fallback_access
- access.__name__ = 'access'
+ access.__name__ = "access"
else:
access = os.access
diff --git a/src/snakeoil/osutils/mount.py b/src/snakeoil/osutils/mount.py
index 05eb47d5..b28853d2 100644
--- a/src/snakeoil/osutils/mount.py
+++ b/src/snakeoil/osutils/mount.py
@@ -1,4 +1,4 @@
-__all__ = ('mount', 'umount')
+__all__ = ("mount", "umount")
import ctypes
import os
@@ -40,10 +40,10 @@ MNT_EXPIRE = 4
UMOUNT_NOFOLLOW = 8
-@supported_systems('linux')
+@supported_systems("linux")
def mount(source, target, fstype, flags, data=None):
"""Call mount(2); see the man page for details."""
- libc = ctypes.CDLL(find_library('c'), use_errno=True)
+ libc = ctypes.CDLL(find_library("c"), use_errno=True)
source = source.encode() if isinstance(source, str) else source
target = target.encode() if isinstance(target, str) else target
fstype = fstype.encode() if isinstance(fstype, str) else fstype
@@ -52,10 +52,10 @@ def mount(source, target, fstype, flags, data=None):
raise OSError(e, os.strerror(e))
-@supported_systems('linux')
+@supported_systems("linux")
def umount(target, flags=None):
"""Call umount or umount2; see the umount(2) man page for details."""
- libc = ctypes.CDLL(find_library('c'), use_errno=True)
+ libc = ctypes.CDLL(find_library("c"), use_errno=True)
target = target.encode() if isinstance(target, str) else target
args = []
func = libc.umount
diff --git a/src/snakeoil/osutils/native_readdir.py b/src/snakeoil/osutils/native_readdir.py
index 0efb9c92..b129a9bf 100644
--- a/src/snakeoil/osutils/native_readdir.py
+++ b/src/snakeoil/osutils/native_readdir.py
@@ -3,8 +3,18 @@
import errno
import os
-from stat import (S_IFBLK, S_IFCHR, S_IFDIR, S_IFIFO, S_IFLNK, S_IFMT, S_IFREG, S_IFSOCK, S_ISDIR,
- S_ISREG)
+from stat import (
+ S_IFBLK,
+ S_IFCHR,
+ S_IFDIR,
+ S_IFIFO,
+ S_IFLNK,
+ S_IFMT,
+ S_IFREG,
+ S_IFSOCK,
+ S_ISDIR,
+ S_ISREG,
+)
from ..mappings import ProtectedDict
@@ -14,6 +24,7 @@ listdir = os.listdir
# import cycle.
pjoin = os.path.join
+
def stat_swallow_enoent(path, check, default=False, stat=os.stat):
try:
return check(stat(path).st_mode)
@@ -22,6 +33,7 @@ def stat_swallow_enoent(path, check, default=False, stat=os.stat):
return default
raise
+
def listdir_dirs(path, followSymlinks=True):
"""
Return a list of all subdirectories within a directory
@@ -36,11 +48,12 @@ def listdir_dirs(path, followSymlinks=True):
pjf = pjoin
lstat = os.lstat
if followSymlinks:
- return [x for x in os.listdir(path) if
- stat_swallow_enoent(pjf(path, x), scheck)]
+ return [
+ x for x in os.listdir(path) if stat_swallow_enoent(pjf(path, x), scheck)
+ ]
lstat = os.lstat
- return [x for x in os.listdir(path) if
- scheck(lstat(pjf(path, x)).st_mode)]
+ return [x for x in os.listdir(path) if scheck(lstat(pjf(path, x)).st_mode)]
+
def listdir_files(path, followSymlinks=True):
"""
@@ -56,24 +69,28 @@ def listdir_files(path, followSymlinks=True):
scheck = S_ISREG
pjf = pjoin
if followSymlinks:
- return [x for x in os.listdir(path) if
- stat_swallow_enoent(pjf(path, x), scheck)]
+ return [
+ x for x in os.listdir(path) if stat_swallow_enoent(pjf(path, x), scheck)
+ ]
lstat = os.lstat
- return [x for x in os.listdir(path) if
- scheck(lstat(pjf(path, x)).st_mode)]
+ return [x for x in os.listdir(path) if scheck(lstat(pjf(path, x)).st_mode)]
+
# we store this outside the function to ensure that
# the strings used are reused, thus avoiding unneeded
# allocations
-d_type_mapping = ProtectedDict({
- S_IFREG: "file",
- S_IFDIR: "directory",
- S_IFLNK: "symlink",
- S_IFCHR: "chardev",
- S_IFBLK: "block",
- S_IFSOCK: "socket",
- S_IFIFO: "fifo",
-})
+d_type_mapping = ProtectedDict(
+ {
+ S_IFREG: "file",
+ S_IFDIR: "directory",
+ S_IFLNK: "symlink",
+ S_IFCHR: "chardev",
+ S_IFBLK: "block",
+ S_IFSOCK: "socket",
+ S_IFIFO: "fifo",
+ }
+)
+
def readdir(path):
"""
diff --git a/src/snakeoil/pickling.py b/src/snakeoil/pickling.py
index e212e6dd..9707ecd0 100644
--- a/src/snakeoil/pickling.py
+++ b/src/snakeoil/pickling.py
@@ -3,7 +3,8 @@ pickling convenience module
"""
__all__ = (
- "iter_stream", "dump_stream",
+ "iter_stream",
+ "dump_stream",
)
from pickle import dump, load
diff --git a/src/snakeoil/process/__init__.py b/src/snakeoil/process/__init__.py
index f2422695..8966fd65 100644
--- a/src/snakeoil/process/__init__.py
+++ b/src/snakeoil/process/__init__.py
@@ -44,7 +44,7 @@ def get_exit_status(status: int):
if os.WIFSIGNALED(status):
return 128 + os.WTERMSIG(status)
else:
- assert os.WIFEXITED(status), 'Unexpected exit status %r' % status
+ assert os.WIFEXITED(status), "Unexpected exit status %r" % status
return os.WEXITSTATUS(status)
@@ -84,16 +84,14 @@ def exit_as_status(status: int):
class CommandNotFound(Exception):
-
def __init__(self, command):
- super().__init__(f'failed to find binary: {command!r}')
+ super().__init__(f"failed to find binary: {command!r}")
self.command = command
class ProcessNotFound(Exception):
-
def __init__(self, pid):
- super().__init__(f'nonexistent process: {pid}')
+ super().__init__(f"nonexistent process: {pid}")
closerange = os.closerange
diff --git a/src/snakeoil/process/namespaces.py b/src/snakeoil/process/namespaces.py
index fd6bd74e..6a823c6c 100644
--- a/src/snakeoil/process/namespaces.py
+++ b/src/snakeoil/process/namespaces.py
@@ -11,8 +11,15 @@ import socket
import subprocess
import sys
-from ..osutils.mount import (MS_NODEV, MS_NOEXEC, MS_NOSUID, MS_PRIVATE, MS_REC, MS_RELATIME,
- MS_SLAVE)
+from ..osutils.mount import (
+ MS_NODEV,
+ MS_NOEXEC,
+ MS_NOSUID,
+ MS_PRIVATE,
+ MS_REC,
+ MS_RELATIME,
+ MS_SLAVE,
+)
from ..osutils.mount import mount as _mount
from . import exit_as_status
@@ -39,7 +46,7 @@ def setns(fd, nstype):
fp = open(fd)
fd = fp.fileno()
- libc = ctypes.CDLL(ctypes.util.find_library('c'), use_errno=True)
+ libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)
if libc.setns(ctypes.c_int(fd), ctypes.c_int(nstype)) != 0:
e = ctypes.get_errno()
raise OSError(e, os.strerror(e))
@@ -54,7 +61,7 @@ def unshare(flags):
:param flags: Namespaces to unshare; bitwise OR of CLONE_* flags.
:raises OSError: if unshare failed.
"""
- libc = ctypes.CDLL(ctypes.util.find_library('c'), use_errno=True)
+ libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)
if libc.unshare(ctypes.c_int(flags)) != 0:
e = ctypes.get_errno()
raise OSError(e, os.strerror(e))
@@ -140,15 +147,13 @@ def create_pidns():
# Make sure to unshare the existing mount point if needed. Some distros
# create shared mount points everywhere by default.
try:
- _mount(None, '/proc', 'proc', MS_PRIVATE | MS_REC)
+ _mount(None, "/proc", "proc", MS_PRIVATE | MS_REC)
except OSError as e:
if e.errno != errno.EINVAL:
raise
# The child needs its own proc mount as it'll be different.
- _mount(
- 'proc', '/proc', 'proc',
- MS_NOSUID | MS_NODEV | MS_NOEXEC | MS_RELATIME)
+ _mount("proc", "/proc", "proc", MS_NOSUID | MS_NODEV | MS_NOEXEC | MS_RELATIME)
if pid := os.fork():
# Mask SIGINT with the assumption that the child will catch & process it.
@@ -195,12 +200,13 @@ def create_netns():
# Since we've unshared the net namespace, we need to bring up loopback.
# The kernel automatically adds the various ip addresses, so skip that.
try:
- subprocess.call(['ip', 'link', 'set', 'up', 'lo'])
+ subprocess.call(["ip", "link", "set", "up", "lo"])
except OSError as e:
if e.errno == errno.ENOENT:
sys.stderr.write(
- 'warning: could not bring up loopback for network; '
- 'install the iproute2 package\n')
+ "warning: could not bring up loopback for network; "
+ "install the iproute2 package\n"
+ )
else:
raise
@@ -243,16 +249,17 @@ def create_userns():
# For all other errors, abort. They shouldn't happen.
raise
- with open('/proc/self/setgroups', 'w') as f:
- f.write('deny')
- with open('/proc/self/uid_map', 'w') as f:
- f.write('0 %s 1\n' % uid)
- with open('/proc/self/gid_map', 'w') as f:
- f.write('0 %s 1\n' % gid)
+ with open("/proc/self/setgroups", "w") as f:
+ f.write("deny")
+ with open("/proc/self/uid_map", "w") as f:
+ f.write("0 %s 1\n" % uid)
+ with open("/proc/self/gid_map", "w") as f:
+ f.write("0 %s 1\n" % gid)
-def simple_unshare(mount=True, uts=True, ipc=True, net=False, pid=False,
- user=False, hostname=None):
+def simple_unshare(
+ mount=True, uts=True, ipc=True, net=False, pid=False, user=False, hostname=None
+):
"""Simpler helper for setting up namespaces quickly.
If support for any namespace type is not available, we'll silently skip it.
@@ -278,7 +285,7 @@ def simple_unshare(mount=True, uts=True, ipc=True, net=False, pid=False,
# on systems that share the rootfs by default, but allow events in the
# parent to propagate down.
try:
- _mount(None, '/', None, MS_REC | MS_SLAVE)
+ _mount(None, "/", None, MS_REC | MS_SLAVE)
except OSError as e:
if e.errno != errno.EINVAL:
raise
diff --git a/src/snakeoil/process/spawn.py b/src/snakeoil/process/spawn.py
index 3413e3f2..48b60b1a 100644
--- a/src/snakeoil/process/spawn.py
+++ b/src/snakeoil/process/spawn.py
@@ -3,8 +3,12 @@ subprocess related functionality
"""
__all__ = [
- "cleanup_pids", "spawn", "spawn_sandbox", "spawn_bash",
- "spawn_get_output", "bash_version",
+ "cleanup_pids",
+ "spawn",
+ "spawn_sandbox",
+ "spawn_bash",
+ "spawn_get_output",
+ "bash_version",
]
import atexit
@@ -17,11 +21,12 @@ from ..mappings import ProtectedDict
from ..osutils import access
from . import CommandNotFound, closerange, find_binary
-BASH_BINARY = find_binary('bash', fallback='/bin/bash')
-SANDBOX_BINARY = find_binary('sandbox', fallback='/usr/bin/sandbox')
+BASH_BINARY = find_binary("bash", fallback="/bin/bash")
+SANDBOX_BINARY = find_binary("sandbox", fallback="/usr/bin/sandbox")
try:
import resource
+
max_fd_limit = resource.getrlimit(resource.RLIMIT_NOFILE)[0]
except ImportError:
max_fd_limit = 256
@@ -36,8 +41,14 @@ def bash_version(force=False):
pass
try:
ret, ver = spawn_get_output(
- [BASH_BINARY, '--norc', '--noprofile', '-c',
- 'printf ${BASH_VERSINFO[0]}.${BASH_VERSINFO[1]}.${BASH_VERSINFO[2]}'])
+ [
+ BASH_BINARY,
+ "--norc",
+ "--noprofile",
+ "-c",
+ "printf ${BASH_VERSINFO[0]}.${BASH_VERSINFO[1]}.${BASH_VERSINFO[2]}",
+ ]
+ )
if ret == 0:
try:
ver = ver[0]
@@ -54,7 +65,7 @@ def bash_version(force=False):
def spawn_bash(mycommand, debug=False, name=None, **kwds):
"""spawn the command via bash -c"""
- args = [BASH_BINARY, '--norc', '--noprofile']
+ args = [BASH_BINARY, "--norc", "--noprofile"]
if debug:
# Print commands and their arguments as they are executed.
args.append("-x")
@@ -84,6 +95,8 @@ def spawn_sandbox(mycommand, name=None, **kwds):
_exithandlers = []
+
+
def atexit_register(func, *args, **kargs):
"""Wrapper around atexit.register that is needed in order to track
what is registered. For example, when portage restarts itself via
@@ -119,6 +132,8 @@ atexit.register(run_exitfuncs)
# we exit. spawn() takes care of adding and removing pids to this list
# as it creates and cleans up processes.
spawned_pids = []
+
+
def cleanup_pids(pids=None):
"""reap list of pids if specified, else all children"""
@@ -146,8 +161,19 @@ def cleanup_pids(pids=None):
pass
-def spawn(mycommand, env=None, name=None, fd_pipes=None, returnpid=False,
- uid=None, gid=None, groups=None, umask=None, cwd=None, pgid=None):
+def spawn(
+ mycommand,
+ env=None,
+ name=None,
+ fd_pipes=None,
+ returnpid=False,
+ uid=None,
+ gid=None,
+ groups=None,
+ umask=None,
+ cwd=None,
+ pgid=None,
+):
"""wrapper around execve
@@ -177,8 +203,19 @@ def spawn(mycommand, env=None, name=None, fd_pipes=None, returnpid=False,
# 'Catch "Exception"'
# pylint: disable-msg=W0703
try:
- _exec(binary, mycommand, name, fd_pipes, env, gid, groups,
- uid, umask, cwd, pgid)
+ _exec(
+ binary,
+ mycommand,
+ name,
+ fd_pipes,
+ env,
+ gid,
+ groups,
+ uid,
+ umask,
+ cwd,
+ pgid,
+ )
except Exception as e:
# We need to catch _any_ exception so that it doesn't
# propogate out of this function and cause exiting
@@ -228,8 +265,19 @@ def spawn(mycommand, env=None, name=None, fd_pipes=None, returnpid=False,
return 0
-def _exec(binary, mycommand, name=None, fd_pipes=None, env=None, gid=None,
- groups=None, uid=None, umask=None, cwd=None, pgid=None):
+def _exec(
+ binary,
+ mycommand,
+ name=None,
+ fd_pipes=None,
+ env=None,
+ gid=None,
+ groups=None,
+ uid=None,
+ umask=None,
+ cwd=None,
+ pgid=None,
+):
"""internal function to handle exec'ing the child process.
If it succeeds this function does not return. It might raise an
@@ -321,8 +369,15 @@ def _exec(binary, mycommand, name=None, fd_pipes=None, env=None, gid=None,
os.execve(binary, myargs, env)
-def spawn_get_output(mycommand, spawn_type=None, raw_exit_code=False, collect_fds=(1,),
- fd_pipes=None, split_lines=True, **kwds):
+def spawn_get_output(
+ mycommand,
+ spawn_type=None,
+ raw_exit_code=False,
+ collect_fds=(1,),
+ fd_pipes=None,
+ split_lines=True,
+ **kwds,
+):
"""Call spawn, collecting the output to fd's specified in collect_fds list.
@@ -386,8 +441,8 @@ def process_exit_code(retval):
:return: The exit code if it exit'd, the signal if it died from signalling.
"""
# If it got a signal, return the signal that was sent.
- if retval & 0xff:
- return (retval & 0xff) << 8
+ if retval & 0xFF:
+ return (retval & 0xFF) << 8
# Otherwise, return its exit code.
return retval >> 8
@@ -399,7 +454,7 @@ class ExecutionFailure(Exception):
self.msg = msg
def __str__(self):
- return f'Execution Failure: {self.msg}'
+ return f"Execution Failure: {self.msg}"
# cached capabilities
@@ -411,7 +466,7 @@ def is_sandbox_capable(force=False):
return is_sandbox_capable.cached_result
except AttributeError:
pass
- if 'SANDBOX_ACTIVE' in os.environ:
+ if "SANDBOX_ACTIVE" in os.environ:
# we can not spawn a sandbox inside another one
res = False
elif not (os.path.isfile(SANDBOX_BINARY) and access(SANDBOX_BINARY, os.X_OK)):
@@ -432,5 +487,5 @@ def is_userpriv_capable(force=False):
return is_userpriv_capable.cached_result
except AttributeError:
pass
- res = is_userpriv_capable.cached_result = (os.getuid() == 0)
+ res = is_userpriv_capable.cached_result = os.getuid() == 0
return res
diff --git a/src/snakeoil/sequences.py b/src/snakeoil/sequences.py
index c041136f..b80b1019 100644
--- a/src/snakeoil/sequences.py
+++ b/src/snakeoil/sequences.py
@@ -1,9 +1,14 @@
"""sequence related operations and classes"""
__all__ = (
- 'unstable_unique', 'stable_unique', 'iter_stable_unique',
- 'iflatten_instance', 'iflatten_func', 'ChainedLists', 'predicate_split',
- 'split_negations',
+ "unstable_unique",
+ "stable_unique",
+ "iter_stable_unique",
+ "iflatten_instance",
+ "iflatten_func",
+ "ChainedLists",
+ "predicate_split",
+ "split_negations",
)
from typing import Any, Callable, Iterable, Type
@@ -89,7 +94,9 @@ def iter_stable_unique(iterable):
break
-def iflatten_instance(l: Iterable, skip_flattening: Iterable[Type] = (str, bytes)) -> Iterable:
+def iflatten_instance(
+ l: Iterable, skip_flattening: Iterable[Type] = (str, bytes)
+) -> Iterable:
"""collapse [[1],2] into [1,2]
:param skip_flattening: list of classes to not descend through
@@ -103,9 +110,10 @@ def iflatten_instance(l: Iterable, skip_flattening: Iterable[Type] = (str, bytes
try:
while True:
x = next(iters)
- if (hasattr(x, '__iter__') and not (
- isinstance(x, skip_flattening) or (
- isinstance(x, (str, bytes)) and len(x) == 1))):
+ if hasattr(x, "__iter__") and not (
+ isinstance(x, skip_flattening)
+ or (isinstance(x, (str, bytes)) and len(x) == 1)
+ ):
iters.appendleft(x)
else:
yield x
@@ -128,7 +136,7 @@ def iflatten_func(l: Iterable, skip_func: Callable[[Any], bool]) -> Iterable:
try:
while True:
x = next(iters)
- if hasattr(x, '__iter__') and not skip_func(x):
+ if hasattr(x, "__iter__") and not skip_func(x):
iters.appendleft(x)
else:
yield x
@@ -164,6 +172,7 @@ class ChainedLists:
...
TypeError: not mutable
"""
+
__slots__ = ("_lists", "__weakref__")
def __init__(self, *lists):
@@ -258,7 +267,7 @@ def predicate_split(func, stream, key=None):
def split_negations(iterable, func=str):
- """"Split an iterable into negative and positive elements.
+ """ "Split an iterable into negative and positive elements.
:param iterable: iterable targeted for splitting
:param func: wrapper method to modify tokens
@@ -267,7 +276,7 @@ def split_negations(iterable, func=str):
"""
neg, pos = [], []
for token in iterable:
- if token[0] == '-':
+ if token[0] == "-":
if len(token) == 1:
raise ValueError("'-' negation without a token")
token = token[1:]
@@ -281,7 +290,7 @@ def split_negations(iterable, func=str):
def split_elements(iterable, func=str):
- """"Split an iterable into negative, neutral, and positive elements.
+ """ "Split an iterable into negative, neutral, and positive elements.
:param iterable: iterable targeted for splitting
:param func: wrapper method to modify tokens
@@ -289,11 +298,11 @@ def split_elements(iterable, func=str):
:return: Tuple containing negative, neutral, and positive element tuples, respectively.
"""
neg, neu, pos = [], [], []
- token_map = {'-': neg, '+': pos}
+ token_map = {"-": neg, "+": pos}
for token in iterable:
if token[0] in token_map:
if len(token) == 1:
- raise ValueError('%r without a token' % (token[0],))
+ raise ValueError("%r without a token" % (token[0],))
l = token_map[token[0]]
token = token[1:]
else:
diff --git a/src/snakeoil/stringio.py b/src/snakeoil/stringio.py
index c17db923..1a392eb4 100644
--- a/src/snakeoil/stringio.py
+++ b/src/snakeoil/stringio.py
@@ -27,7 +27,7 @@ is usable under both py2k and py3k.
"""
# TODO: deprecated, remove in 0.9.0
-__all__ = ('text_readonly', 'bytes_readonly')
+__all__ = ("text_readonly", "bytes_readonly")
import io
diff --git a/src/snakeoil/strings.py b/src/snakeoil/strings.py
index 7d7b2a83..d0dc58ab 100644
--- a/src/snakeoil/strings.py
+++ b/src/snakeoil/strings.py
@@ -2,10 +2,10 @@
from .demandload import demand_compile_regexp
-demand_compile_regexp('_whitespace_regex', r'^(?P<indent>\s+)')
+demand_compile_regexp("_whitespace_regex", r"^(?P<indent>\s+)")
-def pluralism(obj, none=None, singular='', plural='s'):
+def pluralism(obj, none=None, singular="", plural="s"):
"""Return singular or plural suffix depending on object's length or value."""
# default to plural for empty objects, e.g. there are 0 repos
if none is None:
@@ -27,16 +27,16 @@ def pluralism(obj, none=None, singular='', plural='s'):
def doc_dedent(s):
"""Support dedenting docstrings with initial line having no indentation."""
try:
- lines = s.split('\n')
+ lines = s.split("\n")
except AttributeError:
- raise TypeError(f'{s!r} is not a string')
+ raise TypeError(f"{s!r} is not a string")
if lines:
# find first line with an indent if one exists
for line in lines:
if mo := _whitespace_regex.match(line):
- indent = mo.group('indent')
+ indent = mo.group("indent")
break
else:
- indent = ''
+ indent = ""
len_i = len(indent)
- return '\n'.join(x[len_i:] if x.startswith(indent) else x for x in lines)
+ return "\n".join(x[len_i:] if x.startswith(indent) else x for x in lines)
diff --git a/src/snakeoil/tar.py b/src/snakeoil/tar.py
index 7daa9eb0..74061a58 100644
--- a/src/snakeoil/tar.py
+++ b/src/snakeoil/tar.py
@@ -43,13 +43,30 @@ class TarInfo(tarfile.TarInfo):
:ivar uname: same as TarInfo.uname, just interned via a property.
"""
- if not hasattr(tarfile.TarInfo, '__slots__'):
+ if not hasattr(tarfile.TarInfo, "__slots__"):
__slots__ = (
- "name", "mode", "uid", "gid", "size", "mtime", "chksum", "type",
- "linkname", "_uname", "_gname", "devmajor", "devminor", "prefix",
- "offset", "offset_data", "_buf", "sparse", "_link_target")
+ "name",
+ "mode",
+ "uid",
+ "gid",
+ "size",
+ "mtime",
+ "chksum",
+ "type",
+ "linkname",
+ "_uname",
+ "_gname",
+ "devmajor",
+ "devminor",
+ "prefix",
+ "offset",
+ "offset_data",
+ "_buf",
+ "sparse",
+ "_link_target",
+ )
else:
- __slots__ = ('_buf', '_uname', '_gname')
+ __slots__ = ("_buf", "_uname", "_gname")
def get_buf(self):
return self.tobuf()
diff --git a/src/snakeoil/test/__init__.py b/src/snakeoil/test/__init__.py
index bb9381f7..b93094a8 100644
--- a/src/snakeoil/test/__init__.py
+++ b/src/snakeoil/test/__init__.py
@@ -13,18 +13,20 @@ from snakeoil import klass
def random_str(length):
"""Return a random string of specified length."""
- return ''.join(random.choices(string.ascii_letters + string.digits, k=length))
+ return "".join(random.choices(string.ascii_letters + string.digits, k=length))
def coverage():
"""Extract coverage instance (if it exists) from the current running context."""
cov = None
import inspect
+
try:
import coverage
+
frame = inspect.currentframe()
while frame is not None:
- cov = getattr(frame.f_locals.get('self'), 'coverage', None)
+ cov = getattr(frame.f_locals.get("self"), "coverage", None)
if isinstance(cov, coverage.coverage):
break
frame = frame.f_back
@@ -33,7 +35,7 @@ def coverage():
return cov
-@klass.patch('os._exit')
+@klass.patch("os._exit")
def _os_exit(orig_exit, val):
"""Monkeypatch os._exit() to save coverage data before exit."""
cov = coverage()
@@ -51,7 +53,9 @@ def protect_process(functor, name=None):
if os.environ.get(_PROTECT_ENV_VAR, False):
return functor(self)
if name is None:
- name = f"{self.__class__.__module__}.{self.__class__.__name__}.{method_name}"
+ name = (
+ f"{self.__class__.__module__}.{self.__class__.__name__}.{method_name}"
+ )
runner_path = __file__
if runner_path.endswith(".pyc") or runner_path.endswith(".pyo"):
runner_path = runner_path.rsplit(".", maxsplit=1)[0] + ".py"
@@ -59,11 +63,18 @@ def protect_process(functor, name=None):
try:
os.environ[_PROTECT_ENV_VAR] = "yes"
args = [sys.executable, __file__, name]
- p = subprocess.Popen(args, shell=False, env=os.environ.copy(),
- stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+ p = subprocess.Popen(
+ args,
+ shell=False,
+ env=os.environ.copy(),
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ )
stdout, _ = p.communicate()
ret = p.wait()
- assert ret == 0, f"subprocess run: {args!r}\nnon zero exit: {ret}\nstdout:\n{stdout}"
+ assert (
+ ret == 0
+ ), f"subprocess run: {args!r}\nnon zero exit: {ret}\nstdout:\n{stdout}"
finally:
if wipe:
os.environ.pop(_PROTECT_ENV_VAR, None)
@@ -71,7 +82,7 @@ def protect_process(functor, name=None):
for x in ("__doc__", "__name__"):
if hasattr(functor, x):
setattr(_inner_run, x, getattr(functor, x))
- method_name = getattr(functor, '__name__', None)
+ method_name = getattr(functor, "__name__", None)
return _inner_run
@@ -84,4 +95,4 @@ def hide_imports(*import_names: str):
raise ImportError()
return orig_import(name, *args, **kwargs)
- return patch('builtins.__import__', side_effect=mock_import)
+ return patch("builtins.__import__", side_effect=mock_import)
diff --git a/src/snakeoil/test/argparse_helpers.py b/src/snakeoil/test/argparse_helpers.py
index 152e335d..bfcdc3b4 100644
--- a/src/snakeoil/test/argparse_helpers.py
+++ b/src/snakeoil/test/argparse_helpers.py
@@ -56,25 +56,24 @@ class Color(FormatterObject):
self.color = color
def __repr__(self):
- return f'<Color: mode - {self.mode}; color - {self.color}>'
+ return f"<Color: mode - {self.mode}; color - {self.color}>"
class Reset(FormatterObject):
__inst_caching__ = True
def __repr__(self):
- return '<Reset>'
+ return "<Reset>"
class Bold(FormatterObject):
__inst_caching__ = True
def __repr__(self):
- return '<Bold>'
+ return "<Bold>"
class ListStream(list):
-
def write(self, *args):
stringlist = []
objectlist = []
@@ -82,13 +81,16 @@ class ListStream(list):
if isinstance(arg, bytes):
stringlist.append(arg)
else:
- objectlist.append(b''.join(stringlist))
+ objectlist.append(b"".join(stringlist))
stringlist = []
objectlist.append(arg)
- objectlist.append(b''.join(stringlist))
+ objectlist.append(b"".join(stringlist))
# We use len because boolean ops shortcircuit
- if (len(self) and isinstance(self[-1], bytes) and
- isinstance(objectlist[0], bytes)):
+ if (
+ len(self)
+ and isinstance(self[-1], bytes)
+ and isinstance(objectlist[0], bytes)
+ ):
self[-1] = self[-1] + objectlist.pop(0)
self.extend(objectlist)
@@ -97,7 +99,6 @@ class ListStream(list):
class FakeStreamFormatter(PlainTextFormatter):
-
def __init__(self):
super().__init__(ListStream([]))
self.reset = Reset()
@@ -108,15 +109,15 @@ class FakeStreamFormatter(PlainTextFormatter):
self.stream = ListStream([])
def fg(self, color=None):
- return Color('fg', color)
+ return Color("fg", color)
def bg(self, color=None):
- return Color('bg', color)
+ return Color("bg", color)
def get_text_stream(self):
- return b''.join(
- (x for x in self.stream
- if not isinstance(x, FormatterObject))).decode('ascii')
+ return b"".join(
+ (x for x in self.stream if not isinstance(x, FormatterObject))
+ ).decode("ascii")
class ArgParseMixin:
@@ -148,7 +149,7 @@ class ArgParseMixin:
except Error as e:
assert message == e.message
else:
- raise AssertionError('no error triggered')
+ raise AssertionError("no error triggered")
def assertExit(self, status, message, *args, **kwargs):
"""Pass args, assert they trigger the right exit condition."""
@@ -158,7 +159,7 @@ class ArgParseMixin:
assert message == e.message.strip()
assert status == e.status
else:
- raise AssertionError('no exit triggered')
+ raise AssertionError("no exit triggered")
def assertOut(self, out, *args, **kwargs):
"""Like :obj:`assertOutAndErr` but without err."""
@@ -182,17 +183,25 @@ class ArgParseMixin:
main = self.get_main(options)
main(options, outformatter, errformatter)
diffs = []
- for name, strings, formatter in [('out', out, outformatter),
- ('err', err, errformatter)]:
+ for name, strings, formatter in [
+ ("out", out, outformatter),
+ ("err", err, errformatter),
+ ]:
actual = formatter.get_text_stream()
if strings:
- expected = '\n'.join(strings)
+ expected = "\n".join(strings)
else:
- expected = ''
+ expected = ""
if expected != actual:
- diffs.extend(difflib.unified_diff(
- strings, actual.split('\n')[:-1],
- 'expected %s' % (name,), 'actual', lineterm=''))
+ diffs.extend(
+ difflib.unified_diff(
+ strings,
+ actual.split("\n")[:-1],
+ "expected %s" % (name,),
+ "actual",
+ lineterm="",
+ )
+ )
if diffs:
- raise AssertionError('\n' + '\n'.join(diffs))
+ raise AssertionError("\n" + "\n".join(diffs))
return options
diff --git a/src/snakeoil/test/eq_hash_inheritance.py b/src/snakeoil/test/eq_hash_inheritance.py
index 5012f9d3..eaa42a37 100644
--- a/src/snakeoil/test/eq_hash_inheritance.py
+++ b/src/snakeoil/test/eq_hash_inheritance.py
@@ -3,7 +3,7 @@ from . import mixins
class Test(mixins.TargetedNamespaceWalker, mixins.KlassWalker):
- target_namespace = 'snakeoil'
+ target_namespace = "snakeoil"
singleton = object()
@@ -26,8 +26,8 @@ class Test(mixins.TargetedNamespaceWalker, mixins.KlassWalker):
# object sets __hash__/__eq__, which isn't usually
# intended to be inherited/reused
continue
- eq = getattr(parent, '__eq__', self.singleton)
- h = getattr(parent, '__hash__', self.singleton)
+ eq = getattr(parent, "__eq__", self.singleton)
+ h = getattr(parent, "__hash__", self.singleton)
if eq == object.__eq__ and h == object.__hash__:
continue
if eq and h:
@@ -37,10 +37,11 @@ class Test(mixins.TargetedNamespaceWalker, mixins.KlassWalker):
# pylint: disable=undefined-loop-variable
# 'parent' is guaranteed to be defined due to the 'else' clause above
- assert getattr(cls, '__hash__') is not None, (
+ assert getattr(cls, "__hash__") is not None, (
f"class '{cls.__module__}.{cls.__name__}' had its __hash__ reset, "
"while it would've inherited __hash__ from parent "
f"'{parent.__module__}.{parent.__name__}'; this occurs in py3k when "
"__eq__ is defined alone. If this is desired behaviour, set "
"__hash__intentionally_disabled__ to True to explicitly ignore this"
- " class")
+ " class"
+ )
diff --git a/src/snakeoil/test/mixins.py b/src/snakeoil/test/mixins.py
index aa66839c..0648de1e 100644
--- a/src/snakeoil/test/mixins.py
+++ b/src/snakeoil/test/mixins.py
@@ -16,23 +16,27 @@ class PythonNamespaceWalker:
# This is for py3.2/PEP3149; dso's now have the interp + major/minor embedded
# in the name.
# TODO: update this for pypy's naming
- abi_target = 'cpython-%i%i' % tuple(sys.version_info[:2])
+ abi_target = "cpython-%i%i" % tuple(sys.version_info[:2])
- module_blacklist = frozenset({
- 'snakeoil.cli.arghparse', 'snakeoil.pickling',
- })
+ module_blacklist = frozenset(
+ {
+ "snakeoil.cli.arghparse",
+ "snakeoil.pickling",
+ }
+ )
def _default_module_blacklister(self, target):
- return target in self.module_blacklist or target.startswith('snakeoil.dist')
+ return target in self.module_blacklist or target.startswith("snakeoil.dist")
def walk_namespace(self, namespace, **kwds):
- location = os.path.abspath(os.path.dirname(
- self.poor_mans_load(namespace).__file__))
- return self.get_modules(self.recurse(location), namespace=namespace,
- **kwds)
-
- def get_modules(self, feed, namespace=None, blacklist_func=None,
- ignore_failed_imports=None):
+ location = os.path.abspath(
+ os.path.dirname(self.poor_mans_load(namespace).__file__)
+ )
+ return self.get_modules(self.recurse(location), namespace=namespace, **kwds)
+
+ def get_modules(
+ self, feed, namespace=None, blacklist_func=None, ignore_failed_imports=None
+ ):
if ignore_failed_imports is None:
ignore_failed_imports = self.ignore_all_import_failures
if namespace is None:
@@ -57,7 +61,7 @@ class PythonNamespaceWalker:
raise
def recurse(self, location, valid_namespace=True):
- if os.path.dirname(location) == '__pycache__':
+ if os.path.dirname(location) == "__pycache__":
# Shouldn't be possible, but make sure we avoid this if it manages
# to occur.
return
@@ -78,10 +82,13 @@ class PythonNamespaceWalker:
# file disappeared under our feet... lock file from
# trial can cause this. ignore.
import logging
- logging.debug("file %r disappeared under our feet, ignoring",
- os.path.join(location, x))
- seen = set(['__init__'])
+ logging.debug(
+ "file %r disappeared under our feet, ignoring",
+ os.path.join(location, x),
+ )
+
+ seen = set(["__init__"])
for x, st in stats:
if not (x.startswith(".") or x.endswith("~")) and stat.S_ISREG(st):
if x.endswith((".py", ".pyc", ".pyo", ".so")):
@@ -89,8 +96,8 @@ class PythonNamespaceWalker:
# Ensure we're not looking at a >=py3k .so which injects
# the version name in...
if y not in seen:
- if '.' in y and x.endswith('.so'):
- y, abi = x.rsplit('.', 1)
+ if "." in y and x.endswith(".so"):
+ y, abi = x.rsplit(".", 1)
if abi != self.abi_target:
continue
seen.add(y)
@@ -135,6 +142,7 @@ class TargetedNamespaceWalker(PythonNamespaceWalker):
for _mod in self.walk_namespace(namespace):
pass
+
class _classWalker:
cls_blacklist = frozenset()
@@ -173,7 +181,6 @@ class _classWalker:
class SubclassWalker(_classWalker):
-
def walk_derivatives(self, cls, seen=None):
if len(inspect.signature(cls.__subclasses__).parameters) != 0:
return
@@ -193,7 +200,6 @@ class SubclassWalker(_classWalker):
class KlassWalker(_classWalker):
-
def walk_derivatives(self, cls, seen=None):
if len(inspect.signature(cls.__subclasses__).parameters) != 0:
return
diff --git a/src/snakeoil/test/modules.py b/src/snakeoil/test/modules.py
index 6b3a6a1a..0ae116bd 100644
--- a/src/snakeoil/test/modules.py
+++ b/src/snakeoil/test/modules.py
@@ -3,12 +3,12 @@ from . import mixins
class ExportedModules(mixins.PythonNamespaceWalker):
- target_namespace = 'snakeoil'
+ target_namespace = "snakeoil"
def test__all__accuracy(self):
failures = []
for module in self.walk_namespace(self.target_namespace):
- for target in getattr(module, '__all__', ()):
+ for target in getattr(module, "__all__", ()):
if not hasattr(module, target):
failures.append((module, target))
assert not failures, f"nonexistent __all__ targets spotted: {failures}"
diff --git a/src/snakeoil/test/slot_shadowing.py b/src/snakeoil/test/slot_shadowing.py
index 3e260c2e..fac66195 100644
--- a/src/snakeoil/test/slot_shadowing.py
+++ b/src/snakeoil/test/slot_shadowing.py
@@ -5,7 +5,7 @@ from . import mixins
class SlotShadowing(mixins.TargetedNamespaceWalker, mixins.SubclassWalker):
- target_namespace = 'snakeoil'
+ target_namespace = "snakeoil"
err_if_slots_is_str = True
err_if_slots_is_mutable = True
@@ -22,20 +22,20 @@ class SlotShadowing(mixins.TargetedNamespaceWalker, mixins.SubclassWalker):
@staticmethod
def mk_name(kls):
- return f'{kls.__module__}.{kls.__name__}'
+ return f"{kls.__module__}.{kls.__name__}"
def _should_ignore(self, kls):
return self.mk_name(kls).split(".")[0] != self.target_namespace
def run_check(self, kls):
- if getattr(kls, '__slotting_intentionally_disabled__', False):
+ if getattr(kls, "__slotting_intentionally_disabled__", False):
return
slotting = {}
raw_slottings = {}
for parent in self.recurse_parents(kls):
- slots = getattr(parent, '__slots__', None)
+ slots = getattr(parent, "__slots__", None)
if slots is None:
continue
@@ -49,14 +49,15 @@ class SlotShadowing(mixins.TargetedNamespaceWalker, mixins.SubclassWalker):
for slot in slots:
slotting.setdefault(slot, parent)
- slots = getattr(kls, '__slots__', None)
+ slots = getattr(kls, "__slots__", None)
if slots is None and not slotting:
return
if isinstance(slots, str):
if self.err_if_slots_is_str:
pytest.fail(
- f"cls {kls!r}; slots is {slots!r} (should be a tuple or list)")
+ f"cls {kls!r}; slots is {slots!r} (should be a tuple or list)"
+ )
slots = (slots,)
if slots is None:
@@ -64,8 +65,7 @@ class SlotShadowing(mixins.TargetedNamespaceWalker, mixins.SubclassWalker):
if not isinstance(slots, tuple):
if self.err_if_slots_is_mutable:
- pytest.fail(
- f"cls {kls!r}; slots is {slots!r}- - should be a tuple")
+ pytest.fail(f"cls {kls!r}; slots is {slots!r}- - should be a tuple")
slots = tuple(slots)
if slots is None or (slots and slots in raw_slottings):
@@ -74,9 +74,11 @@ class SlotShadowing(mixins.TargetedNamespaceWalker, mixins.SubclassWalker):
# daftly copied the parents... thus defeating the purpose.
pytest.fail(
f"cls {kls!r}; slots is {slots!r}, seemingly inherited from "
- f"{raw_slottings[slots]!r}; the derivative class should be __slots__ = ()")
+ f"{raw_slottings[slots]!r}; the derivative class should be __slots__ = ()"
+ )
for slot in slots:
if slot in slotting:
pytest.fail(
- f"cls {kls!r}; slot {slot!r} was already defined at {slotting[slot]!r}")
+ f"cls {kls!r}; slot {slot!r} was already defined at {slotting[slot]!r}"
+ )
diff --git a/src/snakeoil/version.py b/src/snakeoil/version.py
index fa3fc96f..2fd65d2c 100644
--- a/src/snakeoil/version.py
+++ b/src/snakeoil/version.py
@@ -28,43 +28,46 @@ def get_version(project, repo_file, api_version=None):
version_info = None
if api_version is None:
try:
- api_version = getattr(import_module(project), '__version__')
+ api_version = getattr(import_module(project), "__version__")
except ImportError:
- raise ValueError(f'no {project} module in the syspath')
+ raise ValueError(f"no {project} module in the syspath")
try:
- version_info = getattr(
- import_module(f'{project}._verinfo'), 'version_info')
+ version_info = getattr(import_module(f"{project}._verinfo"), "version_info")
except ImportError:
# we're probably in a git repo
path = os.path.dirname(os.path.abspath(repo_file))
version_info = get_git_version(path)
if version_info is None:
- s = ''
- elif version_info['tag'] == api_version:
+ s = ""
+ elif version_info["tag"] == api_version:
s = f" -- released {version_info['date']}"
else:
- rev = version_info['rev'][:7]
- date = version_info['date']
- commits = version_info.get('commits', None)
- commits = f'-{commits}' if commits is not None else ''
- s = f'{commits}-g{rev} -- {date}'
+ rev = version_info["rev"][:7]
+ date = version_info["date"]
+ commits = version_info.get("commits", None)
+ commits = f"-{commits}" if commits is not None else ""
+ s = f"{commits}-g{rev} -- {date}"
- _ver = f'{project} {api_version}{s}'
+ _ver = f"{project} {api_version}{s}"
return _ver
def _run_git(path, cmd):
env = dict(os.environ)
- for key in env.copy(): # pragma: no cover
+ for key in env.copy(): # pragma: no cover
if key.startswith("LC_"):
del env[key]
env["LC_CTYPE"] = "C"
env["LC_ALL"] = "C"
r = subprocess.Popen(
- ['git'] + list(cmd), stdout=subprocess.PIPE, env=env,
- stderr=subprocess.DEVNULL, cwd=path)
+ ["git"] + list(cmd),
+ stdout=subprocess.PIPE,
+ env=env,
+ stderr=subprocess.DEVNULL,
+ cwd=path,
+ )
stdout = r.communicate()[0]
return stdout, r.returncode
@@ -83,21 +86,20 @@ def get_git_version(path):
tag = _get_git_tag(path, data[0])
# get number of commits since most recent tag
- stdout, ret = _run_git(path, ['describe', '--tags', '--abbrev=0'])
+ stdout, ret = _run_git(path, ["describe", "--tags", "--abbrev=0"])
prev_tag = None
commits = None
if ret == 0:
prev_tag = stdout.decode().strip()
- stdout, ret = _run_git(
- path, ['log', '--oneline', f'{prev_tag}..HEAD'])
+ stdout, ret = _run_git(path, ["log", "--oneline", f"{prev_tag}..HEAD"])
if ret == 0:
commits = len(stdout.decode().splitlines())
return {
- 'rev': data[0],
- 'date': data[1],
- 'tag': tag,
- 'commits': commits,
+ "rev": data[0],
+ "date": data[1],
+ "tag": tag,
+ "commits": commits,
}
except EnvironmentError as exc:
# ENOENT is thrown when the git binary can't be found.
@@ -107,14 +109,14 @@ def get_git_version(path):
def _get_git_tag(path, rev):
- stdout, _ = _run_git(path, ['name-rev', '--tag', rev])
+ stdout, _ = _run_git(path, ["name-rev", "--tag", rev])
tag = stdout.decode().split()
if len(tag) != 2:
return None
tag = tag[1]
if not tag.startswith("tags/"):
return None
- tag = tag[len("tags/"):]
+ tag = tag[len("tags/") :]
if tag.endswith("^0"):
tag = tag[:-2]
if tag.startswith("v"):
diff --git a/src/snakeoil/weakrefs.py b/src/snakeoil/weakrefs.py
index 1fb7e0bf..83a79c00 100644
--- a/src/snakeoil/weakrefs.py
+++ b/src/snakeoil/weakrefs.py
@@ -13,7 +13,6 @@ def finalize_instance(obj, weakref_inst):
class WeakRefProxy(BaseDelayedObject):
-
def __instantiate_proxy_instance__(self):
obj = BaseDelayedObject.__instantiate_proxy_instance__(self)
weakref = ref(self, partial(finalize_instance, obj))
diff --git a/tests/cli/test_arghparse.py b/tests/cli/test_arghparse.py
index ccaa65e7..4741d861 100644
--- a/tests/cli/test_arghparse.py
+++ b/tests/cli/test_arghparse.py
@@ -11,56 +11,60 @@ from snakeoil.test import argparse_helpers
class TestArgparseDocs:
-
def test_add_argument_docs(self):
# force using an unpatched version of argparse
reload(argparse)
parser = argparse.ArgumentParser()
- parser.add_argument('--foo', action='store_true')
+ parser.add_argument("--foo", action="store_true")
# vanilla argparse doesn't support docs kwargs
with pytest.raises(TypeError):
parser.add_argument(
- '-b', '--blah', action='store_true', docs='Blah blah blah')
+ "-b", "--blah", action="store_true", docs="Blah blah blah"
+ )
with pytest.raises(TypeError):
- parser.add_argument_group('fa', description='fa la la', docs='fa la la la')
+ parser.add_argument_group("fa", description="fa la la", docs="fa la la la")
with pytest.raises(TypeError):
- parser.add_mutually_exclusive_group('fee', description='fi', docs='fo fum')
+ parser.add_mutually_exclusive_group("fee", description="fi", docs="fo fum")
# forcibly monkey-patch argparse to allow docs kwargs
reload(arghparse)
- default = 'baz baz'
- docs = 'blah blah'
+ default = "baz baz"
+ docs = "blah blah"
for enable_docs, expected_txt in ((False, default), (True, docs)):
arghparse._generate_docs = enable_docs
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(description=default, docs=docs)
- subparser = subparsers.add_parser('foo', description=default, docs=docs)
+ subparser = subparsers.add_parser("foo", description=default, docs=docs)
action = parser.add_argument(
- '-b', '--blah', action='store_true', help=default, docs=docs)
- arg_group = parser.add_argument_group('fa', description=default, docs=docs)
+ "-b", "--blah", action="store_true", help=default, docs=docs
+ )
+ arg_group = parser.add_argument_group("fa", description=default, docs=docs)
mut_arg_group = parser.add_mutually_exclusive_group()
mut_action = mut_arg_group.add_argument(
- '-f', '--fee', action='store_true', help=default, docs=docs)
+ "-f", "--fee", action="store_true", help=default, docs=docs
+ )
- assert getattr(parser._subparsers, 'description', None) == expected_txt
- assert getattr(subparser, 'description', None) == expected_txt
- assert getattr(action, 'help', None) == expected_txt
- assert getattr(arg_group, 'description', None) == expected_txt
- assert getattr(mut_action, 'help', None) == expected_txt
+ assert getattr(parser._subparsers, "description", None) == expected_txt
+ assert getattr(subparser, "description", None) == expected_txt
+ assert getattr(action, "help", None) == expected_txt
+ assert getattr(arg_group, "description", None) == expected_txt
+ assert getattr(mut_action, "help", None) == expected_txt
# list/tuple-based docs
arghparse._generate_docs = True
- docs = 'foo bar'
+ docs = "foo bar"
parser = argparse.ArgumentParser()
list_action = parser.add_argument(
- '-b', '--blah', action='store_true', help=default, docs=list(docs.split()))
+ "-b", "--blah", action="store_true", help=default, docs=list(docs.split())
+ )
tuple_action = parser.add_argument(
- '-c', '--cat', action='store_true', help=default, docs=tuple(docs.split()))
- assert getattr(list_action, 'help', None) == 'foo\nbar'
- assert getattr(tuple_action, 'help', None) == 'foo\nbar'
+ "-c", "--cat", action="store_true", help=default, docs=tuple(docs.split())
+ )
+ assert getattr(list_action, "help", None) == "foo\nbar"
+ assert getattr(tuple_action, "help", None) == "foo\nbar"
class TestOptionalsParser:
@@ -68,7 +72,9 @@ class TestOptionalsParser:
# TODO: move this to a generic argparse fixture
@pytest.fixture(autouse=True)
def __setup_optionals_parser(self):
- self.optionals_parser = argparse_helpers.mangle_parser(arghparse.OptionalsParser())
+ self.optionals_parser = argparse_helpers.mangle_parser(
+ arghparse.OptionalsParser()
+ )
def test_no_args(self):
args, unknown = self.optionals_parser.parse_known_optionals([])
@@ -76,14 +82,14 @@ class TestOptionalsParser:
assert unknown == []
def test_only_positionals(self):
- self.optionals_parser.add_argument('args')
+ self.optionals_parser.add_argument("args")
args, unknown = self.optionals_parser.parse_known_optionals([])
- assert vars(args) == {'args': None}
+ assert vars(args) == {"args": None}
assert unknown == []
def test_optionals(self):
- self.optionals_parser.add_argument('--opt1')
- self.optionals_parser.add_argument('args')
+ self.optionals_parser.add_argument("--opt1")
+ self.optionals_parser.add_argument("args")
parse = self.optionals_parser.parse_known_optionals
# no args
@@ -92,37 +98,37 @@ class TestOptionalsParser:
assert unknown == []
# only known optional
- args, unknown = parse(['--opt1', 'yes'])
- assert args.opt1 == 'yes'
+ args, unknown = parse(["--opt1", "yes"])
+ assert args.opt1 == "yes"
assert unknown == []
# unknown optional
- args, unknown = parse(['--foo'])
+ args, unknown = parse(["--foo"])
assert args.opt1 is None
- assert unknown == ['--foo']
+ assert unknown == ["--foo"]
# unknown optional and positional
- args, unknown = parse(['--foo', 'arg'])
+ args, unknown = parse(["--foo", "arg"])
assert args.opt1 is None
- assert unknown == ['--foo', 'arg']
+ assert unknown == ["--foo", "arg"]
# known optional with unknown optional
- args, unknown = parse(['--opt1', 'yes', '--foo'])
- assert args.opt1 == 'yes'
- assert unknown == ['--foo']
+ args, unknown = parse(["--opt1", "yes", "--foo"])
+ assert args.opt1 == "yes"
+ assert unknown == ["--foo"]
# different order
- args, unknown = parse(['--foo', '--opt1', 'yes'])
- assert args.opt1 == 'yes'
- assert unknown == ['--foo']
+ args, unknown = parse(["--foo", "--opt1", "yes"])
+ assert args.opt1 == "yes"
+ assert unknown == ["--foo"]
# known optional with unknown positional
- args, unknown = parse(['--opt1', 'yes', 'arg'])
- assert args.opt1 == 'yes'
- assert unknown == ['arg']
+ args, unknown = parse(["--opt1", "yes", "arg"])
+ assert args.opt1 == "yes"
+ assert unknown == ["arg"]
# known optionals parsing stops at the first positional arg
- args, unknown = parse(['arg', '--opt1', 'yes'])
+ args, unknown = parse(["arg", "--opt1", "yes"])
assert args.opt1 is None
- assert unknown == ['arg', '--opt1', 'yes']
+ assert unknown == ["arg", "--opt1", "yes"]
class TestCsvActionsParser:
@@ -134,20 +140,19 @@ class TestCsvActionsParser:
def test_bad_action(self):
with pytest.raises(ValueError) as excinfo:
- self.csv_parser.add_argument('--arg1', action='unknown')
+ self.csv_parser.add_argument("--arg1", action="unknown")
assert 'unknown action "unknown"' == str(excinfo.value)
def test_csv_actions(self):
- self.csv_parser.add_argument('--arg1', action='csv')
- self.csv_parser.add_argument('--arg2', action='csv_append')
- self.csv_parser.add_argument('--arg3', action='csv_negations')
- self.csv_parser.add_argument('--arg4', action='csv_negations_append')
- self.csv_parser.add_argument('--arg5', action='csv_elements')
- self.csv_parser.add_argument('--arg6', action='csv_elements_append')
+ self.csv_parser.add_argument("--arg1", action="csv")
+ self.csv_parser.add_argument("--arg2", action="csv_append")
+ self.csv_parser.add_argument("--arg3", action="csv_negations")
+ self.csv_parser.add_argument("--arg4", action="csv_negations_append")
+ self.csv_parser.add_argument("--arg5", action="csv_elements")
+ self.csv_parser.add_argument("--arg6", action="csv_elements_append")
class TestArgumentParser(TestCsvActionsParser, TestOptionalsParser):
-
def test_debug(self):
# debug passed
parser = argparse_helpers.mangle_parser(arghparse.ArgumentParser(debug=True))
@@ -161,8 +166,10 @@ class TestArgumentParser(TestCsvActionsParser, TestOptionalsParser):
assert namespace.debug is False
# debug passed in sys.argv -- early debug attr on the parser instance is set
- with mock.patch('sys.argv', ['script', '--debug']):
- parser = argparse_helpers.mangle_parser(arghparse.ArgumentParser(debug=True))
+ with mock.patch("sys.argv", ["script", "--debug"]):
+ parser = argparse_helpers.mangle_parser(
+ arghparse.ArgumentParser(debug=True)
+ )
assert parser.debug is True
def test_debug_disabled(self):
@@ -176,34 +183,36 @@ class TestArgumentParser(TestCsvActionsParser, TestOptionalsParser):
# parser attribute still exists
assert parser.debug is False
# but namespace attribute doesn't
- assert not hasattr(namespace, 'debug')
+ assert not hasattr(namespace, "debug")
def test_verbosity(self):
values = (
([], 0),
- (['-q'], -1),
- (['--quiet'], -1),
- (['-v'], 1),
- (['--verbose'], 1),
- (['-q', '-v'], 0),
- (['--quiet', '--verbose'], 0),
- (['-q', '-q'], -2),
- (['-v', '-v'], 2),
+ (["-q"], -1),
+ (["--quiet"], -1),
+ (["-v"], 1),
+ (["--verbose"], 1),
+ (["-q", "-v"], 0),
+ (["--quiet", "--verbose"], 0),
+ (["-q", "-q"], -2),
+ (["-v", "-v"], 2),
)
for args, val in values:
- with mock.patch('sys.argv', ['script'] + args):
+ with mock.patch("sys.argv", ["script"] + args):
parser = argparse_helpers.mangle_parser(
- arghparse.ArgumentParser(quiet=True, verbose=True))
+ arghparse.ArgumentParser(quiet=True, verbose=True)
+ )
namespace = parser.parse_args(args)
- assert parser.verbosity == val, '{} failed'.format(args)
- assert namespace.verbosity == val, '{} failed'.format(args)
+ assert parser.verbosity == val, "{} failed".format(args)
+ assert namespace.verbosity == val, "{} failed".format(args)
def test_verbosity_disabled(self):
parser = argparse_helpers.mangle_parser(
- arghparse.ArgumentParser(quiet=False, verbose=False))
+ arghparse.ArgumentParser(quiet=False, verbose=False)
+ )
# ensure the options aren't there if disabled
- for args in ('-q', '--quiet', '-v', '--verbose'):
+ for args in ("-q", "--quiet", "-v", "--verbose"):
with pytest.raises(argparse_helpers.Error):
namespace = parser.parse_args([args])
@@ -211,17 +220,15 @@ class TestArgumentParser(TestCsvActionsParser, TestOptionalsParser):
# parser attribute still exists
assert parser.verbosity == 0
# but namespace attribute doesn't
- assert not hasattr(namespace, 'verbosity')
+ assert not hasattr(namespace, "verbosity")
class BaseArgparseOptions:
-
def setup_method(self, method):
self.parser = argparse_helpers.mangle_parser(arghparse.ArgumentParser())
class TestStoreBoolAction(BaseArgparseOptions):
-
def setup_method(self, method):
super().setup_method(method)
self.parser.add_argument("--testing", action=arghparse.StoreBool, default=None)
@@ -229,13 +236,13 @@ class TestStoreBoolAction(BaseArgparseOptions):
def test_bool_disabled(self):
for raw_val in ("n", "no", "false"):
for allowed in (raw_val.upper(), raw_val.lower()):
- namespace = self.parser.parse_args(['--testing=' + allowed])
+ namespace = self.parser.parse_args(["--testing=" + allowed])
assert namespace.testing is False
def test_bool_enabled(self):
for raw_val in ("y", "yes", "true"):
for allowed in (raw_val.upper(), raw_val.lower()):
- namespace = self.parser.parse_args(['--testing=' + allowed])
+ namespace = self.parser.parse_args(["--testing=" + allowed])
assert namespace.testing is True
def test_bool_invalid(self):
@@ -244,249 +251,244 @@ class TestStoreBoolAction(BaseArgparseOptions):
class ParseStdinTest(BaseArgparseOptions):
-
def setup_method(self, method):
super().setup_method(method)
- self.parser.add_argument(
- "testing", nargs='+', action=arghparse.ParseStdin)
+ self.parser.add_argument("testing", nargs="+", action=arghparse.ParseStdin)
def test_none_invalid(self):
with pytest.raises(argparse_helpers.Error):
self.parser.parse_args([])
def test_non_stdin(self):
- namespace = self.parser.parse_args(['foo'])
- assert namespace.testing == ['foo']
+ namespace = self.parser.parse_args(["foo"])
+ assert namespace.testing == ["foo"]
def test_non_stdin_multiple(self):
- namespace = self.parser.parse_args(['foo', 'bar'])
- assert namespace.testing == ['foo', 'bar']
+ namespace = self.parser.parse_args(["foo", "bar"])
+ assert namespace.testing == ["foo", "bar"]
def test_stdin(self):
# stdin is an interactive tty
- with mock.patch('sys.stdin.isatty', return_value=True):
+ with mock.patch("sys.stdin.isatty", return_value=True):
with pytest.raises(argparse_helpers.Error) as excinfo:
- namespace = self.parser.parse_args(['-'])
- assert 'only valid when piping data in' in str(excinfo.value)
+ namespace = self.parser.parse_args(["-"])
+ assert "only valid when piping data in" in str(excinfo.value)
# fake piping data in
for readlines, expected in (
- ([], []),
- ([' '], []),
- (['\n'], []),
- (['\n', '\n'], []),
- (['foo'], ['foo']),
- (['foo '], ['foo']),
- (['foo\n'], ['foo']),
- (['foo', 'bar', 'baz'], ['foo', 'bar', 'baz']),
- (['\nfoo\n', ' bar ', '\nbaz'], ['\nfoo', ' bar', '\nbaz']),
+ ([], []),
+ ([" "], []),
+ (["\n"], []),
+ (["\n", "\n"], []),
+ (["foo"], ["foo"]),
+ (["foo "], ["foo"]),
+ (["foo\n"], ["foo"]),
+ (["foo", "bar", "baz"], ["foo", "bar", "baz"]),
+ (["\nfoo\n", " bar ", "\nbaz"], ["\nfoo", " bar", "\nbaz"]),
):
- with mock.patch('sys.stdin') as stdin, \
- mock.patch("builtins.open", mock.mock_open()) as mock_file:
+ with mock.patch("sys.stdin") as stdin, mock.patch(
+ "builtins.open", mock.mock_open()
+ ) as mock_file:
stdin.readlines.return_value = readlines
stdin.isatty.return_value = False
- namespace = self.parser.parse_args(['-'])
+ namespace = self.parser.parse_args(["-"])
mock_file.assert_called_once_with("/dev/tty")
assert namespace.testing == expected
class TestCommaSeparatedValuesAction(BaseArgparseOptions):
-
def setup_method(self, method):
super().setup_method(method)
self.test_values = (
- ('', []),
- (',', []),
- (',,', []),
- ('a', ['a']),
- ('a,b,-c', ['a', 'b', '-c']),
+ ("", []),
+ (",", []),
+ (",,", []),
+ ("a", ["a"]),
+ ("a,b,-c", ["a", "b", "-c"]),
)
- self.action = 'csv'
+ self.action = "csv"
self.single_expected = lambda x: x
self.multi_expected = lambda x: x
def test_parse_args(self):
- self.parser.add_argument('--testing', action=self.action)
+ self.parser.add_argument("--testing", action=self.action)
for raw_val, expected in self.test_values:
- namespace = self.parser.parse_args(['--testing=' + raw_val])
+ namespace = self.parser.parse_args(["--testing=" + raw_val])
assert namespace.testing == self.single_expected(expected)
def test_parse_multi_args(self):
- self.parser.add_argument('--testing', action=self.action)
+ self.parser.add_argument("--testing", action=self.action)
for raw_val, expected in self.test_values:
- namespace = self.parser.parse_args([
- '--testing=' + raw_val, '--testing=' + raw_val,
- ])
+ namespace = self.parser.parse_args(
+ [
+ "--testing=" + raw_val,
+ "--testing=" + raw_val,
+ ]
+ )
assert namespace.testing == self.multi_expected(expected)
class TestCommaSeparatedValuesAppendAction(TestCommaSeparatedValuesAction):
-
def setup_method(self, method):
super().setup_method(method)
- self.action = 'csv_append'
+ self.action = "csv_append"
self.multi_expected = lambda x: x + x
class TestCommaSeparatedNegationsAction(TestCommaSeparatedValuesAction):
-
def setup_method(self, method):
super().setup_method(method)
self.test_values = (
- ('', ([], [])),
- (',', ([], [])),
- (',,', ([], [])),
- ('a', ([], ['a'])),
- ('-a', (['a'], [])),
- ('a,-b,-c,d', (['b', 'c'], ['a', 'd'])),
+ ("", ([], [])),
+ (",", ([], [])),
+ (",,", ([], [])),
+ ("a", ([], ["a"])),
+ ("-a", (["a"], [])),
+ ("a,-b,-c,d", (["b", "c"], ["a", "d"])),
)
- self.bad_args = ('-',)
- self.action = 'csv_negations'
+ self.bad_args = ("-",)
+ self.action = "csv_negations"
def test_parse_bad_args(self):
- self.parser.add_argument('--testing', action=self.action)
+ self.parser.add_argument("--testing", action=self.action)
for arg in self.bad_args:
with pytest.raises(argparse.ArgumentTypeError) as excinfo:
- namespace = self.parser.parse_args(['--testing=' + arg])
- assert 'without a token' in str(excinfo.value)
+ namespace = self.parser.parse_args(["--testing=" + arg])
+ assert "without a token" in str(excinfo.value)
class TestCommaSeparatedNegationsAppendAction(TestCommaSeparatedNegationsAction):
-
def setup_method(self, method):
super().setup_method(method)
- self.action = 'csv_negations_append'
+ self.action = "csv_negations_append"
self.multi_expected = lambda x: tuple(x + y for x, y in zip(x, x))
class TestCommaSeparatedElementsAction(TestCommaSeparatedNegationsAction):
-
def setup_method(self, method):
super().setup_method(method)
self.test_values = (
- ('', ([], [], [])),
- (',', ([], [], [])),
- (',,', ([], [], [])),
- ('-a', (['a'], [], [])),
- ('a', ([], ['a'], [])),
- ('+a', ([], [], ['a'])),
- ('a,-b,-c,d', (['b', 'c'], ['a', 'd'], [])),
- ('a,-b,+c,-d,+e,f', (['b', 'd'], ['a', 'f'], ['c', 'e'])),
+ ("", ([], [], [])),
+ (",", ([], [], [])),
+ (",,", ([], [], [])),
+ ("-a", (["a"], [], [])),
+ ("a", ([], ["a"], [])),
+ ("+a", ([], [], ["a"])),
+ ("a,-b,-c,d", (["b", "c"], ["a", "d"], [])),
+ ("a,-b,+c,-d,+e,f", (["b", "d"], ["a", "f"], ["c", "e"])),
)
- self.bad_values = ('-', '+')
- self.action = 'csv_elements'
+ self.bad_values = ("-", "+")
+ self.action = "csv_elements"
class TestCommaSeparatedElementsAppendAction(TestCommaSeparatedElementsAction):
-
def setup_method(self, method):
super().setup_method(method)
- self.action = 'csv_elements_append'
+ self.action = "csv_elements_append"
self.multi_expected = lambda x: tuple(x + y for x, y in zip(x, x))
class TestExistentPathType(BaseArgparseOptions):
-
def setup_method(self, method):
super().setup_method(method)
- self.parser.add_argument('--path', type=arghparse.existent_path)
+ self.parser.add_argument("--path", type=arghparse.existent_path)
def test_nonexistent(self):
# nonexistent path arg raises an error
with pytest.raises(argparse_helpers.Error):
- self.parser.parse_args(['--path=/path/to/nowhere'])
+ self.parser.parse_args(["--path=/path/to/nowhere"])
def test_os_errors(self, tmpdir):
# random OS/FS issues raise errors
- with mock.patch('os.path.realpath') as realpath:
- realpath.side_effect = OSError(19, 'Random OS error')
+ with mock.patch("os.path.realpath") as realpath:
+ realpath.side_effect = OSError(19, "Random OS error")
with pytest.raises(argparse_helpers.Error):
- self.parser.parse_args(['--path=%s' % tmpdir])
+ self.parser.parse_args(["--path=%s" % tmpdir])
def test_regular_usage(self, tmpdir):
- namespace = self.parser.parse_args(['--path=%s' % tmpdir])
+ namespace = self.parser.parse_args(["--path=%s" % tmpdir])
assert namespace.path == str(tmpdir)
class TestExistentDirType(BaseArgparseOptions):
-
def setup_method(self, method):
super().setup_method(method)
- self.parser.add_argument('--path', type=arghparse.existent_dir)
+ self.parser.add_argument("--path", type=arghparse.existent_dir)
def test_nonexistent(self):
# nonexistent path arg raises an error
with pytest.raises(argparse_helpers.Error):
- self.parser.parse_args(['--path=/path/to/nowhere'])
+ self.parser.parse_args(["--path=/path/to/nowhere"])
def test_os_errors(self, tmp_path):
# random OS/FS issues raise errors
- with mock.patch('os.path.realpath') as realpath:
- realpath.side_effect = OSError(19, 'Random OS error')
+ with mock.patch("os.path.realpath") as realpath:
+ realpath.side_effect = OSError(19, "Random OS error")
with pytest.raises(argparse_helpers.Error):
- self.parser.parse_args([f'--path={tmp_path}'])
+ self.parser.parse_args([f"--path={tmp_path}"])
def test_file_path(self, tmp_path):
- f = tmp_path / 'file'
+ f = tmp_path / "file"
f.touch()
with pytest.raises(argparse_helpers.Error):
- self.parser.parse_args([f'--path={f}'])
+ self.parser.parse_args([f"--path={f}"])
def test_regular_usage(self, tmp_path):
- namespace = self.parser.parse_args([f'--path={tmp_path}'])
+ namespace = self.parser.parse_args([f"--path={tmp_path}"])
assert namespace.path == str(tmp_path)
class TestNamespace:
-
def setup_method(self, method):
self.parser = argparse_helpers.mangle_parser(arghparse.ArgumentParser())
def test_pop(self):
- self.parser.set_defaults(test='test')
+ self.parser.set_defaults(test="test")
namespace = self.parser.parse_args([])
- assert namespace.pop('test') == 'test'
+ assert namespace.pop("test") == "test"
# re-popping raises an exception since the attr has been removed
with pytest.raises(AttributeError):
- namespace.pop('test')
+ namespace.pop("test")
# popping a nonexistent attr with a fallback returns the fallback
- assert namespace.pop('nonexistent', 'foo') == 'foo'
+ assert namespace.pop("nonexistent", "foo") == "foo"
def test_collapse_delayed(self):
def _delayed_val(namespace, attr, val):
setattr(namespace, attr, val)
- self.parser.set_defaults(delayed=arghparse.DelayedValue(partial(_delayed_val, val=42)))
+
+ self.parser.set_defaults(
+ delayed=arghparse.DelayedValue(partial(_delayed_val, val=42))
+ )
namespace = self.parser.parse_args([])
assert namespace.delayed == 42
def test_bool(self):
namespace = arghparse.Namespace()
assert not namespace
- namespace.arg = 'foo'
+ namespace.arg = "foo"
assert namespace
class TestManHelpAction:
-
def test_help(self, capsys):
parser = argparse_helpers.mangle_parser(arghparse.ArgumentParser())
- with mock.patch('subprocess.Popen') as popen:
+ with mock.patch("subprocess.Popen") as popen:
# --help long option tries man page first before falling back to help output
with pytest.raises(argparse_helpers.Exit):
- namespace = parser.parse_args(['--help'])
+ namespace = parser.parse_args(["--help"])
popen.assert_called_once()
- assert popen.call_args[0][0][0] == 'man'
+ assert popen.call_args[0][0][0] == "man"
captured = capsys.readouterr()
- assert captured.out.strip().startswith('usage: ')
+ assert captured.out.strip().startswith("usage: ")
popen.reset_mock()
# -h short option just displays the regular help output
with pytest.raises(argparse_helpers.Exit):
- namespace = parser.parse_args(['-h'])
+ namespace = parser.parse_args(["-h"])
popen.assert_not_called()
captured = capsys.readouterr()
- assert captured.out.strip().startswith('usage: ')
+ assert captured.out.strip().startswith("usage: ")
popen.reset_mock()
diff --git a/tests/cli/test_input.py b/tests/cli/test_input.py
index 8efb1f59..2f15f256 100644
--- a/tests/cli/test_input.py
+++ b/tests/cli/test_input.py
@@ -9,12 +9,11 @@ from snakeoil.test.argparse_helpers import FakeStreamFormatter
@pytest.fixture
def mocked_input():
- with mock.patch('builtins.input') as mocked_input:
+ with mock.patch("builtins.input") as mocked_input:
yield mocked_input
class TestUserQuery:
-
@pytest.fixture(autouse=True)
def __setup(self):
self.out = FakeStreamFormatter()
@@ -22,98 +21,104 @@ class TestUserQuery:
self.query = partial(userquery, out=self.out, err=self.err)
def test_default_answer(self, mocked_input):
- mocked_input.return_value = ''
- assert self.query('foo') == True
+ mocked_input.return_value = ""
+ assert self.query("foo") == True
def test_tuple_prompt(self, mocked_input):
- mocked_input.return_value = ''
- prompt = 'perhaps a tuple'
+ mocked_input.return_value = ""
+ prompt = "perhaps a tuple"
assert self.query(tuple(prompt.split())) == True
- output = ''.join(prompt.split())
- assert self.out.get_text_stream().strip().split('\n')[0][:len(output)] == output
+ output = "".join(prompt.split())
+ assert (
+ self.out.get_text_stream().strip().split("\n")[0][: len(output)] == output
+ )
def test_no_default_answer(self, mocked_input):
responses = {
- 'a': ('z', 'Yes'),
- 'b': ('y', 'No'),
+ "a": ("z", "Yes"),
+ "b": ("y", "No"),
}
# no default answer returns None for empty input
- mocked_input.return_value = ''
- assert self.query('foo', responses=responses) == None
- mocked_input.return_value = 'a'
- assert self.query('foo', responses=responses) == 'z'
- mocked_input.return_value = 'b'
- assert self.query('foo', responses=responses) == 'y'
+ mocked_input.return_value = ""
+ assert self.query("foo", responses=responses) == None
+ mocked_input.return_value = "a"
+ assert self.query("foo", responses=responses) == "z"
+ mocked_input.return_value = "b"
+ assert self.query("foo", responses=responses) == "y"
def test_ambiguous_input(self, mocked_input):
responses = {
- 'a': ('z', 'Yes'),
- 'A': ('y', 'No'),
+ "a": ("z", "Yes"),
+ "A": ("y", "No"),
}
- mocked_input.return_value = 'a'
+ mocked_input.return_value = "a"
with pytest.raises(NoChoice):
- self.query('foo', responses=responses)
- error_output = self.err.get_text_stream().strip().split('\n')[1]
- expected = 'Response %r is ambiguous (%s)' % (
- mocked_input.return_value, ', '.join(sorted(responses.keys())))
+ self.query("foo", responses=responses)
+ error_output = self.err.get_text_stream().strip().split("\n")[1]
+ expected = "Response %r is ambiguous (%s)" % (
+ mocked_input.return_value,
+ ", ".join(sorted(responses.keys())),
+ )
assert error_output == expected
def test_default_correct_input(self, mocked_input):
- for input, output in (('no', False),
- ('No', False),
- ('yes', True),
- ('Yes', True)):
+ for input, output in (
+ ("no", False),
+ ("No", False),
+ ("yes", True),
+ ("Yes", True),
+ ):
mocked_input.return_value = input
- assert self.query('foo') == output
+ assert self.query("foo") == output
def test_default_answer_no_matches(self, mocked_input):
- mocked_input.return_value = ''
+ mocked_input.return_value = ""
with pytest.raises(ValueError):
- self.query('foo', default_answer='foo')
+ self.query("foo", default_answer="foo")
assert self.out.stream == []
def test_custom_default_answer(self, mocked_input):
- mocked_input.return_value = ''
- assert self.query('foo', default_answer=False) == False
+ mocked_input.return_value = ""
+ assert self.query("foo", default_answer=False) == False
def test_eof_nochoice(self, mocked_input):
# user hits ctrl-d
mocked_input.side_effect = EOFError
with pytest.raises(NoChoice):
- self.query('foo')
- output = self.out.get_text_stream().strip().split('\n')[1]
- expected = 'Not answerable: EOF on STDIN'
+ self.query("foo")
+ output = self.out.get_text_stream().strip().split("\n")[1]
+ expected = "Not answerable: EOF on STDIN"
assert output == expected
def test_stdin_closed_nochoice(self, mocked_input):
- mocked_input.side_effect = IOError(errno.EBADF, '')
+ mocked_input.side_effect = IOError(errno.EBADF, "")
with pytest.raises(NoChoice):
- self.query('foo')
- output = self.out.get_text_stream().strip().split('\n')[1]
- expected = 'Not answerable: STDIN is either closed, or not readable'
+ self.query("foo")
+ output = self.out.get_text_stream().strip().split("\n")[1]
+ expected = "Not answerable: STDIN is either closed, or not readable"
assert output == expected
def test_unhandled_ioerror(self, mocked_input):
- mocked_input.side_effect = IOError(errno.ENODEV, '')
+ mocked_input.side_effect = IOError(errno.ENODEV, "")
with pytest.raises(IOError):
- self.query('foo')
+ self.query("foo")
def test_bad_choice_limit(self, mocked_input):
# user hits enters a bad choice 3 times in a row
- mocked_input.return_value = 'bad'
+ mocked_input.return_value = "bad"
with pytest.raises(NoChoice):
- self.query('foo')
+ self.query("foo")
assert mocked_input.call_count == 3
- output = self.err.get_text_stream().strip().split('\n')[1]
+ output = self.err.get_text_stream().strip().split("\n")[1]
expected = "Sorry, response %r not understood." % (mocked_input.return_value,)
assert output == expected
def test_custom_choice_limit(self, mocked_input):
# user hits enters a bad choice 5 times in a row
- mocked_input.return_value = 'haha'
+ mocked_input.return_value = "haha"
with pytest.raises(NoChoice):
- self.query('foo', limit=5)
+ self.query("foo", limit=5)
assert mocked_input.call_count == 5
- output = self.err.get_text_stream().strip().split('\n')[1]
+ output = self.err.get_text_stream().strip().split("\n")[1]
expected = "Sorry, response %r not understood." % (mocked_input.return_value,)
assert output == expected
diff --git a/tests/compression/__init__.py b/tests/compression/__init__.py
index 3b70dcba..0bf26d04 100644
--- a/tests/compression/__init__.py
+++ b/tests/compression/__init__.py
@@ -4,78 +4,100 @@ import pytest
from snakeoil import compression
from snakeoil.process import CommandNotFound, find_binary
+
def hide_binary(*binaries: str):
def mock_find_binary(name):
if name in binaries:
raise CommandNotFound(name)
return find_binary(name)
- return patch('snakeoil.process.find_binary', side_effect=mock_find_binary)
+ return patch("snakeoil.process.find_binary", side_effect=mock_find_binary)
class Base:
- module: str = ''
- decompressed_test_data: bytes = b''
- compressed_test_data: bytes = b''
+ module: str = ""
+ decompressed_test_data: bytes = b""
+ compressed_test_data: bytes = b""
def decompress(self, data: bytes) -> bytes:
- raise NotImplementedError(self, 'decompress')
+ raise NotImplementedError(self, "decompress")
- @pytest.mark.parametrize('parallelize', (True, False))
- @pytest.mark.parametrize('level', (1, 9))
+ @pytest.mark.parametrize("parallelize", (True, False))
+ @pytest.mark.parametrize("level", (1, 9))
def test_compress_data(self, level, parallelize):
- compressed = compression.compress_data(self.module, self.decompressed_test_data, level=level, parallelize=parallelize)
+ compressed = compression.compress_data(
+ self.module,
+ self.decompressed_test_data,
+ level=level,
+ parallelize=parallelize,
+ )
assert compressed
assert self.decompress(compressed) == self.decompressed_test_data
- @pytest.mark.parametrize('parallelize', (True, False))
+ @pytest.mark.parametrize("parallelize", (True, False))
def test_decompress_data(self, parallelize):
- assert self.decompressed_test_data == compression.decompress_data(self.module, self.compressed_test_data, parallelize=parallelize)
+ assert self.decompressed_test_data == compression.decompress_data(
+ self.module, self.compressed_test_data, parallelize=parallelize
+ )
- @pytest.mark.parametrize('parallelize', (True, False))
- @pytest.mark.parametrize('level', (1, 9))
+ @pytest.mark.parametrize("parallelize", (True, False))
+ @pytest.mark.parametrize("level", (1, 9))
def test_compress_handle(self, tmp_path, level, parallelize):
- path = tmp_path / f'test.{self.module}'
+ path = tmp_path / f"test.{self.module}"
- stream = compression.compress_handle(self.module, str(path), level=level, parallelize=parallelize)
+ stream = compression.compress_handle(
+ self.module, str(path), level=level, parallelize=parallelize
+ )
stream.write(self.decompressed_test_data)
stream.close()
assert self.decompress(path.read_bytes()) == self.decompressed_test_data
with path.open("wb") as file:
- stream = compression.compress_handle(self.module, file, level=level, parallelize=parallelize)
+ stream = compression.compress_handle(
+ self.module, file, level=level, parallelize=parallelize
+ )
stream.write(self.decompressed_test_data)
stream.close()
assert self.decompress(path.read_bytes()) == self.decompressed_test_data
with path.open("wb") as file:
- stream = compression.compress_handle(self.module, file.fileno(), level=level, parallelize=parallelize)
+ stream = compression.compress_handle(
+ self.module, file.fileno(), level=level, parallelize=parallelize
+ )
stream.write(self.decompressed_test_data)
stream.close()
assert self.decompress(path.read_bytes()) == self.decompressed_test_data
with pytest.raises(TypeError):
- compression.compress_handle(self.module, b'', level=level, parallelize=parallelize)
+ compression.compress_handle(
+ self.module, b"", level=level, parallelize=parallelize
+ )
- @pytest.mark.parametrize('parallelize', (True, False))
+ @pytest.mark.parametrize("parallelize", (True, False))
def test_decompress_handle(self, tmp_path, parallelize):
- path = tmp_path / f'test.{self.module}'
+ path = tmp_path / f"test.{self.module}"
path.write_bytes(self.compressed_test_data)
- stream = compression.decompress_handle(self.module, str(path), parallelize=parallelize)
+ stream = compression.decompress_handle(
+ self.module, str(path), parallelize=parallelize
+ )
assert stream.read() == self.decompressed_test_data
stream.close()
with path.open("rb") as file:
- stream = compression.decompress_handle(self.module, file, parallelize=parallelize)
+ stream = compression.decompress_handle(
+ self.module, file, parallelize=parallelize
+ )
assert stream.read() == self.decompressed_test_data
stream.close()
with path.open("rb") as file:
- stream = compression.decompress_handle(self.module, file.fileno(), parallelize=parallelize)
+ stream = compression.decompress_handle(
+ self.module, file.fileno(), parallelize=parallelize
+ )
assert stream.read() == self.decompressed_test_data
stream.close()
with pytest.raises(TypeError):
- compression.decompress_handle(self.module, b'', parallelize=parallelize)
+ compression.decompress_handle(self.module, b"", parallelize=parallelize)
diff --git a/tests/compression/test_bzip2.py b/tests/compression/test_bzip2.py
index f3093d09..9fdffd9a 100644
--- a/tests/compression/test_bzip2.py
+++ b/tests/compression/test_bzip2.py
@@ -10,28 +10,29 @@ from . import Base, hide_binary
def test_no_native():
- with hide_imports('bz2'):
+ with hide_imports("bz2"):
importlib.reload(_bzip2)
assert not _bzip2.native
def test_missing_bzip2_binary():
- with hide_binary('bzip2'):
- with pytest.raises(CommandNotFound, match='bzip2'):
+ with hide_binary("bzip2"):
+ with pytest.raises(CommandNotFound, match="bzip2"):
importlib.reload(_bzip2)
def test_missing_lbzip2_binary():
- with hide_binary('lbzip2'):
+ with hide_binary("lbzip2"):
importlib.reload(_bzip2)
assert not _bzip2.parallelizable
+
class Bzip2Base(Base):
- module = 'bzip2'
- decompressed_test_data = b'Some text here\n'
+ module = "bzip2"
+ decompressed_test_data = b"Some text here\n"
compressed_test_data = (
- b'BZh91AY&SY\x1bM\x00\x02\x00\x00\x01\xd3\x80\x00\x10@\x00\x08\x00\x02'
+ b"BZh91AY&SY\x1bM\x00\x02\x00\x00\x01\xd3\x80\x00\x10@\x00\x08\x00\x02"
b'B\x94@ \x00"\r\x03\xd4\x0c \t!\x1b\xb7\x80u/\x17rE8P\x90\x1bM\x00\x02'
)
@@ -40,37 +41,36 @@ class Bzip2Base(Base):
class TestStdlib(Bzip2Base):
-
- @pytest.fixture(autouse=True, scope='class')
+ @pytest.fixture(autouse=True, scope="class")
def _setup(self):
try:
- find_binary('bzip2')
+ find_binary("bzip2")
except CommandNotFound:
- pytest.skip('bzip2 binary not found')
- with hide_binary('lbzip2'):
+ pytest.skip("bzip2 binary not found")
+ with hide_binary("lbzip2"):
importlib.reload(_bzip2)
yield
class TestBzip2(Bzip2Base):
-
- @pytest.fixture(autouse=True, scope='class')
+ @pytest.fixture(autouse=True, scope="class")
def _setup(self):
- with hide_binary('lbzip2'):
+ with hide_binary("lbzip2"):
importlib.reload(_bzip2)
yield
class TestLbzip2(Bzip2Base):
-
- @pytest.fixture(autouse=True, scope='class')
+ @pytest.fixture(autouse=True, scope="class")
def _setup(self):
try:
- find_binary('lbzip2')
+ find_binary("lbzip2")
except CommandNotFound:
- pytest.skip('lbzip2 binary not found')
+ pytest.skip("lbzip2 binary not found")
importlib.reload(_bzip2)
def test_bad_level(self):
with pytest.raises(ValueError, match='unknown option "-0"'):
- _bzip2.compress_data(self.decompressed_test_data, level=90, parallelize=True)
+ _bzip2.compress_data(
+ self.decompressed_test_data, level=90, parallelize=True
+ )
diff --git a/tests/compression/test_init.py b/tests/compression/test_init.py
index f3a40270..f1fe5bda 100644
--- a/tests/compression/test_init.py
+++ b/tests/compression/test_init.py
@@ -11,78 +11,77 @@ from . import hide_binary
@pytest.mark.skipif(sys.platform == "darwin", reason="darwin fails with bzip2")
class TestArComp:
-
- @pytest.fixture(scope='class')
+ @pytest.fixture(scope="class")
def tar_file(self, tmp_path_factory):
data = tmp_path_factory.mktemp("data")
- (data / 'file1').write_text('Hello world')
- (data / 'file2').write_text('Larry the Cow')
- path = data / 'test 1.tar'
- subprocess.run(['tar', 'cf', str(path), 'file1', 'file2'], cwd=data, check=True)
- (data / 'file1').unlink()
- (data / 'file2').unlink()
+ (data / "file1").write_text("Hello world")
+ (data / "file2").write_text("Larry the Cow")
+ path = data / "test 1.tar"
+ subprocess.run(["tar", "cf", str(path), "file1", "file2"], cwd=data, check=True)
+ (data / "file1").unlink()
+ (data / "file2").unlink()
return str(path)
- @pytest.fixture(scope='class')
+ @pytest.fixture(scope="class")
def tar_bz2_file(self, tar_file):
- subprocess.run(['bzip2', '-z', '-k', tar_file], check=True)
+ subprocess.run(["bzip2", "-z", "-k", tar_file], check=True)
return tar_file + ".bz2"
- @pytest.fixture(scope='class')
+ @pytest.fixture(scope="class")
def tbz2_file(self, tar_bz2_file):
- new_path = tar_bz2_file.replace('.tar.bz2', '.tbz2')
+ new_path = tar_bz2_file.replace(".tar.bz2", ".tbz2")
shutil.copyfile(tar_bz2_file, new_path)
return new_path
- @pytest.fixture(scope='class')
+ @pytest.fixture(scope="class")
def lzma_file(self, tmp_path_factory):
- data = (tmp_path_factory.mktemp("data") / 'test 2.lzma')
- with data.open('wb') as f:
- subprocess.run(['lzma'], check=True, input=b'Hello world', stdout=f)
+ data = tmp_path_factory.mktemp("data") / "test 2.lzma"
+ with data.open("wb") as f:
+ subprocess.run(["lzma"], check=True, input=b"Hello world", stdout=f)
return str(data)
def test_unknown_extenstion(self, tmp_path):
- file = tmp_path / 'test.file'
- with pytest.raises(ArCompError, match='unknown compression file extension'):
- ArComp(file, ext='.foo')
+ file = tmp_path / "test.file"
+ with pytest.raises(ArCompError, match="unknown compression file extension"):
+ ArComp(file, ext=".foo")
def test_missing_tar(self, tmp_path, tar_file):
- with hide_binary('tar'), chdir(tmp_path):
- with pytest.raises(ArCompError, match='required binary not found'):
- ArComp(tar_file, ext='.tar').unpack(dest=tmp_path)
+ with hide_binary("tar"), chdir(tmp_path):
+ with pytest.raises(ArCompError, match="required binary not found"):
+ ArComp(tar_file, ext=".tar").unpack(dest=tmp_path)
def test_tar(self, tmp_path, tar_file):
with chdir(tmp_path):
- ArComp(tar_file, ext='.tar').unpack(dest=tmp_path)
- assert (tmp_path / 'file1').read_text() == 'Hello world'
- assert (tmp_path / 'file2').read_text() == 'Larry the Cow'
+ ArComp(tar_file, ext=".tar").unpack(dest=tmp_path)
+ assert (tmp_path / "file1").read_text() == "Hello world"
+ assert (tmp_path / "file2").read_text() == "Larry the Cow"
def test_tar_bz2(self, tmp_path, tar_bz2_file):
with chdir(tmp_path):
- ArComp(tar_bz2_file, ext='.tar.bz2').unpack(dest=tmp_path)
- assert (tmp_path / 'file1').read_text() == 'Hello world'
- assert (tmp_path / 'file2').read_text() == 'Larry the Cow'
+ ArComp(tar_bz2_file, ext=".tar.bz2").unpack(dest=tmp_path)
+ assert (tmp_path / "file1").read_text() == "Hello world"
+ assert (tmp_path / "file2").read_text() == "Larry the Cow"
def test_tbz2(self, tmp_path, tbz2_file):
with chdir(tmp_path):
- ArComp(tbz2_file, ext='.tbz2').unpack(dest=tmp_path)
- assert (tmp_path / 'file1').read_text() == 'Hello world'
- assert (tmp_path / 'file2').read_text() == 'Larry the Cow'
+ ArComp(tbz2_file, ext=".tbz2").unpack(dest=tmp_path)
+ assert (tmp_path / "file1").read_text() == "Hello world"
+ assert (tmp_path / "file2").read_text() == "Larry the Cow"
def test_fallback_tbz2(self, tmp_path, tbz2_file):
with hide_binary(*next(zip(*_TarBZ2.compress_binary[:-1]))):
with chdir(tmp_path):
- ArComp(tbz2_file, ext='.tbz2').unpack(dest=tmp_path)
- assert (tmp_path / 'file1').read_text() == 'Hello world'
- assert (tmp_path / 'file2').read_text() == 'Larry the Cow'
+ ArComp(tbz2_file, ext=".tbz2").unpack(dest=tmp_path)
+ assert (tmp_path / "file1").read_text() == "Hello world"
+ assert (tmp_path / "file2").read_text() == "Larry the Cow"
def test_no_fallback_tbz2(self, tmp_path, tbz2_file):
with hide_binary(*next(zip(*_TarBZ2.compress_binary))), chdir(tmp_path):
- with pytest.raises(ArCompError, match='no compression binary'):
- ArComp(tbz2_file, ext='.tbz2').unpack(dest=tmp_path)
+ with pytest.raises(ArCompError, match="no compression binary"):
+ ArComp(tbz2_file, ext=".tbz2").unpack(dest=tmp_path)
def test_lzma(self, tmp_path, lzma_file):
- dest = tmp_path / 'file'
+ dest = tmp_path / "file"
with chdir(tmp_path):
- ArComp(lzma_file, ext='.lzma').unpack(dest=dest)
- assert (dest).read_bytes() == b'Hello world'
+ ArComp(lzma_file, ext=".lzma").unpack(dest=dest)
+ assert (dest).read_bytes() == b"Hello world"
diff --git a/tests/compression/test_xz.py b/tests/compression/test_xz.py
index f8417b30..0af7c645 100644
--- a/tests/compression/test_xz.py
+++ b/tests/compression/test_xz.py
@@ -10,26 +10,26 @@ from . import Base, hide_binary
def test_no_native():
- with hide_imports('lzma'):
+ with hide_imports("lzma"):
importlib.reload(_xz)
assert not _xz.native
def test_missing_xz_binary():
- with hide_binary('xz'):
- with pytest.raises(CommandNotFound, match='xz'):
+ with hide_binary("xz"):
+ with pytest.raises(CommandNotFound, match="xz"):
importlib.reload(_xz)
class XzBase(Base):
- module = 'xz'
- decompressed_test_data = b'Some text here\n' * 2
+ module = "xz"
+ decompressed_test_data = b"Some text here\n" * 2
compressed_test_data = (
- b'\xfd7zXZ\x00\x00\x04\xe6\xd6\xb4F\x04\xc0\x1e\x1e!\x01\x16\x00\x00\x00'
- b'\x00\x00\x00\x00\x00\x00j\xf6\x947\xe0\x00\x1d\x00\x16]\x00)\x9b\xc9\xa6g'
- b'Bw\x8c\xb3\x9eA\x9a\xbeT\xc9\xfa\xe3\x19\x8f(\x00\x00\x00\x00\x00\x96N'
- b'\xa8\x8ed\xa2WH\x00\x01:\x1e1V \xff\x1f\xb6\xf3}\x01\x00\x00\x00\x00\x04YZ'
+ b"\xfd7zXZ\x00\x00\x04\xe6\xd6\xb4F\x04\xc0\x1e\x1e!\x01\x16\x00\x00\x00"
+ b"\x00\x00\x00\x00\x00\x00j\xf6\x947\xe0\x00\x1d\x00\x16]\x00)\x9b\xc9\xa6g"
+ b"Bw\x8c\xb3\x9eA\x9a\xbeT\xc9\xfa\xe3\x19\x8f(\x00\x00\x00\x00\x00\x96N"
+ b"\xa8\x8ed\xa2WH\x00\x01:\x1e1V \xff\x1f\xb6\xf3}\x01\x00\x00\x00\x00\x04YZ"
)
def decompress(self, data: bytes) -> bytes:
@@ -37,20 +37,18 @@ class XzBase(Base):
class TestStdlib(XzBase):
-
- @pytest.fixture(autouse=True, scope='class')
+ @pytest.fixture(autouse=True, scope="class")
def _setup(self):
try:
- find_binary('xz')
+ find_binary("xz")
except CommandNotFound:
- pytest.skip('xz binary not found')
+ pytest.skip("xz binary not found")
importlib.reload(_xz)
class TestXz(XzBase):
-
- @pytest.fixture(autouse=True, scope='class')
+ @pytest.fixture(autouse=True, scope="class")
def _setup(self):
- with hide_imports('lzma'):
+ with hide_imports("lzma"):
importlib.reload(_xz)
yield
diff --git a/tests/test_bash.py b/tests/test_bash.py
index ec9df537..3d2157b8 100644
--- a/tests/test_bash.py
+++ b/tests/test_bash.py
@@ -1,128 +1,138 @@
from io import StringIO
import pytest
-from snakeoil.bash import (BashParseError, iter_read_bash, read_bash,
- read_bash_dict, read_dict)
+from snakeoil.bash import (
+ BashParseError,
+ iter_read_bash,
+ read_bash,
+ read_bash_dict,
+ read_dict,
+)
class TestBashCommentStripping:
-
def test_iter_read_bash(self):
- output = iter_read_bash(StringIO(
- '\n'
- '# hi I am a comment\n'
- 'I am not \n'
- ' asdf # inline comment\n'))
- assert list(output) == ['I am not', 'asdf']
+ output = iter_read_bash(
+ StringIO(
+ "\n" "# hi I am a comment\n" "I am not \n" " asdf # inline comment\n"
+ )
+ )
+ assert list(output) == ["I am not", "asdf"]
- output = iter_read_bash(StringIO(
- 'inline # comment '), allow_inline_comments=False)
- assert list(output) == ['inline # comment']
+ output = iter_read_bash(
+ StringIO("inline # comment "), allow_inline_comments=False
+ )
+ assert list(output) == ["inline # comment"]
def test_iter_read_bash_line_cont(self):
- output = iter_read_bash(StringIO(
- '\n'
- '# hi I am a comment\\\n'
- 'I am not \\\n'
- 'a comment \n'
- ' asdf # inline comment\\\n'),
- allow_line_cont=True)
- assert list(output) == ['I am not a comment', 'asdf']
+ output = iter_read_bash(
+ StringIO(
+ "\n"
+ "# hi I am a comment\\\n"
+ "I am not \\\n"
+ "a comment \n"
+ " asdf # inline comment\\\n"
+ ),
+ allow_line_cont=True,
+ )
+ assert list(output) == ["I am not a comment", "asdf"]
# continuation into inline comment
- output = iter_read_bash(StringIO(
- '\n'
- '# hi I am a comment\n'
- 'I am \\\n'
- 'not a \\\n'
- 'comment # inline comment\n'),
- allow_line_cont=True)
- assert list(output) == ['I am not a comment']
+ output = iter_read_bash(
+ StringIO(
+ "\n"
+ "# hi I am a comment\n"
+ "I am \\\n"
+ "not a \\\n"
+ "comment # inline comment\n"
+ ),
+ allow_line_cont=True,
+ )
+ assert list(output) == ["I am not a comment"]
# ends with continuation
- output = iter_read_bash(StringIO(
- '\n'
- '# hi I am a comment\n'
- 'I am \\\n'
- '\\\n'
- 'not a \\\n'
- 'comment\\\n'
- '\\\n'),
- allow_line_cont=True)
- assert list(output) == ['I am not a comment']
+ output = iter_read_bash(
+ StringIO(
+ "\n"
+ "# hi I am a comment\n"
+ "I am \\\n"
+ "\\\n"
+ "not a \\\n"
+ "comment\\\n"
+ "\\\n"
+ ),
+ allow_line_cont=True,
+ )
+ assert list(output) == ["I am not a comment"]
# embedded comment prefix via continued lines
- output = iter_read_bash(StringIO(
- '\\\n'
- '# comment\\\n'
- ' not a comment\n'
- '\\\n'
- ' # inner comment\n'
- 'also not\\\n'
- '#\\\n'
- 'a comment\n'),
- allow_line_cont=True)
- assert list(output) == ['not a comment', 'also not#a comment']
+ output = iter_read_bash(
+ StringIO(
+ "\\\n"
+ "# comment\\\n"
+ " not a comment\n"
+ "\\\n"
+ " # inner comment\n"
+ "also not\\\n"
+ "#\\\n"
+ "a comment\n"
+ ),
+ allow_line_cont=True,
+ )
+ assert list(output) == ["not a comment", "also not#a comment"]
# Line continuations have to end with \<newline> without any backslash
# before the pattern.
- output = iter_read_bash(StringIO(
- 'I am \\ \n'
- 'not a comment'),
- allow_line_cont=True)
- assert list(output) == ['I am \\', 'not a comment']
- output = iter_read_bash(StringIO(
- '\\\n'
- 'I am \\\\\n'
- 'not a comment'),
- allow_line_cont=True)
- assert list(output) == ['I am \\\\', 'not a comment']
+ output = iter_read_bash(
+ StringIO("I am \\ \n" "not a comment"), allow_line_cont=True
+ )
+ assert list(output) == ["I am \\", "not a comment"]
+ output = iter_read_bash(
+ StringIO("\\\n" "I am \\\\\n" "not a comment"), allow_line_cont=True
+ )
+ assert list(output) == ["I am \\\\", "not a comment"]
def test_read_bash(self):
- output = read_bash(StringIO(
- '\n'
- '# hi I am a comment\n'
- 'I am not\n'))
- assert output == ['I am not']
+ output = read_bash(StringIO("\n" "# hi I am a comment\n" "I am not\n"))
+ assert output == ["I am not"]
class TestReadDictConfig:
-
def test_read_dict(self):
- bash_dict = read_dict(StringIO(
- '\n'
- '# hi I am a comment\n'
- 'foo1=bar\n'
- 'foo2="bar"\n'
- 'foo3=\'bar"\n'))
+ bash_dict = read_dict(
+ StringIO(
+ "\n" "# hi I am a comment\n" "foo1=bar\n" 'foo2="bar"\n' "foo3='bar\"\n"
+ )
+ )
assert bash_dict == {
- 'foo1': 'bar',
- 'foo2': 'bar',
- 'foo3': '\'bar"',
- }
- assert read_dict(['foo=bar'], source_isiter=True) == {'foo': 'bar'}
+ "foo1": "bar",
+ "foo2": "bar",
+ "foo3": "'bar\"",
+ }
+ assert read_dict(["foo=bar"], source_isiter=True) == {"foo": "bar"}
with pytest.raises(BashParseError):
- read_dict(['invalid'], source_isiter=True)
+ read_dict(["invalid"], source_isiter=True)
- bash_dict = read_dict(StringIO("foo bar\nfoo2 bar\nfoo3\tbar\n"), splitter=None)
- assert bash_dict == dict.fromkeys(('foo', 'foo2', 'foo3'), 'bar')
- bash_dict = read_dict(['foo = blah', 'foo2= blah ', 'foo3=blah'], strip=True)
- assert bash_dict == dict.fromkeys(('foo', 'foo2', 'foo3'), 'blah')
+ bash_dict = read_dict(
+ StringIO("foo bar\nfoo2 bar\nfoo3\tbar\n"), splitter=None
+ )
+ assert bash_dict == dict.fromkeys(("foo", "foo2", "foo3"), "bar")
+ bash_dict = read_dict(["foo = blah", "foo2= blah ", "foo3=blah"], strip=True)
+ assert bash_dict == dict.fromkeys(("foo", "foo2", "foo3"), "blah")
class TestReadBashDict:
-
@pytest.fixture(autouse=True)
def _setup(self, tmp_path):
self.valid_file = tmp_path / "valid"
self.valid_file.write_text(
- '# hi I am a comment\n'
- 'foo1=bar\n'
+ "# hi I am a comment\n"
+ "foo1=bar\n"
"foo2='bar'\n"
'foo3="bar"\n'
- 'foo4=-/:j4\n'
- 'foo5=\n'
+ "foo4=-/:j4\n"
+ "foo5=\n"
'export foo6="bar"\n'
)
self.sourcing_file = tmp_path / "sourcing"
@@ -131,18 +141,13 @@ class TestReadBashDict:
self.sourcing_file2.write_text(f'source "{self.valid_file}"\n')
self.advanced_file = tmp_path / "advanced"
self.advanced_file.write_text(
- 'one1=1\n'
- 'one_=$one1\n'
- 'two1=2\n'
- 'two_=${two1}\n'
+ "one1=1\n" "one_=$one1\n" "two1=2\n" "two_=${two1}\n"
)
self.env_file = tmp_path / "env"
- self.env_file.write_text('imported=${external}\n')
+ self.env_file.write_text("imported=${external}\n")
self.escaped_file = tmp_path / "escaped"
self.escaped_file.write_text(
- 'end=bye\n'
- 'quoteddollar="\\${dollar}"\n'
- 'quotedexpansion="\\${${end}}"\n'
+ "end=bye\n" 'quoteddollar="\\${dollar}"\n' 'quotedexpansion="\\${${end}}"\n'
)
self.unclosed_file = tmp_path / "unclosed"
self.unclosed_file.write_text('foo="bar')
@@ -151,19 +156,19 @@ class TestReadBashDict:
try:
return read_bash_dict(handle, *args, **kwds)
finally:
- if hasattr(handle, 'close'):
+ if hasattr(handle, "close"):
handle.close()
def test_read_bash_dict(self):
# TODO this is not even close to complete
bash_dict = self.invoke_and_close(str(self.valid_file))
d = {
- 'foo1': 'bar',
- 'foo2': 'bar',
- 'foo3': 'bar',
- 'foo4': '-/:j4',
- 'foo5': '',
- 'foo6': 'bar',
+ "foo1": "bar",
+ "foo2": "bar",
+ "foo3": "bar",
+ "foo4": "-/:j4",
+ "foo5": "",
+ "foo6": "bar",
}
assert bash_dict == d
@@ -171,59 +176,81 @@ class TestReadBashDict:
self.invoke_and_close(StringIO("a=b\ny='"))
def test_var_read(self):
- assert self.invoke_and_close(StringIO("x=y@a\n")) == {'x': 'y@a'}
- assert self.invoke_and_close(StringIO("x=y~a\n")) == {'x': 'y~a'}
- assert self.invoke_and_close(StringIO("x=y^a\n")) == {'x': 'y^a'}
- assert self.invoke_and_close(StringIO('x="\nasdf\nfdsa"')) == {'x': '\nasdf\nfdsa'}
+ assert self.invoke_and_close(StringIO("x=y@a\n")) == {"x": "y@a"}
+ assert self.invoke_and_close(StringIO("x=y~a\n")) == {"x": "y~a"}
+ assert self.invoke_and_close(StringIO("x=y^a\n")) == {"x": "y^a"}
+ assert self.invoke_and_close(StringIO('x="\nasdf\nfdsa"')) == {
+ "x": "\nasdf\nfdsa"
+ }
def test_empty_assign(self):
self.valid_file.write_text("foo=\ndar=blah\n")
- assert self.invoke_and_close(str(self.valid_file)) == {'foo': '', 'dar': 'blah'}
+ assert self.invoke_and_close(str(self.valid_file)) == {"foo": "", "dar": "blah"}
self.valid_file.write_text("foo=\ndar=\n")
- assert self.invoke_and_close(str(self.valid_file)) == {'foo': '', 'dar': ''}
+ assert self.invoke_and_close(str(self.valid_file)) == {"foo": "", "dar": ""}
self.valid_file.write_text("foo=blah\ndar=\n")
- assert self.invoke_and_close(str(self.valid_file)) == {'foo': 'blah', 'dar': ''}
+ assert self.invoke_and_close(str(self.valid_file)) == {"foo": "blah", "dar": ""}
def test_quoting(self):
- assert self.invoke_and_close(StringIO("x='y \\\na'")) == {'x': 'y \\\na'}
- assert self.invoke_and_close(StringIO("x='y'a\n")) == {'x': "ya"}
- assert self.invoke_and_close(StringIO('x="y \\\nasdf"')) == {'x': 'y asdf'}
+ assert self.invoke_and_close(StringIO("x='y \\\na'")) == {"x": "y \\\na"}
+ assert self.invoke_and_close(StringIO("x='y'a\n")) == {"x": "ya"}
+ assert self.invoke_and_close(StringIO('x="y \\\nasdf"')) == {"x": "y asdf"}
def test_eof_without_newline(self):
- assert self.invoke_and_close(StringIO("x=y")) == {'x': 'y'}
- assert self.invoke_and_close(StringIO("x='y'a")) == {'x': 'ya'}
+ assert self.invoke_and_close(StringIO("x=y")) == {"x": "y"}
+ assert self.invoke_and_close(StringIO("x='y'a")) == {"x": "ya"}
def test_sourcing(self):
- output = self.invoke_and_close(str(self.sourcing_file), sourcing_command='source')
- expected = {'foo1': 'bar', 'foo2': 'bar', 'foo3': 'bar', 'foo4': '-/:j4', 'foo5': '', 'foo6': 'bar'}
+ output = self.invoke_and_close(
+ str(self.sourcing_file), sourcing_command="source"
+ )
+ expected = {
+ "foo1": "bar",
+ "foo2": "bar",
+ "foo3": "bar",
+ "foo4": "-/:j4",
+ "foo5": "",
+ "foo6": "bar",
+ }
assert output == expected
- output = self.invoke_and_close(str(self.sourcing_file2), sourcing_command='source')
- expected = {'foo1': 'bar', 'foo2': 'bar', 'foo3': 'bar', 'foo4': '-/:j4', 'foo5': '', 'foo6': 'bar'}
+ output = self.invoke_and_close(
+ str(self.sourcing_file2), sourcing_command="source"
+ )
+ expected = {
+ "foo1": "bar",
+ "foo2": "bar",
+ "foo3": "bar",
+ "foo4": "-/:j4",
+ "foo5": "",
+ "foo6": "bar",
+ }
assert output == expected
def test_read_advanced(self):
output = self.invoke_and_close(str(self.advanced_file))
expected = {
- 'one1': '1',
- 'one_': '1',
- 'two1': '2',
- 'two_': '2',
+ "one1": "1",
+ "one_": "1",
+ "two1": "2",
+ "two_": "2",
}
assert output == expected
def test_env(self):
- assert self.invoke_and_close(str(self.env_file)) == {'imported': ''}
- env = {'external': 'imported foo'}
+ assert self.invoke_and_close(str(self.env_file)) == {"imported": ""}
+ env = {"external": "imported foo"}
env_backup = env.copy()
- assert self.invoke_and_close(str(self.env_file), env) == {'imported': 'imported foo'}
+ assert self.invoke_and_close(str(self.env_file), env) == {
+ "imported": "imported foo"
+ }
assert env_backup == env
def test_escaping(self):
output = self.invoke_and_close(str(self.escaped_file))
expected = {
- 'end': 'bye',
- 'quoteddollar': '${dollar}',
- 'quotedexpansion': '${bye}',
+ "end": "bye",
+ "quoteddollar": "${dollar}",
+ "quotedexpansion": "${bye}",
}
assert output == expected
diff --git a/tests/test_caching.py b/tests/test_caching.py
index eaa5014c..06615d38 100644
--- a/tests/test_caching.py
+++ b/tests/test_caching.py
@@ -5,17 +5,20 @@ from snakeoil.caching import WeakInstMeta
class weak_slotted(metaclass=WeakInstMeta):
__inst_caching__ = True
- __slots__ = ('one',)
+ __slots__ = ("one",)
class weak_inst(metaclass=WeakInstMeta):
__inst_caching__ = True
counter = 0
+
def __new__(cls, *args, **kwargs):
cls.counter += 1
return object.__new__(cls)
+
def __init__(self, *args, **kwargs):
pass
+
@classmethod
def reset(cls):
cls.counter = 0
@@ -34,7 +37,6 @@ class reenabled_weak_inst(automatic_disabled_weak_inst):
class TestWeakInstMeta:
-
def test_reuse(self, kls=weak_inst):
kls.reset()
o = kls()
@@ -99,8 +101,8 @@ class TestWeakInstMeta:
# (RaisingHashFor...).
# UserWarning is ignored and everything other warning is an error.
- @pytest.mark.filterwarnings('ignore::UserWarning')
- @pytest.mark.filterwarnings('error')
+ @pytest.mark.filterwarnings("ignore::UserWarning")
+ @pytest.mark.filterwarnings("error")
def test_uncachable(self):
weak_inst.reset()
@@ -108,21 +110,24 @@ class TestWeakInstMeta:
class RaisingHashForTestUncachable:
def __init__(self, error):
self.error = error
+
def __hash__(self):
raise self.error
assert weak_inst([]) is not weak_inst([])
assert weak_inst.counter == 2
for x in (TypeError, NotImplementedError):
- assert weak_inst(RaisingHashForTestUncachable(x)) is not \
- weak_inst(RaisingHashForTestUncachable(x))
+ assert weak_inst(RaisingHashForTestUncachable(x)) is not weak_inst(
+ RaisingHashForTestUncachable(x)
+ )
- @pytest.mark.filterwarnings('error::UserWarning')
+ @pytest.mark.filterwarnings("error::UserWarning")
def test_uncachable_warning_msg(self):
# This name is *important*, see above.
class RaisingHashForTestUncachableWarnings:
def __init__(self, error):
self.error = error
+
def __hash__(self):
raise self.error
@@ -134,6 +139,7 @@ class TestWeakInstMeta:
class BrokenHash:
def __hash__(self):
return 1
+
assert weak_inst(BrokenHash()) is not weak_inst(BrokenHash())
def test_weak_slot(self):
@@ -148,7 +154,7 @@ class TestWeakInstMeta:
# The actual test is that the class definition works.
class ExistingWeakrefSlot:
__inst_caching__ = True
- __slots__ = ('one', '__weakref__')
+ __slots__ = ("one", "__weakref__")
assert ExistingWeakrefSlot()
diff --git a/tests/test_chksum.py b/tests/test_chksum.py
index 016e3f73..b4c1ab24 100644
--- a/tests/test_chksum.py
+++ b/tests/test_chksum.py
@@ -3,15 +3,16 @@ from snakeoil import chksum
class Test_funcs:
-
def setup_method(self, method):
chksum.__inited__ = False
chksum.chksum_types.clear()
self._saved_init = chksum.init
self._inited_count = 0
+
def f():
self._inited_count += 1
chksum.__inited__ = True
+
chksum.init = f
# ensure we aren't mangling chksum state for other tests.
@@ -41,4 +42,3 @@ class Test_funcs:
assert chksum.get_handler("x") == 1
assert chksum.get_handler("y") == 2
assert self._inited_count == 1
-
diff --git a/tests/test_chksum_defaults.py b/tests/test_chksum_defaults.py
index 7f867d8a..a22d339c 100644
--- a/tests/test_chksum_defaults.py
+++ b/tests/test_chksum_defaults.py
@@ -14,14 +14,15 @@ def require_chf(func):
def subfunc(self):
if self.chf is None:
pytest.skip(
- 'no handler for %s, do you need to install PyCrypto or mhash?'
- % (self.chf_type,))
+ "no handler for %s, do you need to install PyCrypto or mhash?"
+ % (self.chf_type,)
+ )
func(self)
+
return subfunc
class base:
-
def get_chf(self):
try:
self.chf = chksum.get_handler(self.chf_type)
@@ -53,14 +54,17 @@ class base:
@require_chf
def test_data_source_check(self):
assert self.chf(local_source(self.fn)) == self.expected_long
- assert self.chf(data_source(fileutils.readfile_ascii(self.fn))) == self.expected_long
+ assert (
+ self.chf(data_source(fileutils.readfile_ascii(self.fn)))
+ == self.expected_long
+ )
-class ChksumTest(base):
+class ChksumTest(base):
@require_chf
def test_str2long(self):
assert self.chf.str2long(self.expected_str) == self.expected_long
- if self.chf_type == 'size':
+ if self.chf_type == "size":
return
for x in extra_chksums.get(self.chf_type, ()):
assert self.chf.str2long(x) == int(x, 16)
@@ -68,11 +72,12 @@ class ChksumTest(base):
@require_chf
def test_long2str(self):
assert self.chf.long2str(self.expected_long) == self.expected_str
- if self.chf_type == 'size':
+ if self.chf_type == "size":
return
for x in extra_chksums.get(self.chf_type, ()):
assert self.chf.long2str(int(x == 16)), x
+
checksums = {
"rmd160": "b83ad488d624e7911f886420ab230f78f6368b9f",
"sha1": "63cd8cce8a1773dffb400ee184be3ec7d89791f5",
@@ -87,22 +92,22 @@ checksums = {
checksums.update((k, (int(v, 16), v)) for k, v in checksums.items())
checksums["size"] = (int(len(data) * multi), str(int(len(data) * multi)))
-extra_chksums = {
- "md5":
- ["2dfd84279314a178d0fa842af3a40e25577e1bc"]
-}
+extra_chksums = {"md5": ["2dfd84279314a178d0fa842af3a40e25577e1bc"]}
for k, v in checksums.items():
- extra_chksums.setdefault(k, []).extend((''.rjust(len(v[1]), '0'), '01'.rjust(len(v[1]), '0')))
+ extra_chksums.setdefault(k, []).extend(
+ ("".rjust(len(v[1]), "0"), "01".rjust(len(v[1]), "0"))
+ )
# trick: create subclasses for each checksum with a useful class name.
for chf_type, expected in checksums.items():
expectedsum = expected[0]
expectedstr = expected[1]
- globals()['TestChksum' + chf_type.capitalize()] = type(
- 'TestChksum' + chf_type.capitalize(),
+ globals()["TestChksum" + chf_type.capitalize()] = type(
+ "TestChksum" + chf_type.capitalize(),
(ChksumTest,),
- dict(chf_type=chf_type, expected_long=expectedsum, expected_str=expectedstr))
+ dict(chf_type=chf_type, expected_long=expectedsum, expected_str=expectedstr),
+ )
# pylint: disable=undefined-loop-variable
del chf_type, expected
@@ -110,7 +115,7 @@ del chf_type, expected
class TestGetChksums(base):
- chfs = [k for k in sorted(checksums) if k in ('md5', 'sha1')]
+ chfs = [k for k in sorted(checksums) if k in ("md5", "sha1")]
expected_long = [checksums[k][0] for k in chfs]
def get_chf(self):
diff --git a/tests/test_constraints.py b/tests/test_constraints.py
index 5e938a9e..15d360c2 100644
--- a/tests/test_constraints.py
+++ b/tests/test_constraints.py
@@ -2,61 +2,75 @@ import pytest
from snakeoil.constraints import Problem
+
def any_of(**kwargs):
return any(kwargs.values())
+
def all_of(**kwargs):
return all(kwargs.values())
+
def test_readd_variables():
p = Problem()
- p.add_variable((True, False), 'x', 'y')
+ p.add_variable((True, False), "x", "y")
with pytest.raises(AssertionError, match="variable 'y' was already added"):
- p.add_variable((True, False), 'y', 'z')
+ p.add_variable((True, False), "y", "z")
+
def test_constraint_unknown_variable():
p = Problem()
- p.add_variable((True, False), 'x', 'y')
+ p.add_variable((True, False), "x", "y")
with pytest.raises(AssertionError, match="unknown variable 'z'"):
- p.add_constraint(any_of, ('y', 'z'))
+ p.add_constraint(any_of, ("y", "z"))
+
def test_empty_problem():
p = Problem()
- assert tuple(p) == ({}, )
+ assert tuple(p) == ({},)
+
def test_empty_constraints():
p = Problem()
- p.add_variable((True, False), 'x', 'y')
- p.add_variable((True, ), 'z')
+ p.add_variable((True, False), "x", "y")
+ p.add_variable((True,), "z")
assert len(tuple(p)) == 4
+
def test_domain_prefer_later():
p = Problem()
- p.add_variable((False, True), 'x', 'y')
- p.add_constraint(any_of, ('x', 'y'))
- assert next(iter(p)) == {'x': True, 'y': True}
+ p.add_variable((False, True), "x", "y")
+ p.add_constraint(any_of, ("x", "y"))
+ assert next(iter(p)) == {"x": True, "y": True}
+
def test_constraint_single_variable():
p = Problem()
- p.add_variable((True, False), 'x', 'y')
- p.add_constraint(lambda x: x, ('x', ))
- p.add_constraint(lambda y: not y, ('y', ))
- assert tuple(p) == ({'x': True, 'y': False}, )
+ p.add_variable((True, False), "x", "y")
+ p.add_constraint(lambda x: x, ("x",))
+ p.add_constraint(lambda y: not y, ("y",))
+ assert tuple(p) == ({"x": True, "y": False},)
+
def test_no_solution():
p = Problem()
- p.add_variable((True, ), 'x')
- p.add_variable((True, False), 'y', 'z')
- p.add_constraint(lambda x, y: not x or y, ('x', 'y'))
- p.add_constraint(lambda y, z: not y or not z, ('y', 'z'))
- p.add_constraint(lambda x, z: not x or z, ('x', 'z'))
+ p.add_variable((True,), "x")
+ p.add_variable((True, False), "y", "z")
+ p.add_constraint(lambda x, y: not x or y, ("x", "y"))
+ p.add_constraint(lambda y, z: not y or not z, ("y", "z"))
+ p.add_constraint(lambda x, z: not x or z, ("x", "z"))
assert not tuple(p)
+
def test_forward_check():
p = Problem()
- p.add_variable(range(2, 10), 'x', 'y', 'z')
- p.add_constraint(lambda x, y: (x + y) % 2 == 0, ('x', 'y'))
- p.add_constraint(lambda x, y, z: (x * y * z) % 2 != 0, ('x', 'y', 'z'))
- p.add_constraint(lambda y, z: y < z, ('y', 'z'))
- p.add_constraint(lambda z, x: x ** 2 <= z, ('x', 'z'))
- assert tuple(p) == ({'x': 3, 'y': 7, 'z': 9}, {'x': 3, 'y': 5, 'z': 9}, {'x': 3, 'y': 3, 'z': 9})
+ p.add_variable(range(2, 10), "x", "y", "z")
+ p.add_constraint(lambda x, y: (x + y) % 2 == 0, ("x", "y"))
+ p.add_constraint(lambda x, y, z: (x * y * z) % 2 != 0, ("x", "y", "z"))
+ p.add_constraint(lambda y, z: y < z, ("y", "z"))
+ p.add_constraint(lambda z, x: x**2 <= z, ("x", "z"))
+ assert tuple(p) == (
+ {"x": 3, "y": 7, "z": 9},
+ {"x": 3, "y": 5, "z": 9},
+ {"x": 3, "y": 3, "z": 9},
+ )
diff --git a/tests/test_containers.py b/tests/test_containers.py
index f6940c90..9df32588 100644
--- a/tests/test_containers.py
+++ b/tests/test_containers.py
@@ -5,7 +5,6 @@ from snakeoil import containers
class TestInvertedContains:
-
def setup_method(self, method):
self.set = containers.InvertedContains(range(12))
@@ -17,7 +16,7 @@ class TestInvertedContains:
class BasicSet(containers.SetMixin):
- __slots__ = ('_data',)
+ __slots__ = ("_data",)
def __init__(self, data):
self._data = set(data)
@@ -28,7 +27,7 @@ class BasicSet(containers.SetMixin):
def __contains__(self, other):
return other in self._data
- #def __str__(self):
+ # def __str__(self):
# return 'BasicSet([%s])' % ', '.join((str(x) for x in self._data))
def __eq__(self, other):
@@ -43,7 +42,6 @@ class BasicSet(containers.SetMixin):
class TestSetMethods:
-
def test_and(self):
c = BasicSet(range(100))
s = set(range(25, 75))
@@ -80,8 +78,8 @@ class TestSetMethods:
assert c - s == r1
assert s - c == r2
-class TestLimitedChangeSet:
+class TestLimitedChangeSet:
def setup_method(self, method):
self.set = containers.LimitedChangeSet(range(12))
@@ -89,17 +87,18 @@ class TestLimitedChangeSet:
def f(val):
assert isinstance(val, int)
return val
+
self.set = containers.LimitedChangeSet(range(12), key_validator=f)
self.set.add(13)
self.set.add(14)
self.set.remove(11)
assert 5 in self.set
with pytest.raises(AssertionError):
- self.set.add('2')
+ self.set.add("2")
with pytest.raises(AssertionError):
- self.set.remove('2')
+ self.set.remove("2")
with pytest.raises(AssertionError):
- self.set.__contains__('2')
+ self.set.__contains__("2")
def test_basic(self, changes=0):
# this should be a no-op
@@ -188,7 +187,7 @@ class TestLimitedChangeSet:
assert sorted(list(self.set)) == list(range(-1, 13))
def test_str(self):
- assert str(containers.LimitedChangeSet([7])) == 'LimitedChangeSet([7])'
+ assert str(containers.LimitedChangeSet([7])) == "LimitedChangeSet([7])"
def test__eq__(self):
c = containers.LimitedChangeSet(range(99))
@@ -199,7 +198,6 @@ class TestLimitedChangeSet:
class TestLimitedChangeSetWithBlacklist:
-
def setup_method(self, method):
self.set = containers.LimitedChangeSet(range(12), [3, 13])
@@ -222,7 +220,6 @@ class TestLimitedChangeSetWithBlacklist:
class TestProtectedSet:
-
def setup_method(self, method):
self.set = containers.ProtectedSet(set(range(12)))
diff --git a/tests/test_contexts.py b/tests/test_contexts.py
index 219d5ee8..be212a23 100644
--- a/tests/test_contexts.py
+++ b/tests/test_contexts.py
@@ -44,9 +44,10 @@ def test_syspath(tmpdir):
assert mangled_syspath == tuple(sys.path)
-@pytest.mark.skip(reason='this currently breaks on github ci, https://github.com/pkgcore/snakeoil/issues/68')
+@pytest.mark.skip(
+ reason="this currently breaks on github ci, https://github.com/pkgcore/snakeoil/issues/68"
+)
class TestSplitExec:
-
def test_context_process(self):
# code inside the with statement is run in a separate process
pid = os.getpid()
@@ -77,9 +78,9 @@ class TestSplitExec:
b = 3
# changes to locals aren't propagated back
assert a == 1
- assert 'b' not in locals()
+ assert "b" not in locals()
# but they're accessible via the 'locals' attr
- expected = {'a': 2, 'b': 3}
+ expected = {"a": 2, "b": 3}
for k, v in expected.items():
assert c.locals[k] == v
@@ -87,20 +88,21 @@ class TestSplitExec:
with SplitExec() as c:
func = lambda x: x
from sys import implementation
+
a = 4
- assert c.locals == {'a': 4}
+ assert c.locals == {"a": 4}
def test_context_exceptions(self):
# exceptions in the child process are sent back to the parent and re-raised
with pytest.raises(IOError) as e:
with SplitExec() as c:
- raise IOError(errno.EBUSY, 'random error')
+ raise IOError(errno.EBUSY, "random error")
assert e.value.errno == errno.EBUSY
def test_child_setup_raises_exception(self):
class ChildSetupException(SplitExec):
def _child_setup(self):
- raise IOError(errno.EBUSY, 'random error')
+ raise IOError(errno.EBUSY, "random error")
with pytest.raises(IOError) as e:
with ChildSetupException() as c:
@@ -108,26 +110,33 @@ class TestSplitExec:
assert e.value.errno == errno.EBUSY
-@pytest.mark.skipif(not sys.platform.startswith('linux'), reason='supported on Linux only')
-@pytest.mark.xfail(platform.python_implementation() == "PyPy", reason='Fails on PyPy')
+@pytest.mark.skipif(
+ not sys.platform.startswith("linux"), reason="supported on Linux only"
+)
+@pytest.mark.xfail(platform.python_implementation() == "PyPy", reason="Fails on PyPy")
class TestNamespace:
-
- @pytest.mark.skipif(not os.path.exists('/proc/self/ns/user'),
- reason='user namespace support required')
+ @pytest.mark.skipif(
+ not os.path.exists("/proc/self/ns/user"),
+ reason="user namespace support required",
+ )
def test_user_namespace(self):
try:
with Namespace(user=True) as ns:
assert os.getuid() == 0
except PermissionError:
- pytest.skip('No permission to use user namespace')
-
- @pytest.mark.skipif(not (os.path.exists('/proc/self/ns/user') and os.path.exists('/proc/self/ns/uts')),
- reason='user and uts namespace support required')
+ pytest.skip("No permission to use user namespace")
+
+ @pytest.mark.skipif(
+ not (
+ os.path.exists("/proc/self/ns/user") and os.path.exists("/proc/self/ns/uts")
+ ),
+ reason="user and uts namespace support required",
+ )
def test_uts_namespace(self):
try:
- with Namespace(user=True, uts=True, hostname='host') as ns:
- ns_hostname, _, ns_domainname = socket.getfqdn().partition('.')
- assert ns_hostname == 'host'
- assert ns_domainname == ''
+ with Namespace(user=True, uts=True, hostname="host") as ns:
+ ns_hostname, _, ns_domainname = socket.getfqdn().partition(".")
+ assert ns_hostname == "host"
+ assert ns_domainname == ""
except PermissionError:
- pytest.skip('No permission to use user and uts namespace')
+ pytest.skip("No permission to use user and uts namespace")
diff --git a/tests/test_currying.py b/tests/test_currying.py
index d14f0fcc..7f8618fc 100644
--- a/tests/test_currying.py
+++ b/tests/test_currying.py
@@ -5,8 +5,10 @@ from snakeoil import currying
def passthrough(*args, **kwargs):
return args, kwargs
+
# docstring is part of the test
+
def documented():
"""original docstring"""
@@ -18,36 +20,37 @@ class TestPreCurry:
def test_pre_curry(self):
noop = self.pre_curry(passthrough)
assert noop() == ((), {})
- assert noop('foo', 'bar') == (('foo', 'bar'), {})
- assert noop(foo='bar') == ((), {'foo': 'bar'})
- assert noop('foo', bar='baz') == (('foo',), {'bar': 'baz'})
+ assert noop("foo", "bar") == (("foo", "bar"), {})
+ assert noop(foo="bar") == ((), {"foo": "bar"})
+ assert noop("foo", bar="baz") == (("foo",), {"bar": "baz"})
one_arg = self.pre_curry(passthrough, 42)
assert one_arg() == ((42,), {})
- assert one_arg('foo', 'bar') == ((42, 'foo', 'bar'), {})
- assert one_arg(foo='bar') == ((42,), {'foo': 'bar'})
- assert one_arg('foo', bar='baz') == ((42, 'foo'), {'bar': 'baz'})
+ assert one_arg("foo", "bar") == ((42, "foo", "bar"), {})
+ assert one_arg(foo="bar") == ((42,), {"foo": "bar"})
+ assert one_arg("foo", bar="baz") == ((42, "foo"), {"bar": "baz"})
keyword_arg = self.pre_curry(passthrough, foo=42)
- assert keyword_arg() == ((), {'foo': 42})
- assert keyword_arg('foo', 'bar') == (('foo', 'bar'), {'foo': 42})
- assert keyword_arg(foo='bar') == ((), {'foo': 'bar'})
- assert keyword_arg('foo', bar='baz') == (('foo',), {'bar': 'baz', 'foo': 42})
+ assert keyword_arg() == ((), {"foo": 42})
+ assert keyword_arg("foo", "bar") == (("foo", "bar"), {"foo": 42})
+ assert keyword_arg(foo="bar") == ((), {"foo": "bar"})
+ assert keyword_arg("foo", bar="baz") == (("foo",), {"bar": "baz", "foo": 42})
both = self.pre_curry(passthrough, 42, foo=42)
- assert both() == ((42,), {'foo': 42})
- assert both('foo', 'bar') == ((42, 'foo', 'bar'), {'foo': 42})
- assert both(foo='bar') == ((42,), {'foo': 'bar'})
- assert both('foo', bar='baz') == ((42, 'foo'), {'bar': 'baz', 'foo': 42})
+ assert both() == ((42,), {"foo": 42})
+ assert both("foo", "bar") == ((42, "foo", "bar"), {"foo": 42})
+ assert both(foo="bar") == ((42,), {"foo": "bar"})
+ assert both("foo", bar="baz") == ((42, "foo"), {"bar": "baz", "foo": 42})
def test_curry_original(self):
assert self.pre_curry(passthrough).func is passthrough
def test_instancemethod(self):
class Test:
- method = self.pre_curry(passthrough, 'test')
+ method = self.pre_curry(passthrough, "test")
+
test = Test()
- assert (('test', test), {}) == test.method()
+ assert (("test", test), {}) == test.method()
class Test_pretty_docs:
@@ -56,58 +59,63 @@ class Test_pretty_docs:
def test_module_magic(self):
for target in self.currying_targets:
- assert currying.pretty_docs(target(passthrough)).__module__ is \
- passthrough.__module__
+ assert (
+ currying.pretty_docs(target(passthrough)).__module__
+ is passthrough.__module__
+ )
# test is kinda useless if they are identical without pretty_docs
- assert getattr(target(passthrough), '__module__', None) is not \
- passthrough.__module__
+ assert (
+ getattr(target(passthrough), "__module__", None)
+ is not passthrough.__module__
+ )
def test_pretty_docs(self):
for target in self.currying_targets:
for func in (passthrough, documented):
- assert currying.pretty_docs(target(func), 'new doc').__doc__ == 'new doc'
+ assert (
+ currying.pretty_docs(target(func), "new doc").__doc__ == "new doc"
+ )
assert currying.pretty_docs(target(func)).__doc__ is func.__doc__
class TestPostCurry:
-
def test_post_curry(self):
noop = currying.post_curry(passthrough)
assert noop() == ((), {})
- assert noop('foo', 'bar') == (('foo', 'bar'), {})
- assert noop(foo='bar') == ((), {'foo': 'bar'})
- assert noop('foo', bar='baz') == (('foo',), {'bar': 'baz'})
+ assert noop("foo", "bar") == (("foo", "bar"), {})
+ assert noop(foo="bar") == ((), {"foo": "bar"})
+ assert noop("foo", bar="baz") == (("foo",), {"bar": "baz"})
one_arg = currying.post_curry(passthrough, 42)
assert one_arg() == ((42,), {})
- assert one_arg('foo', 'bar') == (('foo', 'bar', 42), {})
- assert one_arg(foo='bar') == ((42,), {'foo': 'bar'})
- assert one_arg('foo', bar='baz') == (('foo', 42), {'bar': 'baz'})
+ assert one_arg("foo", "bar") == (("foo", "bar", 42), {})
+ assert one_arg(foo="bar") == ((42,), {"foo": "bar"})
+ assert one_arg("foo", bar="baz") == (("foo", 42), {"bar": "baz"})
keyword_arg = currying.post_curry(passthrough, foo=42)
- assert keyword_arg() == ((), {'foo': 42})
- assert keyword_arg('foo', 'bar') == (('foo', 'bar'), {'foo': 42})
- assert keyword_arg(foo='bar') == ((), {'foo': 42})
- assert keyword_arg('foo', bar='baz') == (('foo',), {'bar': 'baz', 'foo': 42})
+ assert keyword_arg() == ((), {"foo": 42})
+ assert keyword_arg("foo", "bar") == (("foo", "bar"), {"foo": 42})
+ assert keyword_arg(foo="bar") == ((), {"foo": 42})
+ assert keyword_arg("foo", bar="baz") == (("foo",), {"bar": "baz", "foo": 42})
both = currying.post_curry(passthrough, 42, foo=42)
- assert both() == ((42,), {'foo': 42})
- assert both('foo', 'bar') == (('foo', 'bar', 42), {'foo': 42})
- assert both(foo='bar') == ((42,), {'foo': 42})
- assert both('foo', bar='baz') == (('foo', 42), {'bar': 'baz', 'foo': 42})
+ assert both() == ((42,), {"foo": 42})
+ assert both("foo", "bar") == (("foo", "bar", 42), {"foo": 42})
+ assert both(foo="bar") == ((42,), {"foo": 42})
+ assert both("foo", bar="baz") == (("foo", 42), {"bar": "baz", "foo": 42})
def test_curry_original(self):
assert currying.post_curry(passthrough).func is passthrough
def test_instancemethod(self):
class Test:
- method = currying.post_curry(passthrough, 'test')
+ method = currying.post_curry(passthrough, "test")
+
test = Test()
- assert ((test, 'test'), {}) == test.method()
+ assert ((test, "test"), {}) == test.method()
class Test_wrap_exception:
-
def test_wrap_exception_complex(self):
inner, outer = [], []
@@ -118,33 +126,33 @@ class Test_wrap_exception:
assert isinstance(exception, inner_exception)
assert functor is throwing_func
assert fargs == (False,)
- assert fkwds == {'monkey': 'bone'}
+ assert fkwds == {"monkey": "bone"}
outer.append(True)
raise wrapping_exception()
def throwing_func(*args, **kwds):
assert args == (False,)
- assert kwds == {'monkey': 'bone'}
+ assert kwds == {"monkey": "bone"}
inner.append(True)
raise inner_exception()
func = currying.wrap_exception_complex(f, IndexError)(throwing_func)
# basic behaviour
- pytest.raises(IndexError, func, False, monkey='bone')
+ pytest.raises(IndexError, func, False, monkey="bone")
assert len(inner) == 1
assert len(outer) == 1
# ensure pass thru if it's an allowed exception
inner_exception = IndexError
- pytest.raises(IndexError, func, False, monkey='bone')
+ pytest.raises(IndexError, func, False, monkey="bone")
assert len(inner) == 2
assert len(outer) == 1
# finally, ensure it doesn't intercept, and passes thru for
# exceptions it shouldn't handle
inner_exception = MemoryError
- pytest.raises(MemoryError, func, False, monkey='bone')
+ pytest.raises(MemoryError, func, False, monkey="bone")
assert len(inner) == 3
assert len(outer) == 1
@@ -159,9 +167,10 @@ class Test_wrap_exception:
self.args = args
self.kwds = kwds
- func = currying.wrap_exception(my_exception, 1, 3, 2, monkey='bone',
- ignores=ValueError)(throwing_func)
- assert func.__name__ == 'throwing_func'
+ func = currying.wrap_exception(
+ my_exception, 1, 3, 2, monkey="bone", ignores=ValueError
+ )(throwing_func)
+ assert func.__name__ == "throwing_func"
pytest.raises(ValueError, func)
throw_kls = IndexError
pytest.raises(my_exception, func)
@@ -170,17 +179,23 @@ class Test_wrap_exception:
raise AssertionError("shouldn't have been able to reach here")
except my_exception as e:
assert e.args == (1, 3, 2)
- assert e.kwds == {'monkey': 'bone'}
+ assert e.kwds == {"monkey": "bone"}
# finally, verify that the exception can be pased in.
func = currying.wrap_exception(
- my_exception, 1, 3, 2, monkey='bone',
- ignores=ValueError, pass_error="the_exception")(throwing_func)
- assert func.__name__ == 'throwing_func'
+ my_exception,
+ 1,
+ 3,
+ 2,
+ monkey="bone",
+ ignores=ValueError,
+ pass_error="the_exception",
+ )(throwing_func)
+ assert func.__name__ == "throwing_func"
pytest.raises(my_exception, func)
try:
func()
raise AssertionError("shouldn't have been able to reach here")
except my_exception as e:
assert e.args == (1, 3, 2)
- assert e.kwds == {'monkey': 'bone', 'the_exception': e.__cause__}
+ assert e.kwds == {"monkey": "bone", "the_exception": e.__cause__}
diff --git a/tests/test_data_source.py b/tests/test_data_source.py
index ddd3eee6..1ede9aa0 100644
--- a/tests/test_data_source.py
+++ b/tests/test_data_source.py
@@ -53,15 +53,15 @@ class TestDataSource:
assert reader_data == writer_data
def _mk_data(self, size=(100000)):
- return ''.join(str(x % 10) for x in range(size))
+ return "".join(str(x % 10) for x in range(size))
def test_transfer_to_data_source(self):
data = self._mk_data()
reader = self.get_obj(data=data)
if self.supports_mutable:
- writer = self.get_obj(data='', mutable=True)
+ writer = self.get_obj(data="", mutable=True)
else:
- writer = data_source.data_source('', mutable=True)
+ writer = data_source.data_source("", mutable=True)
reader.transfer_to_data_source(writer)
self.assertContents(reader, writer)
@@ -70,9 +70,11 @@ class TestDataSource:
data = self._mk_data()
reader = self.get_obj(data=data)
if isinstance(reader, data_source.bz2_source):
- writer = data_source.bz2_source(tmp_path / 'transfer_to_path', mutable=True)
+ writer = data_source.bz2_source(tmp_path / "transfer_to_path", mutable=True)
else:
- writer = data_source.local_source(tmp_path / 'transfer_to_path', mutable=True)
+ writer = data_source.local_source(
+ tmp_path / "transfer_to_path", mutable=True
+ )
reader.transfer_to_path(writer.path)
@@ -82,9 +84,9 @@ class TestDataSource:
data = self._mk_data()
reader = self.get_obj(data=data)
if self.supports_mutable:
- writer = self.get_obj(data='', mutable=True)
+ writer = self.get_obj(data="", mutable=True)
else:
- writer = data_source.data_source('', mutable=True)
+ writer = data_source.data_source("", mutable=True)
with reader.bytes_fileobj() as reader_f, writer.bytes_fileobj(True) as writer_f:
data_source.transfer_between_files(reader_f, writer_f)
@@ -93,15 +95,14 @@ class TestDataSource:
class TestLocalSource(TestDataSource):
-
def get_obj(self, data="foonani", mutable=False, test_creation=False):
self.fp = self.dir / "localsource.test"
if not test_creation:
mode = None
if isinstance(data, bytes):
- mode = 'wb'
+ mode = "wb"
elif mode is None:
- mode = 'w'
+ mode = "w"
with open(self.fp, mode) as f:
f.write(data)
return data_source.local_source(self.fp, mutable=mutable)
@@ -118,21 +119,20 @@ class TestLocalSource(TestDataSource):
obj = self.get_obj(test_creation=True, mutable=True)
# this will blow up if tries to ascii decode it.
with obj.bytes_fileobj(True) as f:
- assert f.read() == b''
+ assert f.read() == b""
f.write(data)
with obj.bytes_fileobj() as f:
assert f.read() == data
class TestBz2Source(TestDataSource):
-
def get_obj(self, data="foonani", mutable=False, test_creation=False):
self.fp = self.dir / "bz2source.test.bz2"
if not test_creation:
if isinstance(data, str):
data = data.encode()
- with open(self.fp, 'wb') as f:
- f.write(compression.compress_data('bzip2', data))
+ with open(self.fp, "wb") as f:
+ f.write(compression.compress_data("bzip2", data))
return data_source.bz2_source(self.fp, mutable=mutable)
def test_bytes_fileobj(self):
@@ -150,8 +150,7 @@ class Test_invokable_data_source(TestDataSource):
def get_obj(self, data="foonani", mutable=False):
if isinstance(data, str):
data = data.encode("utf8")
- return data_source.invokable_data_source(
- partial(self._get_data, data))
+ return data_source.invokable_data_source(partial(self._get_data, data))
@staticmethod
def _get_data(data, is_text=False):
@@ -168,10 +167,10 @@ class Test_invokable_data_source_wrapper_text(Test_invokable_data_source):
def get_obj(self, mutable=False, data="foonani"):
return data_source.invokable_data_source.wrap_function(
- partial(self._get_data, data),
- self.text_mode)
+ partial(self._get_data, data), self.text_mode
+ )
- def _get_data(self, data='foonani'):
+ def _get_data(self, data="foonani"):
if isinstance(data, str):
if not self.text_mode:
return data.encode("utf8")
diff --git a/tests/test_decorators.py b/tests/test_decorators.py
index ce00ac4b..92c0fb8c 100644
--- a/tests/test_decorators.py
+++ b/tests/test_decorators.py
@@ -8,7 +8,6 @@ from snakeoil.decorators import coroutine, namespace, splitexec
class TestSplitExecDecorator:
-
def setup_method(self, method):
self.pid = os.getpid()
@@ -18,11 +17,14 @@ class TestSplitExecDecorator:
assert self.pid != os.getpid()
-@pytest.mark.skipif(not sys.platform.startswith('linux'), reason='supported on Linux only')
+@pytest.mark.skipif(
+ not sys.platform.startswith("linux"), reason="supported on Linux only"
+)
class TestNamespaceDecorator:
-
- @pytest.mark.skipif(not os.path.exists('/proc/self/ns/user'),
- reason='user namespace support required')
+ @pytest.mark.skipif(
+ not os.path.exists("/proc/self/ns/user"),
+ reason="user namespace support required",
+ )
def test_user_namespace(self):
@namespace(user=True)
def do_test():
@@ -31,31 +33,34 @@ class TestNamespaceDecorator:
try:
do_test()
except PermissionError:
- pytest.skip('No permission to use user namespace')
-
- @pytest.mark.skipif(not (os.path.exists('/proc/self/ns/user') and os.path.exists('/proc/self/ns/uts')),
- reason='user and uts namespace support required')
+ pytest.skip("No permission to use user namespace")
+
+ @pytest.mark.skipif(
+ not (
+ os.path.exists("/proc/self/ns/user") and os.path.exists("/proc/self/ns/uts")
+ ),
+ reason="user and uts namespace support required",
+ )
def test_uts_namespace(self):
- @namespace(user=True, uts=True, hostname='host')
+ @namespace(user=True, uts=True, hostname="host")
def do_test():
- ns_hostname, _, ns_domainname = socket.getfqdn().partition('.')
- assert ns_hostname == 'host'
- assert ns_domainname == ''
+ ns_hostname, _, ns_domainname = socket.getfqdn().partition(".")
+ assert ns_hostname == "host"
+ assert ns_domainname == ""
try:
do_test()
except PermissionError:
- pytest.skip('No permission to use user and uts namespace')
+ pytest.skip("No permission to use user and uts namespace")
class TestCoroutineDecorator:
-
def test_coroutine(self):
@coroutine
def count():
i = 0
while True:
- val = (yield i)
+ val = yield i
i = val if val is not None else i + 1
cr = count()
diff --git a/tests/test_demandload.py b/tests/test_demandload.py
index ec2661c9..843cf801 100644
--- a/tests/test_demandload.py
+++ b/tests/test_demandload.py
@@ -9,6 +9,7 @@ from snakeoil import demandload
# setup is what the test expects.
# it also explicitly resets the state on the way out.
+
def reset_globals(functor):
def f(*args, **kwds):
orig_demandload = demandload.demandload
@@ -22,60 +23,61 @@ def reset_globals(functor):
demandload.demand_compile_regexp = orig_demand_compile
demandload._protection_enabled = orig_protection
demandload._noisy_protection = orig_noisy
+
return f
class TestParser:
-
@reset_globals
def test_parse(self):
for input, output in [
- ('foo', [('foo', 'foo')]),
- ('foo:bar', [('foo.bar', 'bar')]),
- ('foo:bar,baz@spork', [('foo.bar', 'bar'), ('foo.baz', 'spork')]),
- ('foo@bar', [('foo', 'bar')]),
- ('foo_bar', [('foo_bar', 'foo_bar')]),
- ]:
+ ("foo", [("foo", "foo")]),
+ ("foo:bar", [("foo.bar", "bar")]),
+ ("foo:bar,baz@spork", [("foo.bar", "bar"), ("foo.baz", "spork")]),
+ ("foo@bar", [("foo", "bar")]),
+ ("foo_bar", [("foo_bar", "foo_bar")]),
+ ]:
assert output == list(demandload.parse_imports([input]))
- pytest.raises(ValueError, list, demandload.parse_imports(['a.b']))
- pytest.raises(ValueError, list, demandload.parse_imports(['a:,']))
- pytest.raises(ValueError, list, demandload.parse_imports(['a:b,x@']))
- pytest.raises(ValueError, list, demandload.parse_imports(['b-x']))
- pytest.raises(ValueError, list, demandload.parse_imports([' b_x']))
+ pytest.raises(ValueError, list, demandload.parse_imports(["a.b"]))
+ pytest.raises(ValueError, list, demandload.parse_imports(["a:,"]))
+ pytest.raises(ValueError, list, demandload.parse_imports(["a:b,x@"]))
+ pytest.raises(ValueError, list, demandload.parse_imports(["b-x"]))
+ pytest.raises(ValueError, list, demandload.parse_imports([" b_x"]))
class TestPlaceholder:
-
@reset_globals
def test_getattr(self):
scope = {}
- placeholder = demandload.Placeholder(scope, 'foo', list)
- assert scope == object.__getattribute__(placeholder, '_scope')
+ placeholder = demandload.Placeholder(scope, "foo", list)
+ assert scope == object.__getattribute__(placeholder, "_scope")
assert placeholder.__doc__ == [].__doc__
- assert scope['foo'] == []
+ assert scope["foo"] == []
demandload._protection_enabled = lambda: True
with pytest.raises(ValueError):
- getattr(placeholder, '__doc__')
+ getattr(placeholder, "__doc__")
@reset_globals
def test__str__(self):
scope = {}
- placeholder = demandload.Placeholder(scope, 'foo', list)
- assert scope == object.__getattribute__(placeholder, '_scope')
+ placeholder = demandload.Placeholder(scope, "foo", list)
+ assert scope == object.__getattribute__(placeholder, "_scope")
assert str(placeholder) == str([])
- assert scope['foo'] == []
+ assert scope["foo"] == []
@reset_globals
def test_call(self):
def passthrough(*args, **kwargs):
return args, kwargs
+
def get_func():
return passthrough
+
scope = {}
- placeholder = demandload.Placeholder(scope, 'foo', get_func)
- assert scope == object.__getattribute__(placeholder, '_scope')
- assert (('arg',), {'kwarg': 42}) == placeholder('arg', kwarg=42)
- assert passthrough is scope['foo']
+ placeholder = demandload.Placeholder(scope, "foo", get_func)
+ assert scope == object.__getattribute__(placeholder, "_scope")
+ assert (("arg",), {"kwarg": 42}) == placeholder("arg", kwarg=42)
+ assert passthrough is scope["foo"]
@reset_globals
def test_setattr(self):
@@ -83,45 +85,43 @@ class TestPlaceholder:
pass
scope = {}
- placeholder = demandload.Placeholder(scope, 'foo', Struct)
+ placeholder = demandload.Placeholder(scope, "foo", Struct)
placeholder.val = 7
demandload._protection_enabled = lambda: True
with pytest.raises(ValueError):
- getattr(placeholder, 'val')
- assert 7 == scope['foo'].val
+ getattr(placeholder, "val")
+ assert 7 == scope["foo"].val
class TestImport:
-
@reset_globals
def test_demandload(self):
scope = {}
- demandload.demandload('snakeoil:demandload', scope=scope)
- assert demandload is not scope['demandload']
- assert demandload.demandload is scope['demandload'].demandload
- assert demandload is scope['demandload']
+ demandload.demandload("snakeoil:demandload", scope=scope)
+ assert demandload is not scope["demandload"]
+ assert demandload.demandload is scope["demandload"].demandload
+ assert demandload is scope["demandload"]
@reset_globals
def test_disabled_demandload(self):
scope = {}
- demandload.disabled_demandload('snakeoil:demandload', scope=scope)
- assert demandload is scope['demandload']
+ demandload.disabled_demandload("snakeoil:demandload", scope=scope)
+ assert demandload is scope["demandload"]
class TestDemandCompileRegexp:
-
@reset_globals
def test_demand_compile_regexp(self):
scope = {}
- demandload.demand_compile_regexp('foo', 'frob', scope=scope)
- assert list(scope.keys()) == ['foo']
- assert 'frob' == scope['foo'].pattern
- assert 'frob' == scope['foo'].pattern
+ demandload.demand_compile_regexp("foo", "frob", scope=scope)
+ assert list(scope.keys()) == ["foo"]
+ assert "frob" == scope["foo"].pattern
+ assert "frob" == scope["foo"].pattern
# verify it's delayed via a bad regex.
- demandload.demand_compile_regexp('foo', 'f(', scope=scope)
- assert list(scope.keys()) == ['foo']
+ demandload.demand_compile_regexp("foo", "f(", scope=scope)
+ assert list(scope.keys()) == ["foo"]
# should blow up on accessing an attribute.
- obj = scope['foo']
+ obj = scope["foo"]
with pytest.raises(sre_constants.error):
- getattr(obj, 'pattern')
+ getattr(obj, "pattern")
diff --git a/tests/test_demandload_usage.py b/tests/test_demandload_usage.py
index 79ccfe03..ddde0563 100644
--- a/tests/test_demandload_usage.py
+++ b/tests/test_demandload_usage.py
@@ -4,7 +4,7 @@ from snakeoil.test import mixins
class TestDemandLoadTargets(mixins.PythonNamespaceWalker):
- target_namespace = 'snakeoil'
+ target_namespace = "snakeoil"
ignore_all_import_failures = False
@pytest.fixture(autouse=True)
@@ -16,8 +16,8 @@ class TestDemandLoadTargets(mixins.PythonNamespaceWalker):
def test_demandload_targets(self):
for x in self.walk_namespace(
- self.target_namespace,
- ignore_failed_imports=self.ignore_all_import_failures):
+ self.target_namespace, ignore_failed_imports=self.ignore_all_import_failures
+ ):
self.check_space(x)
def check_space(self, mod):
diff --git a/tests/test_dependant_methods.py b/tests/test_dependant_methods.py
index ffd6d363..186e175a 100644
--- a/tests/test_dependant_methods.py
+++ b/tests/test_dependant_methods.py
@@ -8,7 +8,6 @@ def func(self, seq, data, val=True):
class TestDependantMethods:
-
@staticmethod
def generate_instance(methods, dependencies):
class Class(metaclass=dm.ForcedDepends):
@@ -25,13 +24,15 @@ class TestDependantMethods:
results = []
o = self.generate_instance(
{str(x): currying.post_curry(func, results, x) for x in range(10)},
- {str(x): str(x - 1) for x in range(1, 10)})
+ {str(x): str(x - 1) for x in range(1, 10)},
+ )
getattr(o, "9")()
assert results == list(range(10))
results = []
o = self.generate_instance(
{str(x): currying.post_curry(func, results, x, False) for x in range(10)},
- {str(x): str(x - 1) for x in range(1, 10)})
+ {str(x): str(x - 1) for x in range(1, 10)},
+ )
getattr(o, "9")()
assert results == [0]
getattr(o, "9")()
@@ -41,7 +42,8 @@ class TestDependantMethods:
results = []
o = self.generate_instance(
{str(x): currying.post_curry(func, results, x) for x in range(10)},
- {str(x): str(x - 1) for x in range(1, 10)})
+ {str(x): str(x - 1) for x in range(1, 10)},
+ )
getattr(o, "1")()
assert results == [0, 1]
getattr(o, "2")()
@@ -71,14 +73,15 @@ class TestDependantMethods:
results = []
o = self.generate_instance(
{str(x): currying.post_curry(func, results, x) for x in range(10)},
- {str(x): str(x - 1) for x in range(1, 10)})
- getattr(o, '2')(ignore_deps=True)
+ {str(x): str(x - 1) for x in range(1, 10)},
+ )
+ getattr(o, "2")(ignore_deps=True)
assert [2] == results
def test_no_deps(self):
results = []
o = self.generate_instance(
- {str(x): currying.post_curry(func, results, x) for x in range(10)},
- {})
- getattr(o, '2')()
+ {str(x): currying.post_curry(func, results, x) for x in range(10)}, {}
+ )
+ getattr(o, "2")()
assert [2] == results
diff --git a/tests/test_fileutils.py b/tests/test_fileutils.py
index a4555f89..356eb74d 100644
--- a/tests/test_fileutils.py
+++ b/tests/test_fileutils.py
@@ -13,7 +13,6 @@ from snakeoil.test import random_str
class TestTouch:
-
@pytest.fixture
def random_path(self, tmp_path):
return tmp_path / random_str(10)
@@ -124,19 +123,19 @@ class TestAtomicWriteFile:
def cpy_setup_class(scope, func_name):
- if getattr(fileutils, 'native_%s' % func_name) \
- is getattr(fileutils, func_name):
- scope['skip'] = 'extensions disabled'
+ if getattr(fileutils, "native_%s" % func_name) is getattr(fileutils, func_name):
+ scope["skip"] = "extensions disabled"
else:
- scope['func'] = staticmethod(getattr(fileutils, func_name))
+ scope["func"] = staticmethod(getattr(fileutils, func_name))
+
class Test_readfile:
func = staticmethod(fileutils.readfile)
- test_cases = ['asdf\nfdasswer\1923', '', '987234']
+ test_cases = ["asdf\nfdasswer\1923", "", "987234"]
- default_encoding = 'ascii'
- none_on_missing_ret_data = 'dar'
+ default_encoding = "ascii"
+ none_on_missing_ret_data = "dar"
@staticmethod
def convert_data(data, encoding):
@@ -147,7 +146,7 @@ class Test_readfile:
return data
def test_it(self, tmp_path):
- fp = tmp_path / 'testfile'
+ fp = tmp_path / "testfile"
for expected in self.test_cases:
raised = None
encoding = self.default_encoding
@@ -168,16 +167,16 @@ class Test_readfile:
assert self.func(path) == expected
def test_none_on_missing(self, tmp_path):
- fp = tmp_path / 'nonexistent'
+ fp = tmp_path / "nonexistent"
with pytest.raises(FileNotFoundError):
self.func(fp)
assert self.func(fp, True) is None
- fp.write_bytes(self.convert_data('dar', 'ascii'))
+ fp.write_bytes(self.convert_data("dar", "ascii"))
assert self.func(fp, True) == self.none_on_missing_ret_data
# ensure it handles paths that go through files-
# still should be suppress
- assert self.func(fp / 'extra', True) is None
+ assert self.func(fp / "extra", True) is None
class Test_readfile_ascii(Test_readfile):
@@ -186,85 +185,86 @@ class Test_readfile_ascii(Test_readfile):
class Test_readfile_utf8(Test_readfile):
func = staticmethod(fileutils.readfile_utf8)
- default_encoding = 'utf8'
+ default_encoding = "utf8"
class Test_readfile_bytes(Test_readfile):
func = staticmethod(fileutils.readfile_bytes)
default_encoding = None
- test_cases = list(map(
- currying.post_curry(Test_readfile.convert_data, 'ascii'),
- Test_readfile.test_cases))
- test_cases.append('\ua000fa'.encode("utf8"))
+ test_cases = list(
+ map(
+ currying.post_curry(Test_readfile.convert_data, "ascii"),
+ Test_readfile.test_cases,
+ )
+ )
+ test_cases.append("\ua000fa".encode("utf8"))
none_on_missing_ret_data = Test_readfile.convert_data(
- Test_readfile.none_on_missing_ret_data, 'ascii')
+ Test_readfile.none_on_missing_ret_data, "ascii"
+ )
class readlines_mixin:
-
def assertFunc(self, path, expected):
expected = tuple(expected.split())
- if expected == ('',):
+ if expected == ("",):
expected = ()
- if 'utf8' not in self.encoding_mode:
+ if "utf8" not in self.encoding_mode:
assert tuple(self.func(path)) == expected
return
assert tuple(self.func(path)) == expected
def test_none_on_missing(self, tmp_path):
- fp = tmp_path / 'nonexistent'
+ fp = tmp_path / "nonexistent"
with pytest.raises(FileNotFoundError):
self.func(fp)
assert not tuple(self.func(fp, False, True))
- fp.write_bytes(self.convert_data('dar', 'ascii'))
+ fp.write_bytes(self.convert_data("dar", "ascii"))
assert tuple(self.func(fp, True)) == (self.none_on_missing_ret_data,)
- assert not tuple(self.func(fp / 'missing', False, True))
+ assert not tuple(self.func(fp / "missing", False, True))
def test_strip_whitespace(self, tmp_path):
- fp = tmp_path / 'data'
+ fp = tmp_path / "data"
- fp.write_bytes(self.convert_data(' dar1 \ndar2 \n dar3\n',
- 'ascii'))
+ fp.write_bytes(self.convert_data(" dar1 \ndar2 \n dar3\n", "ascii"))
results = tuple(self.func(fp, True))
- expected = ('dar1', 'dar2', 'dar3')
- if self.encoding_mode == 'bytes':
+ expected = ("dar1", "dar2", "dar3")
+ if self.encoding_mode == "bytes":
expected = tuple(x.encode("ascii") for x in expected)
assert results == expected
# this time without the trailing newline...
- fp.write_bytes(self.convert_data(' dar1 \ndar2 \n dar3',
- 'ascii'))
+ fp.write_bytes(self.convert_data(" dar1 \ndar2 \n dar3", "ascii"))
results = tuple(self.func(fp, True))
assert results == expected
# test a couple of edgecases; underly c extension has gotten these
# wrong before.
- fp.write_bytes(self.convert_data('0', 'ascii'))
+ fp.write_bytes(self.convert_data("0", "ascii"))
results = tuple(self.func(fp, True))
- expected = ('0',)
- if self.encoding_mode == 'bytes':
+ expected = ("0",)
+ if self.encoding_mode == "bytes":
expected = tuple(x.encode("ascii") for x in expected)
assert results == expected
- fp.write_bytes(self.convert_data('0\n', 'ascii'))
+ fp.write_bytes(self.convert_data("0\n", "ascii"))
results = tuple(self.func(fp, True))
- expected = ('0',)
- if self.encoding_mode == 'bytes':
+ expected = ("0",)
+ if self.encoding_mode == "bytes":
expected = tuple(x.encode("ascii") for x in expected)
assert results == expected
- fp.write_bytes(self.convert_data('0 ', 'ascii'))
+ fp.write_bytes(self.convert_data("0 ", "ascii"))
results = tuple(self.func(fp, True))
- expected = ('0',)
- if self.encoding_mode == 'bytes':
+ expected = ("0",)
+ if self.encoding_mode == "bytes":
expected = tuple(x.encode("ascii") for x in expected)
assert results == expected
def mk_readlines_test(scope, mode):
- func_name = 'readlines_%s' % mode
- base = globals()['Test_readfile_%s' % mode]
+ func_name = "readlines_%s" % mode
+ base = globals()["Test_readfile_%s" % mode]
class kls(readlines_mixin, base):
func = staticmethod(getattr(fileutils, func_name))
@@ -273,14 +273,15 @@ def mk_readlines_test(scope, mode):
kls.__name__ = "Test_%s" % func_name
scope["Test_%s" % func_name] = kls
+
for case in ("ascii", "bytes", "utf8"):
- name = 'readlines_%s' % case
+ name = "readlines_%s" % case
mk_readlines_test(locals(), case)
class TestBrokenStats:
- test_cases = ['/proc/crypto', '/sys/devices/system/cpu/present']
+ test_cases = ["/proc/crypto", "/sys/devices/system/cpu/present"]
def test_readfile(self):
for path in self.test_cases:
@@ -292,7 +293,7 @@ class TestBrokenStats:
def _check_path(self, path, func, split_it=False):
try:
- with open(path, 'r') as handle:
+ with open(path, "r") as handle:
data = handle.read()
except EnvironmentError as e:
if e.errno not in (errno.ENOENT, errno.EPERM):
@@ -302,7 +303,7 @@ class TestBrokenStats:
func_data = func(path)
if split_it:
func_data = list(func_data)
- data = [x for x in data.split('\n') if x]
+ data = [x for x in data.split("\n") if x]
func_data = [x for x in func_data if x]
assert func_data == data
@@ -313,13 +314,13 @@ class Test_mmap_or_open_for_read:
func = staticmethod(fileutils.mmap_or_open_for_read)
def test_zero_length(self, tmp_path):
- (path := tmp_path / "target").write_text('')
+ (path := tmp_path / "target").write_text("")
m, f = self.func(path)
assert m is None
- assert f.read() == b''
+ assert f.read() == b""
f.close()
- def test_mmap(self, tmp_path, data=b'foonani'):
+ def test_mmap(self, tmp_path, data=b"foonani"):
(path := tmp_path / "target").write_bytes(data)
m, f = self.func(path)
assert len(m) == len(data)
@@ -329,14 +330,14 @@ class Test_mmap_or_open_for_read:
class Test_mmap_and_close:
-
def test_it(self, tmp_path):
- (path := tmp_path / "target").write_bytes(data := b'asdfasdf')
+ (path := tmp_path / "target").write_bytes(data := b"asdfasdf")
fd, m = None, None
try:
fd = os.open(path, os.O_RDONLY)
m = _fileutils.mmap_and_close(
- fd, len(data), mmap.MAP_PRIVATE, mmap.PROT_READ)
+ fd, len(data), mmap.MAP_PRIVATE, mmap.PROT_READ
+ )
# and ensure it closed the fd...
with pytest.raises(EnvironmentError):
os.read(fd, 1)
diff --git a/tests/test_formatters.py b/tests/test_formatters.py
index 266ef1e0..549f2adc 100644
--- a/tests/test_formatters.py
+++ b/tests/test_formatters.py
@@ -18,16 +18,16 @@ class TestPlainTextFormatter:
def test_basics(self):
# As many sporks as fit in 20 chars.
- sporks = ' '.join(3 * ('spork',))
+ sporks = " ".join(3 * ("spork",))
for inputs, output in [
- (('\N{SNOWMAN}',), '?'),
- ((7 * 'spork ',), '%s\n%s\n%s' % (sporks, sporks, 'spork ')),
- (7 * ('spork ',), '%s \n%s \n%s' % (sporks, sporks, 'spork ')),
- ((30 * 'a'), 20 * 'a' + '\n' + 10 * 'a'),
- (30 * ('a',), 20 * 'a' + '\n' + 10 * 'a'),
- ]:
+ (("\N{SNOWMAN}",), "?"),
+ ((7 * "spork ",), "%s\n%s\n%s" % (sporks, sporks, "spork ")),
+ (7 * ("spork ",), "%s \n%s \n%s" % (sporks, sporks, "spork ")),
+ ((30 * "a"), 20 * "a" + "\n" + 10 * "a"),
+ (30 * ("a",), 20 * "a" + "\n" + 10 * "a"),
+ ]:
stream = BytesIO()
- formatter = self.kls(stream, encoding='ascii')
+ formatter = self.kls(stream, encoding="ascii")
formatter.width = 20
formatter.write(autoline=False, wrap=True, *inputs)
assert output.encode() == stream.getvalue()
@@ -35,69 +35,70 @@ class TestPlainTextFormatter:
def test_first_prefix(self):
# As many sporks as fit in 20 chars.
for inputs, output in [
- (('\N{SNOWMAN}',), 'foon:?'),
- ((7 * 'spork ',),
- 'foon:spork spork\n'
- 'spork spork spork\n'
- 'spork spork '),
- (7 * ('spork ',),
- 'foon:spork spork \n'
- 'spork spork spork \n'
- 'spork spork '),
- ((30 * 'a'), 'foon:' + 15 * 'a' + '\n' + 15 * 'a'),
- (30 * ('a',), 'foon:' + 15 * 'a' + '\n' + 15 * 'a'),
- ]:
+ (("\N{SNOWMAN}",), "foon:?"),
+ (
+ (7 * "spork ",),
+ "foon:spork spork\n" "spork spork spork\n" "spork spork ",
+ ),
+ (
+ 7 * ("spork ",),
+ "foon:spork spork \n" "spork spork spork \n" "spork spork ",
+ ),
+ ((30 * "a"), "foon:" + 15 * "a" + "\n" + 15 * "a"),
+ (30 * ("a",), "foon:" + 15 * "a" + "\n" + 15 * "a"),
+ ]:
stream = BytesIO()
- formatter = self.kls(stream, encoding='ascii')
+ formatter = self.kls(stream, encoding="ascii")
formatter.width = 20
- formatter.write(autoline=False, wrap=True, first_prefix='foon:', *inputs)
+ formatter.write(autoline=False, wrap=True, first_prefix="foon:", *inputs)
assert output.encode() == stream.getvalue()
def test_later_prefix(self):
for inputs, output in [
- (('\N{SNOWMAN}',), '?'),
- ((7 * 'spork ',),
- 'spork spork spork\n'
- 'foon:spork spork\n'
- 'foon:spork spork '),
- (7 * ('spork ',),
- 'spork spork spork \n'
- 'foon:spork spork \n'
- 'foon:spork spork '),
- ((30 * 'a'), 20 * 'a' + '\n' + 'foon:' + 10 * 'a'),
- (30 * ('a',), 20 * 'a' + '\n' + 'foon:' + 10 * 'a'),
- ]:
+ (("\N{SNOWMAN}",), "?"),
+ (
+ (7 * "spork ",),
+ "spork spork spork\n" "foon:spork spork\n" "foon:spork spork ",
+ ),
+ (
+ 7 * ("spork ",),
+ "spork spork spork \n" "foon:spork spork \n" "foon:spork spork ",
+ ),
+ ((30 * "a"), 20 * "a" + "\n" + "foon:" + 10 * "a"),
+ (30 * ("a",), 20 * "a" + "\n" + "foon:" + 10 * "a"),
+ ]:
stream = BytesIO()
- formatter = self.kls(stream, encoding='ascii')
+ formatter = self.kls(stream, encoding="ascii")
formatter.width = 20
- formatter.later_prefix = ['foon:']
+ formatter.later_prefix = ["foon:"]
formatter.write(wrap=True, autoline=False, *inputs)
assert output.encode() == stream.getvalue()
def test_complex(self):
stream = BytesIO()
- formatter = self.kls(stream, encoding='ascii')
+ formatter = self.kls(stream, encoding="ascii")
formatter.width = 9
- formatter.first_prefix = ['foo', None, ' d']
- formatter.later_prefix = ['dorkey']
+ formatter.first_prefix = ["foo", None, " d"]
+ formatter.later_prefix = ["dorkey"]
formatter.write("dar bl", wrap=True, autoline=False)
assert "foo ddar\ndorkeybl".encode() == stream.getvalue()
- formatter.write(" "*formatter.width, wrap=True, autoline=True)
+ formatter.write(" " * formatter.width, wrap=True, autoline=True)
formatter.stream = stream = BytesIO()
formatter.write("dar", " b", wrap=True, autoline=False)
assert "foo ddar\ndorkeyb".encode() == stream.getvalue()
- output = \
-""" rdepends: >=dev-lang/python-2.3 >=sys-apps/sed-4.0.5
+ output = """ rdepends: >=dev-lang/python-2.3 >=sys-apps/sed-4.0.5
dev-python/python-fchksum
"""
stream = BytesIO()
- formatter = self.kls(stream, encoding='ascii', width=80)
+ formatter = self.kls(stream, encoding="ascii", width=80)
formatter.wrap = True
assert formatter.autoline
assert formatter.width == 80
- formatter.later_prefix = [' ']
- formatter.write(" rdepends: >=dev-lang/python-2.3 "
- ">=sys-apps/sed-4.0.5 dev-python/python-fchksum")
+ formatter.later_prefix = [" "]
+ formatter.write(
+ " rdepends: >=dev-lang/python-2.3 "
+ ">=sys-apps/sed-4.0.5 dev-python/python-fchksum"
+ )
assert len(formatter.first_prefix) == 0
assert len(formatter.later_prefix) == 1
assert output.encode() == stream.getvalue()
@@ -105,148 +106,176 @@ class TestPlainTextFormatter:
formatter.stream = stream = BytesIO()
# push it right up to the limit.
formatter.width = 82
- formatter.write(" rdepends: >=dev-lang/python-2.3 "
- ">=sys-apps/sed-4.0.5 dev-python/python-fchksum")
+ formatter.write(
+ " rdepends: >=dev-lang/python-2.3 "
+ ">=sys-apps/sed-4.0.5 dev-python/python-fchksum"
+ )
assert output.encode() == stream.getvalue()
formatter.first_prefix = []
- formatter.later_prefix = [' ']
+ formatter.later_prefix = [" "]
formatter.width = 28
formatter.autoline = False
formatter.wrap = True
formatter.stream = stream = BytesIO()
input = (" description: ", "The Portage")
formatter.write(*input)
- output = ''.join(input).rsplit(" ", 1)
- output[1] = ' %s' % output[1]
- assert '\n'.join(output).encode() == stream.getvalue()
-
+ output = "".join(input).rsplit(" ", 1)
+ output[1] = " %s" % output[1]
+ assert "\n".join(output).encode() == stream.getvalue()
def test_wrap_autoline(self):
for inputs, output in [
- ((3 * ('spork',)), 'spork\nspork\nspork\n'),
- (3 * (('spork',),), 'spork\nspork\nspork\n'),
- (((3 * 'spork',),),
- '\n'
- 'foonsporks\n'
- 'foonporksp\n'
- 'foonork\n'),
- ((('fo',), (2 * 'spork',),), 'fo\nsporkspork\n'),
- ((('fo',), (3 * 'spork',),),
- 'fo\n'
- '\n'
- 'foonsporks\n'
- 'foonporksp\n'
- 'foonork\n'),
- ]:
+ ((3 * ("spork",)), "spork\nspork\nspork\n"),
+ (3 * (("spork",),), "spork\nspork\nspork\n"),
+ (((3 * "spork",),), "\n" "foonsporks\n" "foonporksp\n" "foonork\n"),
+ (
+ (
+ ("fo",),
+ (2 * "spork",),
+ ),
+ "fo\nsporkspork\n",
+ ),
+ (
+ (
+ ("fo",),
+ (3 * "spork",),
+ ),
+ "fo\n" "\n" "foonsporks\n" "foonporksp\n" "foonork\n",
+ ),
+ ]:
stream = BytesIO()
- formatter = self.kls(stream, encoding='ascii')
+ formatter = self.kls(stream, encoding="ascii")
formatter.width = 10
for input in inputs:
- formatter.write(wrap=True, later_prefix='foon', *input)
+ formatter.write(wrap=True, later_prefix="foon", *input)
assert output.encode() == stream.getvalue()
class TerminfoFormatterTest:
-
def _test_stream(self, stream, formatter, inputs, output):
stream.seek(0)
stream.truncate()
formatter.write(*inputs)
stream.seek(0)
result = stream.read()
- output = ''.join(output)
- assert output.encode() == result, \
- "given(%r), expected(%r), got(%r)" % (inputs, output, result)
+ output = "".join(output)
+ assert output.encode() == result, "given(%r), expected(%r), got(%r)" % (
+ inputs,
+ output,
+ result,
+ )
@issue7567
def test_terminfo(self):
- esc = '\x1b['
+ esc = "\x1b["
stream = TemporaryFile()
- f = formatters.TerminfoFormatter(stream, 'ansi', True, 'ascii')
+ f = formatters.TerminfoFormatter(stream, "ansi", True, "ascii")
f.autoline = False
for inputs, output in (
- ((f.bold, 'bold'), (esc, '1m', 'bold', esc, '0;10m')),
- ((f.underline, 'underline'),
- (esc, '4m', 'underline', esc, '0;10m')),
- ((f.fg('red'), 'red'), (esc, '31m', 'red', esc, '39;49m')),
- ((f.fg('red'), 'red', f.bold, 'boldred', f.fg(), 'bold',
- f.reset, 'done'),
- (esc, '31m', 'red', esc, '1m', 'boldred', esc, '39;49m', 'bold',
- esc, '0;10m', 'done')),
- ((42,), ('42',)),
- (('\N{SNOWMAN}',), ('?',))
- ):
+ ((f.bold, "bold"), (esc, "1m", "bold", esc, "0;10m")),
+ ((f.underline, "underline"), (esc, "4m", "underline", esc, "0;10m")),
+ ((f.fg("red"), "red"), (esc, "31m", "red", esc, "39;49m")),
+ (
+ (
+ f.fg("red"),
+ "red",
+ f.bold,
+ "boldred",
+ f.fg(),
+ "bold",
+ f.reset,
+ "done",
+ ),
+ (
+ esc,
+ "31m",
+ "red",
+ esc,
+ "1m",
+ "boldred",
+ esc,
+ "39;49m",
+ "bold",
+ esc,
+ "0;10m",
+ "done",
+ ),
+ ),
+ ((42,), ("42",)),
+ (("\N{SNOWMAN}",), ("?",)),
+ ):
self._test_stream(stream, f, inputs, output)
f.autoline = True
- self._test_stream(
- stream, f, ('lala',), ('lala', '\n'))
+ self._test_stream(stream, f, ("lala",), ("lala", "\n"))
def test_unsupported_term(self):
stream = TemporaryFile()
with pytest.raises(formatters.TerminfoUnsupported):
- formatters.TerminfoFormatter(stream, term='dumb')
+ formatters.TerminfoFormatter(stream, term="dumb")
@issue7567
def test_title(self):
stream = TemporaryFile()
try:
- f = formatters.TerminfoFormatter(stream, 'xterm+sl', True, 'ascii')
+ f = formatters.TerminfoFormatter(stream, "xterm+sl", True, "ascii")
except curses.error:
pytest.skip("xterm+sl not in terminfo db")
- f.title('TITLE')
+ f.title("TITLE")
stream.seek(0)
- assert b'\x1b]0;TITLE\x07' == stream.read()
+ assert b"\x1b]0;TITLE\x07" == stream.read()
def _with_term(term, func, *args, **kwargs):
- orig_term = os.environ.get('TERM')
+ orig_term = os.environ.get("TERM")
try:
- os.environ['TERM'] = term
+ os.environ["TERM"] = term
return func(*args, **kwargs)
finally:
if orig_term is None:
- del os.environ['TERM']
+ del os.environ["TERM"]
else:
- os.environ['TERM'] = orig_term
+ os.environ["TERM"] = orig_term
+
# XXX ripped from pkgcore's test_commandline
-def _get_pty_pair(encoding='ascii'):
+def _get_pty_pair(encoding="ascii"):
master_fd, slave_fd = pty.openpty()
- master = os.fdopen(master_fd, 'rb', 0)
- out = os.fdopen(slave_fd, 'wb', 0)
+ master = os.fdopen(master_fd, "rb", 0)
+ out = os.fdopen(slave_fd, "wb", 0)
return master, out
-@pytest.mark.skip(reason='this currently breaks on github ci due to the issue7567 workaround')
+@pytest.mark.skip(
+ reason="this currently breaks on github ci due to the issue7567 workaround"
+)
class TestGetFormatter:
-
@issue7567
def test_dumb_terminal(self):
master, _out = _get_pty_pair()
- formatter = _with_term('dumb', formatters.get_formatter, master)
+ formatter = _with_term("dumb", formatters.get_formatter, master)
assert isinstance(formatter, formatters.PlainTextFormatter)
@issue7567
def test_vt100_terminal(self):
master, _out = _get_pty_pair()
- formatter = _with_term('vt100', formatters.get_formatter, master)
+ formatter = _with_term("vt100", formatters.get_formatter, master)
assert isinstance(formatter, formatters.PlainTextFormatter)
@issue7567
def test_smart_terminal(self):
master, _out = _get_pty_pair()
- formatter = _with_term('xterm', formatters.get_formatter, master)
+ formatter = _with_term("xterm", formatters.get_formatter, master)
assert isinstance(formatter, formatters.TerminfoFormatter)
@issue7567
def test_not_a_tty(self):
stream = TemporaryFile()
- formatter = _with_term('xterm', formatters.get_formatter, stream)
+ formatter = _with_term("xterm", formatters.get_formatter, stream)
assert isinstance(formatter, formatters.PlainTextFormatter)
@issue7567
def test_no_fd(self):
stream = BytesIO()
- formatter = _with_term('xterm', formatters.get_formatter, stream)
+ formatter = _with_term("xterm", formatters.get_formatter, stream)
assert isinstance(formatter, formatters.PlainTextFormatter)
diff --git a/tests/test_iterables.py b/tests/test_iterables.py
index 3345c5e8..d0d57683 100644
--- a/tests/test_iterables.py
+++ b/tests/test_iterables.py
@@ -1,12 +1,10 @@
import operator
import pytest
-from snakeoil.iterables import (caching_iter, expandable_chain, iter_sort,
- partition)
+from snakeoil.iterables import caching_iter, expandable_chain, iter_sort, partition
class TestPartition:
-
def test_empty(self):
a, b = partition(())
assert list(a) == []
@@ -23,19 +21,18 @@ class TestPartition:
class TestExpandableChain:
-
def test_normal_function(self):
i = [iter(range(100)) for x in range(3)]
e = expandable_chain()
e.extend(i)
- assert list(e) == list(range(100))*3
+ assert list(e) == list(range(100)) * 3
for x in i + [e]:
pytest.raises(StopIteration, x.__next__)
def test_extend(self):
e = expandable_chain()
e.extend(range(100) for i in (1, 2))
- assert list(e) == list(range(100))*2
+ assert list(e) == list(range(100)) * 2
with pytest.raises(StopIteration):
e.extend([[]])
@@ -62,7 +59,6 @@ class TestExpandableChain:
class TestCachingIter:
-
def test_iter_consumption(self):
i = iter(range(100))
c = caching_iter(i)
@@ -147,6 +143,7 @@ class Test_iter_sort:
def test_ordering(self):
def f(l):
return sorted(l, key=operator.itemgetter(0))
+
result = list(iter_sort(f, *[iter(range(x, x + 10)) for x in (30, 20, 0, 10)]))
expected = list(range(40))
assert result == expected
diff --git a/tests/test_klass.py b/tests/test_klass.py
index 773925d2..25728fa5 100644
--- a/tests/test_klass.py
+++ b/tests/test_klass.py
@@ -14,7 +14,8 @@ class Test_GetAttrProxy:
class foo1:
def __init__(self, obj):
self.obj = obj
- __getattr__ = self.kls('obj')
+
+ __getattr__ = self.kls("obj")
class foo2:
pass
@@ -27,18 +28,18 @@ class Test_GetAttrProxy:
o2.foon = "dar"
assert o.foon == "dar"
o.foon = "foo"
- assert o.foon == 'foo'
+ assert o.foon == "foo"
def test_attrlist(self):
def make_class(attr_list=None):
class foo(metaclass=self.kls):
if attr_list is not None:
- locals()['__attr_comparison__'] = attr_list
+ locals()["__attr_comparison__"] = attr_list
with pytest.raises(TypeError):
make_class()
with pytest.raises(TypeError):
- make_class(['foon'])
+ make_class(["foon"])
with pytest.raises(TypeError):
make_class([None])
@@ -47,38 +48,39 @@ class Test_GetAttrProxy:
bar = "baz"
class Test:
- method = self.kls('test')
+ method = self.kls("test")
test = foo()
test = Test()
- assert test.method('bar') == foo.bar
+ assert test.method("bar") == foo.bar
class TestDirProxy:
-
@staticmethod
def noninternal_attrs(obj):
- return sorted(x for x in dir(obj) if not re.match(r'__\w+__', x))
+ return sorted(x for x in dir(obj) if not re.match(r"__\w+__", x))
def test_combined(self):
class foo1:
def __init__(self, obj):
self.obj = obj
- __dir__ = klass.DirProxy('obj')
+
+ __dir__ = klass.DirProxy("obj")
class foo2:
def __init__(self):
- self.attr = 'foo'
+ self.attr = "foo"
o2 = foo2()
o = foo1(o2)
- assert self.noninternal_attrs(o) == ['attr', 'obj']
+ assert self.noninternal_attrs(o) == ["attr", "obj"]
def test_empty(self):
class foo1:
def __init__(self, obj):
self.obj = obj
- __dir__ = klass.DirProxy('obj')
+
+ __dir__ = klass.DirProxy("obj")
class foo2:
pass
@@ -86,23 +88,26 @@ class TestDirProxy:
o2 = foo2()
o = foo1(o2)
assert self.noninternal_attrs(o2) == []
- assert self.noninternal_attrs(o) == ['obj']
+ assert self.noninternal_attrs(o) == ["obj"]
def test_slots(self):
class foo1:
- __slots__ = ('obj',)
+ __slots__ = ("obj",)
+
def __init__(self, obj):
self.obj = obj
- __dir__ = klass.DirProxy('obj')
+
+ __dir__ = klass.DirProxy("obj")
class foo2:
- __slots__ = ('attr',)
+ __slots__ = ("attr",)
+
def __init__(self):
- self.attr = 'foo'
+ self.attr = "foo"
o2 = foo2()
o = foo1(o2)
- assert self.noninternal_attrs(o) == ['attr', 'obj']
+ assert self.noninternal_attrs(o) == ["attr", "obj"]
class Test_contains:
@@ -111,6 +116,7 @@ class Test_contains:
def test_it(self):
class c(dict):
__contains__ = self.func
+
d = c({"1": 2})
assert "1" in d
assert 1 not in d
@@ -122,6 +128,7 @@ class Test_get:
def test_it(self):
class c(dict):
get = self.func
+
d = c({"1": 2})
assert d.get("1") == 2
assert d.get("1", 3) == 2
@@ -142,11 +149,13 @@ class Test_chained_getter:
assert id(self.kls("fa2341fa")) == l[0]
def test_eq(self):
- assert self.kls("asdf", disable_inst_caching=True) == \
- self.kls("asdf", disable_inst_caching=True)
+ assert self.kls("asdf", disable_inst_caching=True) == self.kls(
+ "asdf", disable_inst_caching=True
+ )
- assert self.kls("asdf2", disable_inst_caching=True) != \
- self.kls("asdf", disable_inst_caching=True)
+ assert self.kls("asdf2", disable_inst_caching=True) != self.kls(
+ "asdf", disable_inst_caching=True
+ )
def test_it(self):
class maze:
@@ -159,13 +168,13 @@ class Test_chained_getter:
d = {}
m = maze(d)
f = self.kls
- assert f('foon')(m) == m
+ assert f("foon")(m) == m
d["foon"] = 1
- assert f('foon')(m) == 1
- assert f('dar.foon')(m) == 1
- assert f('.'.join(['blah']*10))(m) == m
+ assert f("foon")(m) == 1
+ assert f("dar.foon")(m) == 1
+ assert f(".".join(["blah"] * 10))(m) == m
with pytest.raises(AttributeError):
- f('foon.dar')(m)
+ f("foon.dar")(m)
class Test_jit_attr:
@@ -184,23 +193,28 @@ class Test_jit_attr:
def jit_attr_ext_method(self):
return partial(klass.jit_attr_ext_method, kls=self.kls)
- def mk_inst(self, attrname='_attr', method_lookup=False,
- use_cls_setattr=False, func=None,
- singleton=klass._uncached_singleton):
+ def mk_inst(
+ self,
+ attrname="_attr",
+ method_lookup=False,
+ use_cls_setattr=False,
+ func=None,
+ singleton=klass._uncached_singleton,
+ ):
f = func
if not func:
+
def f(self):
self._invokes.append(self)
return 54321
class cls:
-
def __init__(self):
sf = partial(object.__setattr__, self)
- sf('_sets', [])
- sf('_reflects', [])
- sf('_invokes', [])
+ sf("_sets", [])
+ sf("_reflects", [])
+ sf("_invokes", [])
attr = self.kls(f, attrname, singleton, use_cls_setattr)
@@ -219,13 +233,22 @@ class Test_jit_attr:
sets = [instance] * sets
reflects = [instance] * reflects
invokes = [instance] * invokes
- msg = ("checking %s: got(%r), expected(%r); state was sets=%r, "
- "reflects=%r, invokes=%r" % (
- "%s", "%s", "%s", instance._sets, instance._reflects,
- instance._invokes))
+ msg = (
+ "checking %s: got(%r), expected(%r); state was sets=%r, "
+ "reflects=%r, invokes=%r"
+ % ("%s", "%s", "%s", instance._sets, instance._reflects, instance._invokes)
+ )
assert instance._sets == sets, msg % ("sets", instance._sets, sets)
- assert instance._reflects == reflects, msg % ("reflects", instance._reflects, reflects)
- assert instance._invokes == invokes, msg % ("invokes", instance._invokes, invokes)
+ assert instance._reflects == reflects, msg % (
+ "reflects",
+ instance._reflects,
+ reflects,
+ )
+ assert instance._invokes == invokes, msg % (
+ "invokes",
+ instance._invokes,
+ invokes,
+ )
def test_implementation(self):
obj = self.mk_inst()
@@ -298,7 +321,7 @@ class Test_jit_attr:
object.__setattr__(self, attr, value)
o = cls()
- assert not hasattr(o, 'invoked')
+ assert not hasattr(o, "invoked")
assert o.my_attr == now
assert o._blah2 == now
assert o.invoked
@@ -315,34 +338,34 @@ class Test_jit_attr:
return now2
def __setattr__(self, attr, value):
- if not getattr(self, '_setattr_allowed', False):
+ if not getattr(self, "_setattr_allowed", False):
raise TypeError("setattr isn't allowed for %s" % attr)
object.__setattr__(self, attr, value)
- base.attr = self.jit_attr_ext_method('f1', '_attr')
+ base.attr = self.jit_attr_ext_method("f1", "_attr")
o = base()
assert o.attr == now
assert o._attr == now
assert o.attr == now
- base.attr = self.jit_attr_ext_method('f1', '_attr', use_cls_setattr=True)
+ base.attr = self.jit_attr_ext_method("f1", "_attr", use_cls_setattr=True)
o = base()
with pytest.raises(TypeError):
- getattr(o, 'attr')
+ getattr(o, "attr")
base._setattr_allowed = True
assert o.attr == now
- base.attr = self.jit_attr_ext_method('f2', '_attr2')
+ base.attr = self.jit_attr_ext_method("f2", "_attr2")
o = base()
assert o.attr == now2
assert o._attr2 == now2
# finally, check that it's doing lookups rather then storing the func.
- base.attr = self.jit_attr_ext_method('func', '_attr2')
+ base.attr = self.jit_attr_ext_method("func", "_attr2")
o = base()
# no func...
with pytest.raises(AttributeError):
- getattr(o, 'attr')
+ getattr(o, "attr")
base.func = base.f1
assert o.attr == now
assert o._attr2 == now
@@ -354,7 +377,13 @@ class Test_jit_attr:
def test_check_singleton_is_compare(self):
def throw_assert(*args, **kwds):
- raise AssertionError("I shouldn't be invoked: %s, %s" % (args, kwds,))
+ raise AssertionError(
+ "I shouldn't be invoked: %s, %s"
+ % (
+ args,
+ kwds,
+ )
+ )
class puker:
__eq__ = throw_assert
@@ -369,11 +398,13 @@ class Test_jit_attr:
def test_cached_property(self):
l = []
+
class foo:
@klass.cached_property
def blah(self, l=l, i=iter(range(5))):
l.append(None)
return next(i)
+
f = foo()
assert f.blah == 0
assert len(l) == 1
@@ -413,15 +444,15 @@ class Test_aliased_attr:
o = cls()
with pytest.raises(AttributeError):
- getattr(o, 'attr')
+ getattr(o, "attr")
o.dar = "foon"
with pytest.raises(AttributeError):
- getattr(o, 'attr')
+ getattr(o, "attr")
o.dar = o
o.blah = "monkey"
- assert o.attr == 'monkey'
+ assert o.attr == "monkey"
# verify it'll cross properties...
class blah:
@@ -431,6 +462,7 @@ class Test_aliased_attr:
@property
def foon(self):
return blah()
+
alias = self.func("foon.target")
o = cls()
@@ -442,12 +474,15 @@ class Test_cached_hash:
def test_it(self):
now = int(time())
+
class cls:
invoked = []
+
@self.func
def __hash__(self):
self.invoked.append(self)
return now
+
o = cls()
assert hash(o) == now
assert o.invoked == [o]
@@ -462,7 +497,7 @@ class Test_reflective_hash:
def test_it(self):
class cls:
- __hash__ = self.func('_hash')
+ __hash__ = self.func("_hash")
obj = cls()
with pytest.raises(AttributeError):
@@ -477,7 +512,8 @@ class Test_reflective_hash:
hash(obj)
class cls2:
- __hash__ = self.func('_dar')
+ __hash__ = self.func("_dar")
+
obj = cls2()
with pytest.raises(AttributeError):
hash(obj)
@@ -486,7 +522,6 @@ class Test_reflective_hash:
class TestImmutableInstance:
-
def test_metaclass(self):
self.common_test(lambda x: x, metaclass=klass.immutable_instance)
@@ -506,7 +541,7 @@ class TestImmutableInstance:
with pytest.raises(AttributeError):
delattr(o, "dar")
- object.__setattr__(o, 'dar', 'foon')
+ object.__setattr__(o, "dar", "foon")
with pytest.raises(AttributeError):
delattr(o, "dar")
@@ -541,7 +576,6 @@ class TestAliasMethod:
class TestPatch:
-
def setup_method(self, method):
# cache original methods
self._math_ceil = math.ceil
@@ -556,7 +590,7 @@ class TestPatch:
n = 0.1
assert math.ceil(n) == 1
- @klass.patch('math.ceil')
+ @klass.patch("math.ceil")
def ceil(orig_ceil, n):
return math.floor(n)
@@ -567,8 +601,8 @@ class TestPatch:
assert math.ceil(n) == 2
assert math.floor(n) == 1
- @klass.patch('math.ceil')
- @klass.patch('math.floor')
+ @klass.patch("math.ceil")
+ @klass.patch("math.floor")
def zero(orig_func, n):
return 0
diff --git a/tests/test_mappings.py b/tests/test_mappings.py
index b1aef254..1ffe7801 100644
--- a/tests/test_mappings.py
+++ b/tests/test_mappings.py
@@ -10,7 +10,6 @@ def a_dozen():
class BasicDict(mappings.DictMixin):
-
def __init__(self, i=None, **kwargs):
self._d = {}
mappings.DictMixin.__init__(self, i, **kwargs)
@@ -20,7 +19,6 @@ class BasicDict(mappings.DictMixin):
class MutableDict(BasicDict):
-
def __setitem__(self, key, val):
self._d[key] = val
@@ -36,7 +34,6 @@ class ImmutableDict(BasicDict):
class TestDictMixin:
-
def test_immutability(self):
d = ImmutableDict()
pytest.raises(AttributeError, d.__setitem__, "spork", "foon")
@@ -59,12 +56,12 @@ class TestDictMixin:
pytest.raises(KeyError, d.pop, "spork")
assert d.pop("spork", "bat") == "bat"
assert d.pop("foo") == "bar"
- assert d.popitem(), ("baz" == "cat")
+ assert d.popitem(), "baz" == "cat"
pytest.raises(KeyError, d.popitem)
assert d.pop("nonexistent", None) == None
def test_init(self):
- d = MutableDict((('foo', 'bar'), ('spork', 'foon')), baz="cat")
+ d = MutableDict((("foo", "bar"), ("spork", "foon")), baz="cat")
assert d["foo"] == "bar"
assert d["baz"] == "cat"
d.clear()
@@ -73,19 +70,20 @@ class TestDictMixin:
def test_bool(self):
d = MutableDict()
assert not d
- d['x'] = 1
+ d["x"] = 1
assert d
- del d['x']
+ del d["x"]
assert not d
class RememberingNegateMixin:
-
def setup_method(self, method):
self.negate_calls = []
+
def negate(i):
self.negate_calls.append(i)
return -i
+
self.negate = negate
def teardown_method(self, method):
@@ -94,7 +92,6 @@ class RememberingNegateMixin:
class LazyValDictTestMixin:
-
def test_invalid_operations(self):
pytest.raises(AttributeError, operator.setitem, self.dict, 7, 7)
pytest.raises(AttributeError, operator.delitem, self.dict, 7)
@@ -118,6 +115,7 @@ class LazyValDictTestMixin:
# missing key
def get():
return self.dict[42]
+
pytest.raises(KeyError, get)
def test_caching(self):
@@ -129,7 +127,6 @@ class LazyValDictTestMixin:
class TestLazyValDictWithList(LazyValDictTestMixin, RememberingNegateMixin):
-
def setup_method(self, method):
super().setup_method(method)
self.dict = mappings.LazyValDict(list(range(12)), self.negate)
@@ -148,14 +145,12 @@ class TestLazyValDictWithList(LazyValDictTestMixin, RememberingNegateMixin):
class TestLazyValDictWithFunc(LazyValDictTestMixin, RememberingNegateMixin):
-
def setup_method(self, method):
super().setup_method(method)
self.dict = mappings.LazyValDict(a_dozen, self.negate)
class TestLazyValDict:
-
def test_invalid_init_args(self):
pytest.raises(TypeError, mappings.LazyValDict, [1], 42)
pytest.raises(TypeError, mappings.LazyValDict, 42, a_dozen)
@@ -164,36 +159,43 @@ class TestLazyValDict:
# TODO check for valid values for dict.new, since that seems to be
# part of the interface?
class TestProtectedDict:
-
def setup_method(self, method):
self.orig = {1: -1, 2: -2}
self.dict = mappings.ProtectedDict(self.orig)
def test_basic_operations(self):
assert self.dict[1] == -1
+
def get(i):
return self.dict[i]
+
pytest.raises(KeyError, get, 3)
assert sorted(self.dict.keys()) == [1, 2]
assert -1 not in self.dict
assert 2 in self.dict
+
def remove(i):
del self.dict[i]
+
pytest.raises(KeyError, remove, 50)
def test_basic_mutating(self):
# add something
self.dict[7] = -7
+
def check_after_adding():
assert self.dict[7] == -7
assert 7 in self.dict
assert sorted(self.dict.keys()) == [1, 2, 7]
+
check_after_adding()
# remove it again
del self.dict[7]
assert 7 not in self.dict
+
def get(i):
return self.dict[i]
+
pytest.raises(KeyError, get, 7)
assert sorted(self.dict.keys()) == [1, 2]
# add it back
@@ -214,7 +216,6 @@ class TestProtectedDict:
class TestImmutableDict:
-
def test_init_iterator(self):
d = mappings.ImmutableDict((x, x) for x in range(3))
assert dict(d) == {0: 0, 1: 1, 2: 2}
@@ -239,7 +240,7 @@ class TestImmutableDict:
def test_init_dictmixin(self):
d = MutableDict(baz="cat")
e = mappings.ImmutableDict(d)
- assert dict(d) == {'baz': 'cat'}
+ assert dict(d) == {"baz": "cat"}
def test_init_bad_data(self):
for data in (range(10), list(range(10)), [([], 1)]):
@@ -288,7 +289,6 @@ class TestImmutableDict:
class TestOrderedFrozenSet:
-
def test_magic_methods(self):
s = mappings.OrderedFrozenSet(range(9))
for x in range(9):
@@ -299,7 +299,7 @@ class TestOrderedFrozenSet:
for i in range(9):
assert s[i] == i
assert list(s[1:]) == list(range(1, 9))
- with pytest.raises(IndexError, match='index out of range'):
+ with pytest.raises(IndexError, match="index out of range"):
s[9]
assert s == set(range(9))
@@ -308,12 +308,12 @@ class TestOrderedFrozenSet:
assert hash(s)
def test_ordering(self):
- s = mappings.OrderedFrozenSet('set')
- assert 'set' == ''.join(s)
- assert 'tes' == ''.join(reversed(s))
- s = mappings.OrderedFrozenSet('setordered')
- assert 'setord' == ''.join(s)
- assert 'drotes' == ''.join(reversed(s))
+ s = mappings.OrderedFrozenSet("set")
+ assert "set" == "".join(s)
+ assert "tes" == "".join(reversed(s))
+ s = mappings.OrderedFrozenSet("setordered")
+ assert "setord" == "".join(s)
+ assert "drotes" == "".join(reversed(s))
def test_immmutability(self):
s = mappings.OrderedFrozenSet(range(9))
@@ -355,41 +355,40 @@ class TestOrderedFrozenSet:
class TestOrderedSet(TestOrderedFrozenSet):
-
def test_hash(self):
with pytest.raises(TypeError):
- assert hash(mappings.OrderedSet('set'))
+ assert hash(mappings.OrderedSet("set"))
def test_add(self):
s = mappings.OrderedSet()
- s.add('a')
- assert 'a' in s
+ s.add("a")
+ assert "a" in s
s.add(1)
assert 1 in s
- assert list(s) == ['a', 1]
+ assert list(s) == ["a", 1]
def test_discard(self):
s = mappings.OrderedSet()
- s.discard('a')
- s.add('a')
+ s.discard("a")
+ s.add("a")
assert s
- s.discard('a')
+ s.discard("a")
assert not s
def test_remove(self):
s = mappings.OrderedSet()
with pytest.raises(KeyError):
- s.remove('a')
- s.add('a')
- assert 'a' in s
- s.remove('a')
- assert 'a' not in s
+ s.remove("a")
+ s.add("a")
+ assert "a" in s
+ s.remove("a")
+ assert "a" not in s
def test_clear(self):
s = mappings.OrderedSet()
s.clear()
assert len(s) == 0
- s.add('a')
+ s.add("a")
assert len(s) == 1
s.clear()
assert len(s) == 0
@@ -425,8 +424,9 @@ class TestStackedDict:
assert x in std
def test_len(self):
- assert sum(map(len, (self.orig_dict, self.new_dict))) == \
- len(mappings.StackedDict(self.orig_dict, self.new_dict))
+ assert sum(map(len, (self.orig_dict, self.new_dict))) == len(
+ mappings.StackedDict(self.orig_dict, self.new_dict)
+ )
def test_setattr(self):
pytest.raises(TypeError, mappings.StackedDict().__setitem__, (1, 2))
@@ -447,24 +447,28 @@ class TestStackedDict:
assert len(s) == 0
def test_keys(self):
- assert sorted(mappings.StackedDict(self.orig_dict, self.new_dict)) == \
- sorted(list(self.orig_dict.keys()) + list(self.new_dict.keys()))
+ assert sorted(mappings.StackedDict(self.orig_dict, self.new_dict)) == sorted(
+ list(self.orig_dict.keys()) + list(self.new_dict.keys())
+ )
class TestIndeterminantDict:
-
def test_disabled_methods(self):
d = mappings.IndeterminantDict(lambda *a: None)
for x in (
- "clear",
- ("update", {}),
- ("setdefault", 1),
- "__iter__", "__len__", "__hash__",
- ("__delitem__", 1),
- ("__setitem__", 2),
- ("popitem", 2),
- "keys", "items", "values",
- ):
+ "clear",
+ ("update", {}),
+ ("setdefault", 1),
+ "__iter__",
+ "__len__",
+ "__hash__",
+ ("__delitem__", 1),
+ ("__setitem__", 2),
+ ("popitem", 2),
+ "keys",
+ "items",
+ "values",
+ ):
if isinstance(x, tuple):
pytest.raises(TypeError, getattr(d, x[0]), x[1])
else:
@@ -472,7 +476,8 @@ class TestIndeterminantDict:
def test_starter_dict(self):
d = mappings.IndeterminantDict(
- lambda key: False, starter_dict={}.fromkeys(range(100), True))
+ lambda key: False, starter_dict={}.fromkeys(range(100), True)
+ )
for x in range(100):
assert d[x] == True
for x in range(100, 110):
@@ -481,21 +486,24 @@ class TestIndeterminantDict:
def test_behaviour(self):
val = []
d = mappings.IndeterminantDict(
- lambda key: val.append(key), {}.fromkeys(range(10), True))
+ lambda key: val.append(key), {}.fromkeys(range(10), True)
+ )
assert d[0] == True
assert d[11] == None
assert val == [11]
+
def func(*a):
raise KeyError
+
with pytest.raises(KeyError):
mappings.IndeterminantDict(func).__getitem__(1)
-
def test_get(self):
def func(key):
if key == 2:
raise KeyError
return True
+
d = mappings.IndeterminantDict(func, {1: 1})
assert d.get(1, 1) == 1
assert d.get(1, 2) == 1
@@ -505,41 +513,42 @@ class TestIndeterminantDict:
class TestFoldingDict:
-
def test_preserve(self):
dct = mappings.PreservingFoldingDict(
- str.lower, list({'Foo': 'bar', 'fnz': 'donkey'}.items()))
- assert dct['fnz'] == 'donkey'
- assert dct['foo'] == 'bar'
- assert sorted(['bar' == 'donkey']), sorted(dct.values())
+ str.lower, list({"Foo": "bar", "fnz": "donkey"}.items())
+ )
+ assert dct["fnz"] == "donkey"
+ assert dct["foo"] == "bar"
+ assert sorted(["bar" == "donkey"]), sorted(dct.values())
assert dct.copy() == dct
- assert dct['foo'] == dct.get('Foo')
- assert 'foo' in dct
- keys = ['Foo', 'fnz']
+ assert dct["foo"] == dct.get("Foo")
+ assert "foo" in dct
+ keys = ["Foo", "fnz"]
keysList = list(dct)
for key in keys:
assert key in list(dct.keys())
assert key in keysList
assert (key, dct[key]) in list(dct.items())
assert len(keys) == len(dct)
- assert dct.pop('foo') == 'bar'
- assert 'foo' not in dct
- del dct['fnz']
- assert 'fnz' not in dct
- dct['Foo'] = 'bar'
+ assert dct.pop("foo") == "bar"
+ assert "foo" not in dct
+ del dct["fnz"]
+ assert "fnz" not in dct
+ dct["Foo"] = "bar"
dct.refold(lambda _: _)
- assert 'foo' not in dct
- assert 'Foo' in dct
- assert list(dct.items()) == [('Foo', 'bar')]
+ assert "foo" not in dct
+ assert "Foo" in dct
+ assert list(dct.items()) == [("Foo", "bar")]
dct.clear()
assert {} == dict(dct)
def test_no_preserve(self):
dct = mappings.NonPreservingFoldingDict(
- str.lower, list({'Foo': 'bar', 'fnz': 'monkey'}.items()))
- assert sorted(['bar', 'monkey']) == sorted(dct.values())
+ str.lower, list({"Foo": "bar", "fnz": "monkey"}.items())
+ )
+ assert sorted(["bar", "monkey"]) == sorted(dct.values())
assert dct.copy() == dct
- keys = ['foo', 'fnz']
+ keys = ["foo", "fnz"]
keysList = [key for key in dct]
for key in keys:
assert key in list(dct.keys())
@@ -547,8 +556,8 @@ class TestFoldingDict:
assert key in keysList
assert (key, dct[key]) in list(dct.items())
assert len(keys) == len(dct)
- assert dct.pop('foo') == 'bar'
- del dct['fnz']
+ assert dct.pop("foo") == "bar"
+ del dct["fnz"]
assert list(dct.keys()) == []
dct.clear()
assert {} == dict(dct)
@@ -580,20 +589,20 @@ class Test_attr_to_item_mapping:
if kls is None:
kls = self.kls
o = kls(f=2, g=3)
- assert ['f', 'g'] == sorted(o)
- self.assertBoth(o, 'g', 3)
+ assert ["f", "g"] == sorted(o)
+ self.assertBoth(o, "g", 3)
o.g = 4
- self.assertBoth(o, 'g', 4)
+ self.assertBoth(o, "g", 4)
del o.g
with pytest.raises(KeyError):
- operator.__getitem__(o, 'g')
+ operator.__getitem__(o, "g")
with pytest.raises(AttributeError):
- getattr(o, 'g')
- del o['f']
+ getattr(o, "g")
+ del o["f"]
with pytest.raises(KeyError):
- operator.__getitem__(o, 'f')
+ operator.__getitem__(o, "f")
with pytest.raises(AttributeError):
- getattr(o, 'f')
+ getattr(o, "f")
def test_inject(self):
class foon(dict):
@@ -611,30 +620,31 @@ class Test_ProxiedAttrs:
def __init__(self, **kwargs):
for attr, val in kwargs.items():
setattr(self, attr, val)
+
obj = foo()
d = self.kls(obj)
with pytest.raises(KeyError):
- operator.__getitem__(d, 'x')
+ operator.__getitem__(d, "x")
with pytest.raises(KeyError):
- operator.__delitem__(d, 'x')
- assert 'x' not in d
- d['x'] = 1
- assert d['x'] == 1
- assert 'x' in d
- assert ['x'] == list(x for x in d if not x.startswith("__"))
- del d['x']
- assert 'x' not in d
+ operator.__delitem__(d, "x")
+ assert "x" not in d
+ d["x"] = 1
+ assert d["x"] == 1
+ assert "x" in d
+ assert ["x"] == list(x for x in d if not x.startswith("__"))
+ del d["x"]
+ assert "x" not in d
with pytest.raises(KeyError):
- operator.__delitem__(d, 'x')
+ operator.__delitem__(d, "x")
with pytest.raises(KeyError):
- operator.__getitem__(d, 'x')
+ operator.__getitem__(d, "x")
# Finally, verify that immutable attribute errors are handled correctly.
d = self.kls(object())
with pytest.raises(KeyError):
- operator.__setitem__(d, 'x', 1)
+ operator.__setitem__(d, "x", 1)
with pytest.raises(KeyError):
- operator.__delitem__(d, 'x')
+ operator.__delitem__(d, "x")
class TestSlottedDict:
@@ -642,9 +652,9 @@ class TestSlottedDict:
kls = staticmethod(mappings.make_SlottedDict_kls)
def test_exceptions(self):
- d = self.kls(['spork'])()
+ d = self.kls(["spork"])()
for op in (operator.getitem, operator.delitem):
with pytest.raises(KeyError):
- op(d, 'spork')
+ op(d, "spork")
with pytest.raises(KeyError):
- op(d, 'foon')
+ op(d, "foon")
diff --git a/tests/test_modules.py b/tests/test_modules.py
index f4174979..da20f4f8 100644
--- a/tests/test_modules.py
+++ b/tests/test_modules.py
@@ -5,19 +5,18 @@ from snakeoil import modules
class TestModules:
-
@pytest.fixture(autouse=True)
def _setup(self, tmp_path):
# set up some test modules for our use
- packdir = tmp_path / 'mod_testpack'
+ packdir = tmp_path / "mod_testpack"
packdir.mkdir()
# create an empty file
- (packdir / '__init__.py').touch()
+ (packdir / "__init__.py").touch()
for directory in (tmp_path, packdir):
for i in range(3):
- (directory / f'mod_test{i}.py').write_text('def foo(): pass\n')
- (directory / 'mod_horked.py').write_text('1/0\n')
+ (directory / f"mod_test{i}.py").write_text("def foo(): pass\n")
+ (directory / "mod_horked.py").write_text("1/0\n")
# append them to path
sys.path.insert(0, str(tmp_path))
@@ -27,89 +26,93 @@ class TestModules:
sys.path.pop(0)
# make sure we don't keep the sys.modules entries around
for i in range(3):
- sys.modules.pop('mod_test%s' % i, None)
- sys.modules.pop('mod_testpack.mod_test%s' % i, None)
- sys.modules.pop('mod_testpack', None)
- sys.modules.pop('mod_horked', None)
- sys.modules.pop('mod_testpack.mod_horked', None)
+ sys.modules.pop("mod_test%s" % i, None)
+ sys.modules.pop("mod_testpack.mod_test%s" % i, None)
+ sys.modules.pop("mod_testpack", None)
+ sys.modules.pop("mod_horked", None)
+ sys.modules.pop("mod_testpack.mod_horked", None)
def test_load_module(self):
# import an already-imported module
- assert modules.load_module('snakeoil.modules') is modules
+ assert modules.load_module("snakeoil.modules") is modules
# and a system one, just for kicks
- assert modules.load_module('sys') is sys
+ assert modules.load_module("sys") is sys
# non-existing module from an existing package
with pytest.raises(modules.FailedImport):
- modules.load_module('snakeoil.__not_there')
+ modules.load_module("snakeoil.__not_there")
# (hopefully :) non-existing top-level module/package
with pytest.raises(modules.FailedImport):
- modules.load_module('__not_there')
+ modules.load_module("__not_there")
# "Unable to import"
# pylint: disable=F0401
# unimported toplevel module
- modtest1 = modules.load_module('mod_test1')
+ modtest1 = modules.load_module("mod_test1")
import mod_test1
+
assert mod_test1 is modtest1
# unimported in-package module
- packtest2 = modules.load_module('mod_testpack.mod_test2')
+ packtest2 = modules.load_module("mod_testpack.mod_test2")
from mod_testpack import mod_test2
+
assert mod_test2 is packtest2
def test_load_attribute(self):
# already imported
- assert modules.load_attribute('sys.path') is sys.path
+ assert modules.load_attribute("sys.path") is sys.path
# unimported
- myfoo = modules.load_attribute('mod_testpack.mod_test2.foo')
+ myfoo = modules.load_attribute("mod_testpack.mod_test2.foo")
# "Unable to import"
# pylint: disable=F0401
from mod_testpack.mod_test2 import foo
+
assert foo is myfoo
# nonexisting attribute
with pytest.raises(modules.FailedImport):
- modules.load_attribute('snakeoil.froznicator')
+ modules.load_attribute("snakeoil.froznicator")
# nonexisting top-level
with pytest.raises(modules.FailedImport):
- modules.load_attribute('spork_does_not_exist.foo')
+ modules.load_attribute("spork_does_not_exist.foo")
# not an attr
with pytest.raises(modules.FailedImport):
- modules.load_attribute('sys')
+ modules.load_attribute("sys")
# not imported yet
with pytest.raises(modules.FailedImport):
- modules.load_attribute('mod_testpack.mod_test3')
+ modules.load_attribute("mod_testpack.mod_test3")
def test_load_any(self):
# import an already-imported module
- assert modules.load_any('snakeoil.modules') is modules
+ assert modules.load_any("snakeoil.modules") is modules
# attribute of an already imported module
- assert modules.load_any('sys.path') is sys.path
+ assert modules.load_any("sys.path") is sys.path
# already imported toplevel.
- assert sys is modules.load_any('sys')
+ assert sys is modules.load_any("sys")
# unimported
- myfoo = modules.load_any('mod_testpack.mod_test2.foo')
+ myfoo = modules.load_any("mod_testpack.mod_test2.foo")
# "Unable to import"
# pylint: disable=F0401
from mod_testpack.mod_test2 import foo
+
assert foo is myfoo
# nonexisting attribute
with pytest.raises(modules.FailedImport):
- modules.load_any('snakeoil.froznicator')
+ modules.load_any("snakeoil.froznicator")
# nonexisting top-level
with pytest.raises(modules.FailedImport):
- modules.load_any('spork_does_not_exist.foo')
+ modules.load_any("spork_does_not_exist.foo")
with pytest.raises(modules.FailedImport):
- modules.load_any('spork_does_not_exist')
+ modules.load_any("spork_does_not_exist")
# not imported yet
with pytest.raises(modules.FailedImport):
- modules.load_any('mod_testpack.mod_test3')
+ modules.load_any("mod_testpack.mod_test3")
def test_broken_module(self):
for func in [modules.load_module, modules.load_any]:
with pytest.raises(modules.FailedImport):
- func('mod_testpack.mod_horked')
- assert 'mod_testpack.mod_horked' not in sys.modules
+ func("mod_testpack.mod_horked")
+ assert "mod_testpack.mod_horked" not in sys.modules
diff --git a/tests/test_obj.py b/tests/test_obj.py
index 83f9f776..78083d8b 100644
--- a/tests/test_obj.py
+++ b/tests/test_obj.py
@@ -7,7 +7,6 @@ make_DIkls = obj.DelayedInstantiation_kls
class TestDelayedInstantiation:
-
def test_simple(self):
t = tuple([1, 2, 3])
o = make_DI(tuple, lambda: t)
@@ -19,19 +18,34 @@ class TestDelayedInstantiation:
assert t >= o
def test_descriptor_awareness(self):
- def assertKls(cls, ignores=(),
- default_ignores=("__new__", "__init__", "__init_subclass__",
- "__getattribute__", "__class__",
- "__getnewargs__", "__getstate__",
- "__doc__", "__class_getitem__")):
- required = set(x for x in dir(cls)
- if x.startswith("__") and x.endswith("__"))
+ def assertKls(
+ cls,
+ ignores=(),
+ default_ignores=(
+ "__new__",
+ "__init__",
+ "__init_subclass__",
+ "__getattribute__",
+ "__class__",
+ "__getnewargs__",
+ "__getstate__",
+ "__doc__",
+ "__class_getitem__",
+ ),
+ ):
+ required = set(
+ x for x in dir(cls) if x.startswith("__") and x.endswith("__")
+ )
missing = required.difference(obj.kls_descriptors)
missing.difference_update(obj.base_kls_descriptors)
missing.difference_update(default_ignores)
missing.difference_update(ignores)
- assert not missing, ("object %r potentially has unsupported special "
- "attributes: %s" % (cls, ', '.join(missing)))
+ assert (
+ not missing
+ ), "object %r potentially has unsupported special " "attributes: %s" % (
+ cls,
+ ", ".join(missing),
+ )
assertKls(object)
assertKls(1)
@@ -43,25 +57,38 @@ class TestDelayedInstantiation:
def test_BaseDelayedObject(self):
# assert that all methods/descriptors of object
# are covered via the base.
- o = set(dir(object)).difference(f"__{x}__" for x in (
- "class", "getattribute", "new", "init", "init_subclass", "getstate", "doc"))
+ o = set(dir(object)).difference(
+ f"__{x}__"
+ for x in (
+ "class",
+ "getattribute",
+ "new",
+ "init",
+ "init_subclass",
+ "getstate",
+ "doc",
+ )
+ )
diff = o.difference(obj.base_kls_descriptors)
- assert not diff, ("base delayed instantiation class should cover all of object, but "
- "%r was spotted" % (",".join(sorted(diff)),))
+ assert not diff, (
+ "base delayed instantiation class should cover all of object, but "
+ "%r was spotted" % (",".join(sorted(diff)),)
+ )
assert obj.DelayedInstantiation_kls(int, "1") + 2 == 3
-
def test_klass_choice_optimization(self):
"""ensure that BaseDelayedObject is used whenever possible"""
# note object is an odd one- it actually has a __doc__, thus
# it must always be a custom
o = make_DI(object, object)
- assert object.__getattribute__(o, '__class__') is not obj.BaseDelayedObject
+ assert object.__getattribute__(o, "__class__") is not obj.BaseDelayedObject
+
class foon:
pass
+
o = make_DI(foon, foon)
- cls = object.__getattribute__(o, '__class__')
+ cls = object.__getattribute__(o, "__class__")
assert cls is obj.BaseDelayedObject
# now ensure we always get the same kls back for derivatives
@@ -70,39 +97,43 @@ class TestDelayedInstantiation:
return True
o = make_DI(foon, foon)
- cls = object.__getattribute__(o, '__class__')
+ cls = object.__getattribute__(o, "__class__")
assert cls is not obj.BaseDelayedObject
o = make_DI(foon, foon)
- cls2 = object.__getattribute__(o, '__class__')
+ cls2 = object.__getattribute__(o, "__class__")
assert cls is cls2
def test__class__(self):
l = []
+
def f():
l.append(False)
return True
+
o = make_DI(bool, f)
assert isinstance(o, bool)
assert not l, "accessing __class__ shouldn't trigger instantiation"
def test__doc__(self):
l = []
+
def f():
l.append(True)
return foon()
+
class foon:
__doc__ = "monkey"
o = make_DI(foon, f)
- assert o.__doc__ == 'monkey'
+ assert o.__doc__ == "monkey"
assert not l, (
"in accessing __doc__, the instance was generated- "
"this is a class level attribute, thus shouldn't "
- "trigger instantiation")
+ "trigger instantiation"
+ )
class TestPopattr:
-
class Object:
pass
@@ -113,21 +144,21 @@ class TestPopattr:
def test_no_attrs(self):
# object without any attrs
with pytest.raises(AttributeError):
- obj.popattr(object(), 'nonexistent')
+ obj.popattr(object(), "nonexistent")
def test_nonexistent_attr(self):
# object with attr trying to get nonexistent attr
with pytest.raises(AttributeError):
- obj.popattr(self.o, 'nonexistent')
+ obj.popattr(self.o, "nonexistent")
def test_fallback(self):
# object with attr trying to get nonexistent attr using fallback
- value = obj.popattr(self.o, 'nonexistent', 2)
+ value = obj.popattr(self.o, "nonexistent", 2)
assert value == 2
def test_removed_attr(self):
- value = obj.popattr(self.o, 'test')
+ value = obj.popattr(self.o, "test")
assert value == 1
# verify that attr was removed from the object
with pytest.raises(AttributeError):
- obj.popattr(self.o, 'test')
+ obj.popattr(self.o, "test")
diff --git a/tests/test_osutils.py b/tests/test_osutils.py
index 18092823..264d670d 100644
--- a/tests/test_osutils.py
+++ b/tests/test_osutils.py
@@ -16,44 +16,45 @@ from snakeoil.osutils.mount import MNT_DETACH, MS_BIND, mount, umount
class ReaddirCommon:
-
@pytest.fixture
def subdir(self, tmp_path):
- subdir = tmp_path / 'dir'
+ subdir = tmp_path / "dir"
subdir.mkdir()
- (tmp_path / 'file').touch()
- os.mkfifo((tmp_path / 'fifo'))
+ (tmp_path / "file").touch()
+ os.mkfifo((tmp_path / "fifo"))
return subdir
def _test_missing(self, tmp_path, funcs):
for func in funcs:
- pytest.raises(OSError, func, tmp_path / 'spork')
+ pytest.raises(OSError, func, tmp_path / "spork")
class TestNativeListDir(ReaddirCommon):
-
def test_listdir(self, tmp_path, subdir):
- assert set(native_readdir.listdir(tmp_path)) == {'dir', 'fifo', 'file'}
+ assert set(native_readdir.listdir(tmp_path)) == {"dir", "fifo", "file"}
assert native_readdir.listdir(subdir) == []
def test_listdir_dirs(self, tmp_path, subdir):
- assert native_readdir.listdir_dirs(tmp_path) == ['dir']
+ assert native_readdir.listdir_dirs(tmp_path) == ["dir"]
assert native_readdir.listdir_dirs(subdir) == []
def test_listdir_files(self, tmp_path, subdir):
- assert native_readdir.listdir_files(tmp_path) == ['file']
+ assert native_readdir.listdir_files(tmp_path) == ["file"]
assert native_readdir.listdir_dirs(subdir) == []
def test_missing(self, tmp_path, subdir):
- return self._test_missing(tmp_path, (
- native_readdir.listdir,
- native_readdir.listdir_dirs,
- native_readdir.listdir_files,
- ))
+ return self._test_missing(
+ tmp_path,
+ (
+ native_readdir.listdir,
+ native_readdir.listdir_dirs,
+ native_readdir.listdir_files,
+ ),
+ )
def test_dangling_sym(self, tmp_path, subdir):
(tmp_path / "monkeys").symlink_to("foon")
- assert native_readdir.listdir_files(tmp_path) == ['file']
+ assert native_readdir.listdir_files(tmp_path) == ["file"]
class TestNativeReaddir(ReaddirCommon):
@@ -78,36 +79,37 @@ class TestNativeReaddir(ReaddirCommon):
class TestEnsureDirs:
-
def check_dir(self, path, uid, gid, mode):
assert path.is_dir()
st = os.stat(path)
- assert stat.S_IMODE(st.st_mode) == mode, \
- '0%o != 0%o' % (stat.S_IMODE(st.st_mode), mode)
+ assert stat.S_IMODE(st.st_mode) == mode, "0%o != 0%o" % (
+ stat.S_IMODE(st.st_mode),
+ mode,
+ )
assert st.st_uid == uid
assert st.st_gid == gid
def test_ensure_dirs(self, tmp_path):
# default settings
- path = tmp_path / 'foo' / 'bar'
+ path = tmp_path / "foo" / "bar"
assert osutils.ensure_dirs(path)
self.check_dir(path, os.geteuid(), os.getegid(), 0o777)
def test_minimal_nonmodifying(self, tmp_path):
- path = tmp_path / 'foo' / 'bar'
+ path = tmp_path / "foo" / "bar"
assert osutils.ensure_dirs(path, mode=0o755)
os.chmod(path, 0o777)
assert osutils.ensure_dirs(path, mode=0o755, minimal=True)
self.check_dir(path, os.geteuid(), os.getegid(), 0o777)
def test_minimal_modifying(self, tmp_path):
- path = tmp_path / 'foo' / 'bar'
+ path = tmp_path / "foo" / "bar"
assert osutils.ensure_dirs(path, mode=0o750)
assert osutils.ensure_dirs(path, mode=0o005, minimal=True)
self.check_dir(path, os.geteuid(), os.getegid(), 0o755)
def test_create_unwritable_subdir(self, tmp_path):
- path = tmp_path / 'restricted' / 'restricted'
+ path = tmp_path / "restricted" / "restricted"
# create the subdirs without 020 first
assert osutils.ensure_dirs(path.parent)
assert osutils.ensure_dirs(path, mode=0o020)
@@ -118,38 +120,39 @@ class TestEnsureDirs:
def test_path_is_a_file(self, tmp_path):
# fail if passed a path to an existing file
- path = tmp_path / 'file'
+ path = tmp_path / "file"
touch(path)
assert path.is_file()
assert not osutils.ensure_dirs(path, mode=0o700)
def test_non_dir_in_path(self, tmp_path):
# fail if one of the parts of the path isn't a dir
- path = tmp_path / 'file' / 'dir'
- (tmp_path / 'file').touch()
+ path = tmp_path / "file" / "dir"
+ (tmp_path / "file").touch()
assert not osutils.ensure_dirs(path, mode=0o700)
def test_mkdir_failing(self, tmp_path):
# fail if os.mkdir fails
- with mock.patch('snakeoil.osutils.os.mkdir') as mkdir:
- mkdir.side_effect = OSError(30, 'Read-only file system')
- path = tmp_path / 'dir'
+ with mock.patch("snakeoil.osutils.os.mkdir") as mkdir:
+ mkdir.side_effect = OSError(30, "Read-only file system")
+ path = tmp_path / "dir"
assert not osutils.ensure_dirs(path, mode=0o700)
# force temp perms
assert not osutils.ensure_dirs(path, mode=0o400)
- mkdir.side_effect = OSError(17, 'File exists')
+ mkdir.side_effect = OSError(17, "File exists")
assert not osutils.ensure_dirs(path, mode=0o700)
def test_chmod_or_chown_failing(self, tmp_path):
# fail if chmod or chown fails
- path = tmp_path / 'dir'
+ path = tmp_path / "dir"
path.mkdir()
path.chmod(0o750)
- with mock.patch('snakeoil.osutils.os.chmod') as chmod, \
- mock.patch('snakeoil.osutils.os.chown') as chown:
- chmod.side_effect = OSError(5, 'Input/output error')
+ with mock.patch("snakeoil.osutils.os.chmod") as chmod, mock.patch(
+ "snakeoil.osutils.os.chown"
+ ) as chown:
+ chmod.side_effect = OSError(5, "Input/output error")
# chmod failure when file exists and trying to reset perms to match
# the specified mode
@@ -163,13 +166,13 @@ class TestEnsureDirs:
# chown failure when resetting perms on parents
chmod.side_effect = None
- chown.side_effect = OSError(5, 'Input/output error')
+ chown.side_effect = OSError(5, "Input/output error")
assert not osutils.ensure_dirs(path, uid=1000, gid=1000, mode=0o400)
def test_reset_sticky_parent_perms(self, tmp_path):
# make sure perms are reset after traversing over sticky parents
- sticky_parent = tmp_path / 'dir'
- path = sticky_parent / 'dir'
+ sticky_parent = tmp_path / "dir"
+ path = sticky_parent / "dir"
sticky_parent.mkdir()
sticky_parent.chmod(0o2755)
pre_sticky_parent = os.stat(sticky_parent)
@@ -178,7 +181,7 @@ class TestEnsureDirs:
assert pre_sticky_parent.st_mode == post_sticky_parent.st_mode
def test_mode(self, tmp_path):
- path = tmp_path / 'mode' / 'mode'
+ path = tmp_path / "mode" / "mode"
assert osutils.ensure_dirs(path, mode=0o700)
self.check_dir(path, os.geteuid(), os.getegid(), 0o700)
# unrestrict it
@@ -188,12 +191,12 @@ class TestEnsureDirs:
def test_gid(self, tmp_path):
# abuse the portage group as secondary group
try:
- portage_gid = grp.getgrnam('portage').gr_gid
+ portage_gid = grp.getgrnam("portage").gr_gid
except KeyError:
- pytest.skip('the portage group does not exist')
+ pytest.skip("the portage group does not exist")
if portage_gid not in os.getgroups():
- pytest.skip('you are not in the portage group')
- path = tmp_path / 'group' / 'group'
+ pytest.skip("you are not in the portage group")
+ path = tmp_path / "group" / "group"
assert osutils.ensure_dirs(path, gid=portage_gid)
self.check_dir(path, os.geteuid(), portage_gid, 0o777)
assert osutils.ensure_dirs(path)
@@ -203,12 +206,11 @@ class TestEnsureDirs:
class TestAbsSymlink:
-
def test_abssymlink(self, tmp_path):
- target = tmp_path / 'target'
- linkname = tmp_path / 'link'
+ target = tmp_path / "target"
+ linkname = tmp_path / "link"
target.mkdir()
- linkname.symlink_to('target')
+ linkname.symlink_to("target")
assert osutils.abssymlink(linkname) == str(target)
@@ -223,28 +225,30 @@ class Test_Native_NormPath:
got = f(src)
assert got == val, f"{src!r}: expected {val!r}, got {got!r}"
- check('/foo/', '/foo')
- check('//foo/', '/foo')
- check('//foo/.', '/foo')
- check('//..', '/')
- check('//..//foo', '/foo')
- check('/foo/..', '/')
- check('..//foo', '../foo')
- check('../foo/../', '..')
- check('../', '..')
- check('../foo/..', '..')
- check('../foo/../dar', '../dar')
- check('.//foo', 'foo')
- check('/foo/../../', '/')
- check('/foo/../../..', '/')
- check('/tmp/foo/../dar/', '/tmp/dar')
- check('/tmp/foo/../dar', '/tmp/dar')
+ check("/foo/", "/foo")
+ check("//foo/", "/foo")
+ check("//foo/.", "/foo")
+ check("//..", "/")
+ check("//..//foo", "/foo")
+ check("/foo/..", "/")
+ check("..//foo", "../foo")
+ check("../foo/../", "..")
+ check("../", "..")
+ check("../foo/..", "..")
+ check("../foo/../dar", "../dar")
+ check(".//foo", "foo")
+ check("/foo/../../", "/")
+ check("/foo/../../..", "/")
+ check("/tmp/foo/../dar/", "/tmp/dar")
+ check("/tmp/foo/../dar", "/tmp/dar")
# explicit unicode and bytes
- check('/tmṕ/föo//../dár', '/tmṕ/dár')
- check(b'/tm\xe1\xb9\x95/f\xc3\xb6o//../d\xc3\xa1r', b'/tm\xe1\xb9\x95/d\xc3\xa1r')
- check('/föó/..', '/')
- check(b'/f\xc3\xb6\xc3\xb3/..', b'/')
+ check("/tmṕ/föo//../dár", "/tmṕ/dár")
+ check(
+ b"/tm\xe1\xb9\x95/f\xc3\xb6o//../d\xc3\xa1r", b"/tm\xe1\xb9\x95/d\xc3\xa1r"
+ )
+ check("/föó/..", "/")
+ check(b"/f\xc3\xb6\xc3\xb3/..", b"/")
@pytest.mark.skipif(os.getuid() != 0, reason="these tests must be ran as root")
@@ -253,7 +257,7 @@ class TestAccess:
func = staticmethod(osutils.fallback_access)
def test_fallback(self, tmp_path):
- fp = tmp_path / 'file'
+ fp = tmp_path / "file"
# create the file
fp.touch()
fp.chmod(0o000)
@@ -270,9 +274,9 @@ class Test_unlink_if_exists:
def test_it(self, tmp_path):
f = self.func
- path = tmp_path / 'target'
+ path = tmp_path / "target"
f(path)
- path.write_text('')
+ path.write_text("")
f(path)
assert not path.exists()
# and once more for good measure...
@@ -280,18 +284,17 @@ class Test_unlink_if_exists:
class TestSupportedSystems:
-
def test_supported_system(self):
- @supported_systems('supported')
+ @supported_systems("supported")
def func():
return True
- with mock.patch('snakeoil.osutils.sys') as _sys:
- _sys.configure_mock(platform='supported')
+ with mock.patch("snakeoil.osutils.sys") as _sys:
+ _sys.configure_mock(platform="supported")
assert func()
def test_unsupported_system(self):
- @supported_systems('unsupported')
+ @supported_systems("unsupported")
def func():
return True
@@ -299,39 +302,39 @@ class TestSupportedSystems:
func()
# make sure we're iterating through the system params correctly
- with mock.patch('snakeoil.osutils.sys') as _sys:
- _sys.configure_mock(platform='u')
+ with mock.patch("snakeoil.osutils.sys") as _sys:
+ _sys.configure_mock(platform="u")
with pytest.raises(NotImplementedError):
func()
def test_multiple_systems(self):
- @supported_systems('darwin', 'linux')
+ @supported_systems("darwin", "linux")
def func():
return True
- with mock.patch('snakeoil.osutils.sys') as _sys:
- _sys.configure_mock(platform='nonexistent')
+ with mock.patch("snakeoil.osutils.sys") as _sys:
+ _sys.configure_mock(platform="nonexistent")
with pytest.raises(NotImplementedError):
func()
- for platform in ('linux2', 'darwin'):
+ for platform in ("linux2", "darwin"):
_sys.configure_mock(platform=platform)
assert func()
-@pytest.mark.skipif(not sys.platform.startswith('linux'),
- reason='supported on Linux only')
+@pytest.mark.skipif(
+ not sys.platform.startswith("linux"), reason="supported on Linux only"
+)
class TestMount:
-
@pytest.fixture
def source(self, tmp_path):
- source = tmp_path / 'source'
+ source = tmp_path / "source"
source.mkdir()
return source
@pytest.fixture
def target(self, tmp_path):
- target = tmp_path / 'target'
+ target = tmp_path / "target"
target.mkdir()
return target
@@ -340,21 +343,25 @@ class TestMount:
# byte strings; if they are unicode strings the arguments get mangled
# leading to errors when the syscall is run. This confirms mount() from
# snakeoil.osutils always converts the arguments into byte strings.
- for source, target, fstype in ((b'source', b'target', b'fstype'),
- ('source', 'target', 'fstype')):
- with mock.patch('snakeoil.osutils.mount.ctypes') as mock_ctypes:
+ for source, target, fstype in (
+ (b"source", b"target", b"fstype"),
+ ("source", "target", "fstype"),
+ ):
+ with mock.patch("snakeoil.osutils.mount.ctypes") as mock_ctypes:
with pytest.raises(OSError):
mount(str(source), str(target), fstype, MS_BIND)
- mount_call = next(x for x in mock_ctypes.mock_calls if x[0] == 'CDLL().mount')
+ mount_call = next(
+ x for x in mock_ctypes.mock_calls if x[0] == "CDLL().mount"
+ )
for arg in mount_call[1][0:3]:
assert isinstance(arg, bytes)
def test_missing_dirs(self):
with pytest.raises(OSError) as cm:
- mount('source', 'target', None, MS_BIND)
+ mount("source", "target", None, MS_BIND)
assert cm.value.errno in (errno.EPERM, errno.ENOENT)
- @pytest.mark.skipif(os.getuid() == 0, reason='this test must be run as non-root')
+ @pytest.mark.skipif(os.getuid() == 0, reason="this test must be run as non-root")
def test_no_perms(self, source, target):
with pytest.raises(OSError) as cm:
mount(str(source), str(target), None, MS_BIND)
@@ -363,11 +370,15 @@ class TestMount:
umount(str(target))
assert cm.value.errno in (errno.EPERM, errno.EINVAL)
- @pytest.mark.skipif(not (os.path.exists('/proc/self/ns/mnt') and os.path.exists('/proc/self/ns/user')),
- reason='user and mount namespace support required')
+ @pytest.mark.skipif(
+ not (
+ os.path.exists("/proc/self/ns/mnt") and os.path.exists("/proc/self/ns/user")
+ ),
+ reason="user and mount namespace support required",
+ )
def test_bind_mount(self, source, target):
- src_file = source / 'file'
- bind_file = target / 'file'
+ src_file = source / "file"
+ bind_file = target / "file"
src_file.touch()
try:
@@ -378,15 +389,19 @@ class TestMount:
umount(str(target))
assert not bind_file.exists()
except PermissionError:
- pytest.skip('No permission to use user and mount namespace')
-
- @pytest.mark.skipif(not (os.path.exists('/proc/self/ns/mnt') and os.path.exists('/proc/self/ns/user')),
- reason='user and mount namespace support required')
+ pytest.skip("No permission to use user and mount namespace")
+
+ @pytest.mark.skipif(
+ not (
+ os.path.exists("/proc/self/ns/mnt") and os.path.exists("/proc/self/ns/user")
+ ),
+ reason="user and mount namespace support required",
+ )
def test_lazy_unmount(self, source, target):
- src_file = source / 'file'
- bind_file = target / 'file'
+ src_file = source / "file"
+ bind_file = target / "file"
src_file.touch()
- src_file.write_text('foo')
+ src_file.write_text("foo")
try:
with Namespace(user=True, mount=True):
@@ -403,14 +418,14 @@ class TestMount:
# confirm the file doesn't exist in the bind mount anymore
assert not bind_file.exists()
# but the file is still accessible to the process
- assert f.read() == 'foo'
+ assert f.read() == "foo"
# trying to reopen causes IOError
with pytest.raises(IOError) as cm:
f = bind_file.open()
assert cm.value.errno == errno.ENOENT
except PermissionError:
- pytest.skip('No permission to use user and mount namespace')
+ pytest.skip("No permission to use user and mount namespace")
class TestSizeofFmt:
diff --git a/tests/test_process.py b/tests/test_process.py
index bb45712e..488b7b03 100644
--- a/tests/test_process.py
+++ b/tests/test_process.py
@@ -30,8 +30,10 @@ class TestFindBinary:
process.find_binary(self.script)
def test_fallback(self):
- fallback = process.find_binary(self.script, fallback=os.path.join('bin', self.script))
- assert fallback == os.path.join('bin', self.script)
+ fallback = process.find_binary(
+ self.script, fallback=os.path.join("bin", self.script)
+ )
+ assert fallback == os.path.join("bin", self.script)
def test_not_executable(self, tmp_path):
fp = tmp_path / self.script
diff --git a/tests/test_process_spawn.py b/tests/test_process_spawn.py
index 8981c6e7..556b34cc 100644
--- a/tests/test_process_spawn.py
+++ b/tests/test_process_spawn.py
@@ -6,11 +6,11 @@ from snakeoil import process
from snakeoil.contexts import chdir
from snakeoil.process import spawn
-BASH_BINARY = process.find_binary("bash", fallback='')
+BASH_BINARY = process.find_binary("bash", fallback="")
-@pytest.mark.skipif(not BASH_BINARY, reason='missing bash binary')
-class TestSpawn:
+@pytest.mark.skipif(not BASH_BINARY, reason="missing bash binary")
+class TestSpawn:
@pytest.fixture(autouse=True)
def _setup(self, tmp_path):
orig_path = os.environ["PATH"]
@@ -37,21 +37,25 @@ class TestSpawn:
def test_get_output(self, tmp_path, dev_null):
filename = "spawn-getoutput.sh"
for r, s, text, args in (
- [0, ["dar\n"], "echo dar\n", {}],
- [0, ["dar"], "echo -n dar", {}],
- [1, ["blah\n", "dar\n"], "echo blah\necho dar\nexit 1", {}],
- [0, [], "echo dar 1>&2", {"fd_pipes": {1: 1, 2: dev_null}}]):
+ [0, ["dar\n"], "echo dar\n", {}],
+ [0, ["dar"], "echo -n dar", {}],
+ [1, ["blah\n", "dar\n"], "echo blah\necho dar\nexit 1", {}],
+ [0, [], "echo dar 1>&2", {"fd_pipes": {1: 1, 2: dev_null}}],
+ ):
fp = self.generate_script(tmp_path, filename, text)
- assert (r, s) == spawn.spawn_get_output(str(fp), spawn_type=spawn.spawn_bash, **args)
+ assert (r, s) == spawn.spawn_get_output(
+ str(fp), spawn_type=spawn.spawn_bash, **args
+ )
os.unlink(fp)
@pytest.mark.skipif(not spawn.is_sandbox_capable(), reason="missing sandbox binary")
def test_sandbox(self, tmp_path):
- fp = self.generate_script(
- tmp_path, "spawn-sandbox.sh", "echo $LD_PRELOAD")
+ fp = self.generate_script(tmp_path, "spawn-sandbox.sh", "echo $LD_PRELOAD")
ret = spawn.spawn_get_output(str(fp), spawn_type=spawn.spawn_sandbox)
assert ret[1], "no output; exit code was %s; script location %s" % (ret[0], fp)
- assert "libsandbox.so" in [os.path.basename(x.strip()) for x in ret[1][0].split()]
+ assert "libsandbox.so" in [
+ os.path.basename(x.strip()) for x in ret[1][0].split()
+ ]
os.unlink(fp)
@pytest.mark.skipif(not spawn.is_sandbox_capable(), reason="missing sandbox binary")
@@ -60,15 +64,17 @@ class TestSpawn:
this verifies our fix works.
"""
- fp = self.generate_script(
- tmp_path, "spawn-sandbox.sh", "echo $LD_PRELOAD")
+ fp = self.generate_script(tmp_path, "spawn-sandbox.sh", "echo $LD_PRELOAD")
dpath = tmp_path / "dar"
dpath.mkdir()
with chdir(dpath):
dpath.rmdir()
- assert "libsandbox.so" in \
- [os.path.basename(x.strip()) for x in spawn.spawn_get_output(
- str(fp), spawn_type=spawn.spawn_sandbox, cwd='/')[1][0].split()]
+ assert "libsandbox.so" in [
+ os.path.basename(x.strip())
+ for x in spawn.spawn_get_output(
+ str(fp), spawn_type=spawn.spawn_sandbox, cwd="/"
+ )[1][0].split()
+ ]
fp.unlink()
def test_process_exit_code(self):
@@ -98,13 +104,12 @@ class TestSpawn:
def test_spawn_bash(self, capfd):
# bash builtin for true without exec'ing true (eg, no path lookup)
- assert 0 == spawn.spawn_bash('echo bash')
+ assert 0 == spawn.spawn_bash("echo bash")
out, _err = capfd.readouterr()
- assert out.strip() == 'bash'
+ assert out.strip() == "bash"
def test_umask(self, tmp_path):
- fp = self.generate_script(
- tmp_path, "spawn-umask.sh", f"#!{BASH_BINARY}\numask")
+ fp = self.generate_script(tmp_path, "spawn-umask.sh", f"#!{BASH_BINARY}\numask")
try:
old_umask = os.umask(0)
if old_umask == 0:
@@ -113,7 +118,8 @@ class TestSpawn:
os.umask(desired)
else:
desired = 0
- assert str(desired).lstrip("0") == \
- spawn.spawn_get_output(str(fp))[1][0].strip().lstrip("0")
+ assert str(desired).lstrip("0") == spawn.spawn_get_output(str(fp))[1][
+ 0
+ ].strip().lstrip("0")
finally:
os.umask(old_umask)
diff --git a/tests/test_sequences.py b/tests/test_sequences.py
index edbaa5a0..0d8c5a62 100644
--- a/tests/test_sequences.py
+++ b/tests/test_sequences.py
@@ -8,13 +8,11 @@ from snakeoil.sequences import split_elements, split_negations
class UnhashableComplex(complex):
-
def __hash__(self):
raise TypeError
class TestStableUnique:
-
def common_check(self, func):
# silly
assert func(()) == []
@@ -23,9 +21,10 @@ class TestStableUnique:
# neither
def test_stable_unique(self, func=sequences.stable_unique):
- assert list(set([1, 2, 3])) == [1, 2, 3], \
- "this test is reliant on the interpreter hasing 1,2,3 into a specific ordering- " \
+ assert list(set([1, 2, 3])) == [1, 2, 3], (
+ "this test is reliant on the interpreter hasing 1,2,3 into a specific ordering- "
"for whatever reason, ordering differs, thus this test can't verify it"
+ )
assert func([3, 2, 1]) == [3, 2, 1]
def test_iter_stable_unique(self):
@@ -43,20 +42,19 @@ class TestStableUnique:
uc = UnhashableComplex
res = sequences.unstable_unique([uc(1, 0), uc(0, 1), uc(1, 0)])
# sortable
- assert sorted(sequences.unstable_unique(
- [[1, 2], [1, 3], [1, 2], [1, 3]])) == [[1, 2], [1, 3]]
+ assert sorted(sequences.unstable_unique([[1, 2], [1, 3], [1, 2], [1, 3]])) == [
+ [1, 2],
+ [1, 3],
+ ]
assert res == [uc(1, 0), uc(0, 1)] or res == [uc(0, 1), uc(1, 0)]
assert sorted(sequences.unstable_unique(self._generator())) == sorted(range(6))
class TestChainedLists:
-
@staticmethod
def gen_cl():
return sequences.ChainedLists(
- list(range(3)),
- list(range(3, 6)),
- list(range(6, 100))
+ list(range(3)), list(range(3, 6)), list(range(6, 100))
)
def test_contains(self):
@@ -72,7 +70,7 @@ class TestChainedLists:
def test_str(self):
l = sequences.ChainedLists(list(range(3)), list(range(3, 5)))
- assert str(l) == '[ [0, 1, 2], [3, 4] ]'
+ assert str(l) == "[ [0, 1, 2], [3, 4] ]"
def test_getitem(self):
cl = self.gen_cl()
@@ -108,15 +106,18 @@ class Test_iflatten_instance:
def test_it(self):
o = OrderedDict((k, None) for k in range(10))
for l, correct, skip in (
- (["asdf", ["asdf", "asdf"], 1, None],
- ["asdf", "asdf", "asdf", 1, None], str),
- ([o, 1, "fds"], [o, 1, "fds"], (str, OrderedDict)),
- ([o, 1, "fds"], list(range(10)) + [1, "fds"], str),
- ("fds", ["fds"], str),
- ("fds", ["f", "d", "s"], int),
- ('', [''], str),
- (1, [1], int),
- ):
+ (
+ ["asdf", ["asdf", "asdf"], 1, None],
+ ["asdf", "asdf", "asdf", 1, None],
+ str,
+ ),
+ ([o, 1, "fds"], [o, 1, "fds"], (str, OrderedDict)),
+ ([o, 1, "fds"], list(range(10)) + [1, "fds"], str),
+ ("fds", ["fds"], str),
+ ("fds", ["f", "d", "s"], int),
+ ("", [""], str),
+ (1, [1], int),
+ ):
iterator = self.func(l, skip)
assert list(iterator) == correct
assert list(iterator) == []
@@ -126,6 +127,7 @@ class Test_iflatten_instance:
# have to iterate.
def fail():
return list(self.func(None))
+
with pytest.raises(TypeError):
fail()
@@ -148,13 +150,16 @@ class Test_iflatten_func:
def test_it(self):
o = OrderedDict((k, None) for k in range(10))
for l, correct, skip in (
- (["asdf", ["asdf", "asdf"], 1, None],
- ["asdf", "asdf", "asdf", 1, None], str),
- ([o, 1, "fds"], [o, 1, "fds"], (str, OrderedDict)),
- ([o, 1, "fds"], list(range(10)) + [1, "fds"], str),
- ("fds", ["fds"], str),
- (1, [1], int),
- ):
+ (
+ ["asdf", ["asdf", "asdf"], 1, None],
+ ["asdf", "asdf", "asdf", 1, None],
+ str,
+ ),
+ ([o, 1, "fds"], [o, 1, "fds"], (str, OrderedDict)),
+ ([o, 1, "fds"], list(range(10)) + [1, "fds"], str),
+ ("fds", ["fds"], str),
+ (1, [1], int),
+ ):
iterator = self.func(l, lambda x: isinstance(x, skip))
assert list(iterator) == correct
assert list(iterator) == []
@@ -164,6 +169,7 @@ class Test_iflatten_func:
# have to iterate.
def fail():
return list(self.func(None, lambda x: False))
+
with pytest.raises(TypeError):
fail()
@@ -189,25 +195,24 @@ class Test_predicate_split:
assert true_l == list(range(0, 100, 2))
def test_key(self):
- false_l, true_l = self.kls(lambda x: x % 2 == 0,
- ([0, x] for x in range(100)),
- key=itemgetter(1))
+ false_l, true_l = self.kls(
+ lambda x: x % 2 == 0, ([0, x] for x in range(100)), key=itemgetter(1)
+ )
assert false_l == [[0, x] for x in range(1, 100, 2)]
assert true_l == [[0, x] for x in range(0, 100, 2)]
class TestSplitNegations:
-
def test_empty(self):
# empty input
- seq = ''
+ seq = ""
assert split_negations(seq) == ((), ())
def test_bad_value(self):
# no-value negation should raise a ValueError
bad_values = (
- '-',
- 'a b c - d f e',
+ "-",
+ "a b c - d f e",
)
for s in bad_values:
@@ -216,7 +221,7 @@ class TestSplitNegations:
def test_negs(self):
# all negs
- seq = ('-' + str(x) for x in range(100))
+ seq = ("-" + str(x) for x in range(100))
assert split_negations(seq) == (tuple(map(str, range(100))), ())
def test_pos(self):
@@ -226,31 +231,33 @@ class TestSplitNegations:
def test_neg_pos(self):
# both
- seq = (('-' + str(x), str(x)) for x in range(100))
+ seq = (("-" + str(x), str(x)) for x in range(100))
seq = chain.from_iterable(seq)
- assert split_negations(seq) == (tuple(map(str, range(100))), tuple(map(str, range(100))))
+ assert split_negations(seq) == (
+ tuple(map(str, range(100))),
+ tuple(map(str, range(100))),
+ )
def test_converter(self):
# converter method
- seq = (('-' + str(x), str(x)) for x in range(100))
+ seq = (("-" + str(x), str(x)) for x in range(100))
seq = chain.from_iterable(seq)
assert split_negations(seq, int) == (tuple(range(100)), tuple(range(100)))
class TestSplitElements:
-
def test_empty(self):
# empty input
- seq = ''
+ seq = ""
assert split_elements(seq) == ((), (), ())
def test_bad_value(self):
# no-value neg/pos should raise ValueErrors
bad_values = (
- '-',
- '+',
- 'a b c - d f e',
- 'a b c + d f e',
+ "-",
+ "+",
+ "a b c - d f e",
+ "a b c + d f e",
)
for s in bad_values:
@@ -259,7 +266,7 @@ class TestSplitElements:
def test_negs(self):
# all negs
- seq = ('-' + str(x) for x in range(100))
+ seq = ("-" + str(x) for x in range(100))
assert split_elements(seq) == (tuple(map(str, range(100))), (), ())
def test_neutral(self):
@@ -269,12 +276,12 @@ class TestSplitElements:
def test_pos(self):
# all pos
- seq = ('+' + str(x) for x in range(100))
+ seq = ("+" + str(x) for x in range(100))
assert split_elements(seq) == ((), (), tuple(map(str, range(100))))
def test_neg_pos(self):
# both negative and positive values
- seq = (('-' + str(x), '+' + str(x)) for x in range(100))
+ seq = (("-" + str(x), "+" + str(x)) for x in range(100))
seq = chain.from_iterable(seq)
assert split_elements(seq) == (
tuple(map(str, range(100))),
@@ -284,7 +291,7 @@ class TestSplitElements:
def test_neg_neu_pos(self):
# all three value types
- seq = (('-' + str(x), str(x), '+' + str(x)) for x in range(100))
+ seq = (("-" + str(x), str(x), "+" + str(x)) for x in range(100))
seq = chain.from_iterable(seq)
assert split_elements(seq) == (
tuple(map(str, range(100))),
@@ -294,7 +301,10 @@ class TestSplitElements:
def test_converter(self):
# converter method
- seq = (('-' + str(x), str(x), '+' + str(x)) for x in range(100))
+ seq = (("-" + str(x), str(x), "+" + str(x)) for x in range(100))
seq = chain.from_iterable(seq)
assert split_elements(seq, int) == (
- tuple(range(100)), tuple(range(100)), tuple(range(100)))
+ tuple(range(100)),
+ tuple(range(100)),
+ tuple(range(100)),
+ )
diff --git a/tests/test_stringio.py b/tests/test_stringio.py
index 4fb7c78b..1e6d1e5d 100644
--- a/tests/test_stringio.py
+++ b/tests/test_stringio.py
@@ -34,6 +34,7 @@ class readonly_mixin:
class Test_text_readonly(readonly_mixin):
kls = stringio.text_readonly
-class Test_bytes_readonly(readonly_mixin ):
+
+class Test_bytes_readonly(readonly_mixin):
kls = stringio.bytes_readonly
- encoding = 'utf8'
+ encoding = "utf8"
diff --git a/tests/test_strings.py b/tests/test_strings.py
index b55c2306..19305708 100644
--- a/tests/test_strings.py
+++ b/tests/test_strings.py
@@ -3,38 +3,36 @@ from snakeoil.strings import doc_dedent, pluralism
class TestPluralism:
-
def test_none(self):
# default
- assert pluralism([]) == 's'
+ assert pluralism([]) == "s"
# different suffix for nonexistence
- assert pluralism([], none='') == ''
+ assert pluralism([], none="") == ""
def test_singular(self):
# default
- assert pluralism([1]) == ''
+ assert pluralism([1]) == ""
# different suffix for singular existence
- assert pluralism([1], singular='o') == 'o'
+ assert pluralism([1], singular="o") == "o"
def test_plural(self):
# default
- assert pluralism([1, 2]) == 's'
+ assert pluralism([1, 2]) == "s"
# different suffix for plural existence
- assert pluralism([1, 2], plural='ies') == 'ies'
+ assert pluralism([1, 2], plural="ies") == "ies"
def test_int(self):
- assert pluralism(0) == 's'
- assert pluralism(1) == ''
- assert pluralism(2) == 's'
+ assert pluralism(0) == "s"
+ assert pluralism(1) == ""
+ assert pluralism(2) == "s"
class TestDocDedent:
-
def test_empty(self):
- s = ''
+ s = ""
assert s == doc_dedent(s)
def test_non_string(self):
@@ -42,20 +40,20 @@ class TestDocDedent:
doc_dedent(None)
def test_line(self):
- s = 'line'
+ s = "line"
assert s == doc_dedent(s)
def test_indented_line(self):
- for indent in ('\t', ' '):
- s = f'{indent}line'
- assert 'line' == doc_dedent(s)
+ for indent in ("\t", " "):
+ s = f"{indent}line"
+ assert "line" == doc_dedent(s)
def test_docstring(self):
s = """Docstring to test.
foo bar
"""
- assert 'Docstring to test.\n\nfoo bar\n' == doc_dedent(s)
+ assert "Docstring to test.\n\nfoo bar\n" == doc_dedent(s)
def test_all_indented(self):
s = """\
@@ -63,4 +61,4 @@ class TestDocDedent:
foo bar
"""
- assert 'Docstring to test.\n\nfoo bar\n' == doc_dedent(s)
+ assert "Docstring to test.\n\nfoo bar\n" == doc_dedent(s)
diff --git a/tests/test_version.py b/tests/test_version.py
index 7dad73e4..09927542 100644
--- a/tests/test_version.py
+++ b/tests/test_version.py
@@ -7,7 +7,6 @@ from snakeoil import __version__, version
class TestVersion:
-
def setup_method(self, method):
# reset the cached version in the module
reload(version)
@@ -17,124 +16,142 @@ class TestVersion:
def test_get_version_unknown(self):
with pytest.raises(ValueError):
- version.get_version('snakeoilfoo', __file__)
+ version.get_version("snakeoilfoo", __file__)
def test_get_version_api(self):
- v = version.get_version('snakeoil', __file__, '9.9.9')
- assert v.startswith('snakeoil 9.9.9')
+ v = version.get_version("snakeoil", __file__, "9.9.9")
+ assert v.startswith("snakeoil 9.9.9")
def test_get_version_git_dev(self):
- with mock.patch('snakeoil.version.import_module') as import_module, \
- mock.patch('snakeoil.version.get_git_version') as get_git_version:
+ with mock.patch("snakeoil.version.import_module") as import_module, mock.patch(
+ "snakeoil.version.get_git_version"
+ ) as get_git_version:
import_module.side_effect = ImportError
verinfo = {
- 'rev': '1ff76b021d208f7df38ac524537b6419404f1c64',
- 'date': 'Mon Sep 25 13:50:24 2017 -0400',
- 'tag': None
+ "rev": "1ff76b021d208f7df38ac524537b6419404f1c64",
+ "date": "Mon Sep 25 13:50:24 2017 -0400",
+ "tag": None,
}
get_git_version.return_value = verinfo
- result = version.get_version('snakeoil', __file__, __version__)
- assert result == f"snakeoil {__version__}-g{verinfo['rev'][:7]} -- {verinfo['date']}"
+ result = version.get_version("snakeoil", __file__, __version__)
+ assert (
+ result
+ == f"snakeoil {__version__}-g{verinfo['rev'][:7]} -- {verinfo['date']}"
+ )
def test_get_version_git_release(self):
verinfo = {
- 'rev': 'ab38751890efa8be96b7f95938d6b868b769bab6',
- 'date': 'Thu Sep 21 15:57:38 2017 -0400',
- 'tag': '2.3.4',
+ "rev": "ab38751890efa8be96b7f95938d6b868b769bab6",
+ "date": "Thu Sep 21 15:57:38 2017 -0400",
+ "tag": "2.3.4",
}
# fake snakeoil._verinfo module object
class Verinfo:
version_info = verinfo
- with mock.patch('snakeoil.version.import_module') as import_module:
+ with mock.patch("snakeoil.version.import_module") as import_module:
import_module.return_value = Verinfo()
- result = version.get_version('snakeoil', __file__, verinfo['tag'])
+ result = version.get_version("snakeoil", __file__, verinfo["tag"])
assert result == f"snakeoil {verinfo['tag']} -- released {verinfo['date']}"
def test_get_version_no_git_version(self):
- with mock.patch('snakeoil.version.import_module') as import_module, \
- mock.patch('snakeoil.version.get_git_version') as get_git_version:
+ with mock.patch("snakeoil.version.import_module") as import_module, mock.patch(
+ "snakeoil.version.get_git_version"
+ ) as get_git_version:
import_module.side_effect = ImportError
get_git_version.return_value = None
- result = version.get_version('snakeoil', 'nonexistent', __version__)
- assert result == f'snakeoil {__version__}'
+ result = version.get_version("snakeoil", "nonexistent", __version__)
+ assert result == f"snakeoil {__version__}"
def test_get_version_caching(self):
# retrieved version info is cached in a module attr
- v = version.get_version('snakeoil', __file__)
- assert v.startswith(f'snakeoil {__version__}')
+ v = version.get_version("snakeoil", __file__)
+ assert v.startswith(f"snakeoil {__version__}")
# re-running get_version returns the cached attr instead of reprocessing
- with mock.patch('snakeoil.version.import_module') as import_module:
- v = version.get_version('snakeoil', __file__)
+ with mock.patch("snakeoil.version.import_module") as import_module:
+ v = version.get_version("snakeoil", __file__)
assert not import_module.called
class TestGitVersion:
-
def test_get_git_version_not_available(self):
- with mock.patch('snakeoil.version._run_git') as run_git:
- run_git.side_effect = EnvironmentError(errno.ENOENT, 'git not found')
- assert version.get_git_version('nonexistent') is None
+ with mock.patch("snakeoil.version._run_git") as run_git:
+ run_git.side_effect = EnvironmentError(errno.ENOENT, "git not found")
+ assert version.get_git_version("nonexistent") is None
def test_get_git_version_error(self):
- with mock.patch('snakeoil.version._run_git') as run_git:
- run_git.return_value = (b'foo', 1)
- assert version.get_git_version('nonexistent') is None
+ with mock.patch("snakeoil.version._run_git") as run_git:
+ run_git.return_value = (b"foo", 1)
+ assert version.get_git_version("nonexistent") is None
def test_get_git_version_non_repo(self, tmpdir):
assert version.get_git_version(str(tmpdir)) is None
def test_get_git_version_exc(self):
with pytest.raises(OSError):
- with mock.patch('snakeoil.version._run_git') as run_git:
- run_git.side_effect = OSError(errno.EIO, 'Input/output error')
- version.get_git_version('nonexistent')
+ with mock.patch("snakeoil.version._run_git") as run_git:
+ run_git.side_effect = OSError(errno.EIO, "Input/output error")
+ version.get_git_version("nonexistent")
def test_get_git_version_good_dev(self):
- with mock.patch('snakeoil.version._run_git') as run_git:
+ with mock.patch("snakeoil.version._run_git") as run_git:
# dev version
run_git.return_value = (
- b'1ff76b021d208f7df38ac524537b6419404f1c64\nMon Sep 25 13:50:24 2017 -0400', 0)
- result = version.get_git_version('nonexistent')
+ b"1ff76b021d208f7df38ac524537b6419404f1c64\nMon Sep 25 13:50:24 2017 -0400",
+ 0,
+ )
+ result = version.get_git_version("nonexistent")
expected = {
- 'rev': '1ff76b021d208f7df38ac524537b6419404f1c64',
- 'date': 'Mon Sep 25 13:50:24 2017 -0400',
- 'tag': None,
- 'commits': 2,
+ "rev": "1ff76b021d208f7df38ac524537b6419404f1c64",
+ "date": "Mon Sep 25 13:50:24 2017 -0400",
+ "tag": None,
+ "commits": 2,
}
assert result == expected
def test_get_git_version_good_tag(self):
- with mock.patch('snakeoil.version._run_git') as run_git, \
- mock.patch('snakeoil.version._get_git_tag') as get_git_tag:
+ with mock.patch("snakeoil.version._run_git") as run_git, mock.patch(
+ "snakeoil.version._get_git_tag"
+ ) as get_git_tag:
# tagged, release version
run_git.return_value = (
- b'1ff76b021d208f7df38ac524537b6419404f1c64\nMon Sep 25 13:50:24 2017 -0400', 0)
- get_git_tag.return_value = '1.1.1'
- result = version.get_git_version('nonexistent')
+ b"1ff76b021d208f7df38ac524537b6419404f1c64\nMon Sep 25 13:50:24 2017 -0400",
+ 0,
+ )
+ get_git_tag.return_value = "1.1.1"
+ result = version.get_git_version("nonexistent")
expected = {
- 'rev': '1ff76b021d208f7df38ac524537b6419404f1c64',
- 'date': 'Mon Sep 25 13:50:24 2017 -0400',
- 'tag': '1.1.1',
- 'commits': 2,
+ "rev": "1ff76b021d208f7df38ac524537b6419404f1c64",
+ "date": "Mon Sep 25 13:50:24 2017 -0400",
+ "tag": "1.1.1",
+ "commits": 2,
}
assert result == expected
def test_get_git_tag_bad_output(self):
- with mock.patch('snakeoil.version._run_git') as run_git:
+ with mock.patch("snakeoil.version._run_git") as run_git:
# unknown git tag rev output
- run_git.return_value = (b'a', 1)
- assert version._get_git_tag('foo', 'bar') is None
- run_git.return_value = (b'a foo/v0.7.2', 0)
- assert version._get_git_tag('foo', 'bar') is None
+ run_git.return_value = (b"a", 1)
+ assert version._get_git_tag("foo", "bar") is None
+ run_git.return_value = (b"a foo/v0.7.2", 0)
+ assert version._get_git_tag("foo", "bar") is None
# expected output formats
- run_git.return_value = (b'ab38751890efa8be96b7f95938d6b868b769bab6 tags/v1.1.1^0', 0)
- assert version._get_git_tag('foo', 'bar') == '1.1.1'
- run_git.return_value = (b'ab38751890efa8be96b7f95938d6b868b769bab6 tags/v1.1.1', 0)
- assert version._get_git_tag('foo', 'bar') == '1.1.1'
- run_git.return_value = (b'ab38751890efa8be96b7f95938d6b868b769bab6 tags/1.1.1', 0)
- assert version._get_git_tag('foo', 'bar') == '1.1.1'
+ run_git.return_value = (
+ b"ab38751890efa8be96b7f95938d6b868b769bab6 tags/v1.1.1^0",
+ 0,
+ )
+ assert version._get_git_tag("foo", "bar") == "1.1.1"
+ run_git.return_value = (
+ b"ab38751890efa8be96b7f95938d6b868b769bab6 tags/v1.1.1",
+ 0,
+ )
+ assert version._get_git_tag("foo", "bar") == "1.1.1"
+ run_git.return_value = (
+ b"ab38751890efa8be96b7f95938d6b868b769bab6 tags/1.1.1",
+ 0,
+ )
+ assert version._get_git_tag("foo", "bar") == "1.1.1"