Source code for craft_platforms._distro

# This file is part of craft-platforms.
#
# Copyright 2024 Canonical Ltd.
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License version 3, as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranties of MERCHANTABILITY,
# SATISFACTORY QUALITY, or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with this program.  If not, see <http://www.gnu.org/licenses/>.
"""Distribution related utilities."""

from __future__ import annotations

import contextlib
import dataclasses
import typing
from typing import List, Union

import distro
from typing_extensions import Self


[docs]@typing.runtime_checkable class BaseName(typing.Protocol): """A protocol for any class that can be used as an OS base. This protocol exists as a backwards compatibility shim for the language used in craft-providers. """ @property def name(self) -> str: ... @property def version(self) -> str: ...
def _get_series_tuple(version_str: str) -> tuple[Union[int, str], ...]: """Convert a version string into a version tuple.""" parts = typing.cast(List[Union[str, int]], version_str.split(".")) # Try converting each part to an integer, leaving as a string if not doable. for idx, part in enumerate(parts): with contextlib.suppress(ValueError): parts[idx] = int(part) return tuple(parts) def _get_distro(base: Union[DistroBase, BaseName] | tuple[str, str]) -> str: """Get the distribution of a base.""" if isinstance(base, DistroBase): return base.distribution if isinstance(base, BaseName): return base.name return base[0] def _get_series(base: Union[DistroBase, BaseName] | tuple[str, str]) -> str: """Get the version of a base.""" if isinstance(base, DistroBase): return base.series if isinstance(base, BaseName): return base.version return base[1]
[docs]@dataclasses.dataclass(repr=True) class DistroBase: """A linux distribution base.""" distribution: str series: str def _ensure_bases_comparable( self, other: Union[DistroBase, BaseName] | tuple[str, str], ) -> None: """Ensure that these bases are comparable, raising an exception if not. :param other: Another distribution base. :raises: ValueError if the distribution bases are not comparable. """ other_distro = _get_distro(other) if self.distribution != other_distro: raise ValueError( f"Different distributions ({self.distribution} and {other_distro}) do not have comparable versions.", ) def __eq__(self, other: object, /) -> bool: if isinstance(other, (DistroBase, BaseName, tuple)): other_distro = _get_distro(other) other_series = _get_series(other) else: return NotImplemented if self.distribution != other_distro: return False # The series is allowed to be more specific on one side. return all( this_part == other_part for this_part, other_part in zip( self.series.split("."), other_series.split(".") ) ) def __lt__(self, other: Union[Self, BaseName, tuple[str, str]]) -> bool: self._ensure_bases_comparable(other) other_version = _get_series(other) if self.series == "devel" or other_version == "devel": return self.series != "devel" and other_version == "devel" self_version_tuple = _get_series_tuple(self.series) other_version_tuple = _get_series_tuple(other_version) return self_version_tuple < other_version_tuple def __le__(self, other: Union[Self, BaseName, tuple[str, str]]) -> bool: self._ensure_bases_comparable(other) other_version = _get_series(other) if self.series == "devel" or other_version == "devel": return other_version == "devel" self_version_tuple = _get_series_tuple(self.series) other_version_tuple = _get_series_tuple(other_version) return self_version_tuple <= other_version_tuple def __gt__(self, other: Union[Self, BaseName, tuple[str, str]]) -> bool: self._ensure_bases_comparable(other) other_version = _get_series(other) if self.series == "devel" or other_version == "devel": return other_version != "devel" self_version_tuple = _get_series_tuple(self.series) other_version_tuple = _get_series_tuple(other_version) return self_version_tuple > other_version_tuple def __ge__(self, other: Union[Self, BaseName, tuple[str, str]]) -> bool: self._ensure_bases_comparable(other) other_version = _get_series(other) if self.series == "devel" or other_version == "devel": return self.series == "devel" self_version_tuple = _get_series_tuple(self.series) other_version_tuple = _get_series_tuple(_get_series(other)) return self_version_tuple >= other_version_tuple
[docs] @classmethod def from_str(cls, base_str: str) -> Self: """Parse a distribution string to a DistroBase. :param base_str: A distribution string (e.g. "[email protected]") :returns: A DistroBase of this string. :raises: ValueError if the string isn't of the appropriate format. """ # "devel" is an exception and corresponds to `ubuntu@devel` if base_str == "devel": return cls("ubuntu", "devel") if base_str.count("@") != 1: raise ValueError( f"Invalid base string {base_str!r}. Format should be '<distribution>@<series>'", ) distribution, _, series = base_str.partition("@") return cls(distribution, series)
[docs] @classmethod def from_linux_distribution(cls, distribution: distro.LinuxDistribution) -> Self: """Convert a distro package's LinuxDistribution object to a DistroBase. :param distribution: A LinuxDistribution from the distro package. :returns: A matching DistroBase object. """ return cls(distribution=distribution.id(), series=distribution.version())
def __str__(self) -> str: return f"{self.distribution}@{self.series}"
[docs]def is_ubuntu_like(distribution: Union[distro.LinuxDistribution, None] = None) -> bool: """Determine whether the given distribution is Ubuntu or Ubuntu-like. :param distribution: Linux distribution info object, or None to use the host system. :returns: A boolean noting whether the given distribution is Ubuntu or Ubuntu-like. """ if distribution is None: distribution = distro.LinuxDistribution() if distribution.id() == "ubuntu": return True distros_like = distribution.like().split() return "ubuntu" in distros_like