From a1fc5287d501a1ecdbd259e5bbdd4f7d5d06dd13 Mon Sep 17 00:00:00 2001
From: Alexander Graul <agraul@suse.com>
Date: Fri, 28 Apr 2023 09:41:28 +0200
Subject: [PATCH] Make sure the file client is destroyed upon used

Backport of https://github.com/saltstack/salt/pull/64113
---
 salt/client/ssh/wrapper/saltcheck.py          | 108 +++----
 salt/fileclient.py                            |  11 -
 salt/modules/dockermod.py                     |  17 +-
 salt/pillar/__init__.py                       |   6 +-
 salt/states/ansiblegate.py                    |  11 +-
 salt/utils/asynchronous.py                    |   2 +-
 salt/utils/jinja.py                           |  53 ++-
 salt/utils/mako.py                            |   7 +
 salt/utils/templates.py                       | 303 +++++++++---------
 .../integration/states/test_include.py        |  40 +++
 .../utils/jinja/test_salt_cache_loader.py     |  47 ++-
 11 files changed, 330 insertions(+), 275 deletions(-)
 create mode 100644 tests/pytests/integration/states/test_include.py

diff --git a/salt/client/ssh/wrapper/saltcheck.py b/salt/client/ssh/wrapper/saltcheck.py
index d47b5cf6883..b0b94593809 100644
--- a/salt/client/ssh/wrapper/saltcheck.py
+++ b/salt/client/ssh/wrapper/saltcheck.py
@@ -9,6 +9,7 @@ import tarfile
 import tempfile
 from contextlib import closing
 
+import salt.fileclient
 import salt.utils.files
 import salt.utils.json
 import salt.utils.url
@@ -28,65 +29,62 @@ def update_master_cache(states, saltenv="base"):
     # Setup for copying states to gendir
     gendir = tempfile.mkdtemp()
     trans_tar = salt.utils.files.mkstemp()
-    if "cp.fileclient_{}".format(id(__opts__)) not in __context__:
-        __context__[
-            "cp.fileclient_{}".format(id(__opts__))
-        ] = salt.fileclient.get_file_client(__opts__)
-
-    # generate cp.list_states output and save to gendir
-    cp_output = salt.utils.json.dumps(__salt__["cp.list_states"]())
-    cp_output_file = os.path.join(gendir, "cp_output.txt")
-    with salt.utils.files.fopen(cp_output_file, "w") as fp:
-        fp.write(cp_output)
-
-    # cp state directories to gendir
-    already_processed = []
-    sls_list = salt.utils.args.split_input(states)
-    for state_name in sls_list:
-        # generate low data for each state and save to gendir
-        state_low_file = os.path.join(gendir, state_name + ".low")
-        state_low_output = salt.utils.json.dumps(
-            __salt__["state.show_low_sls"](state_name)
-        )
-        with salt.utils.files.fopen(state_low_file, "w") as fp:
-            fp.write(state_low_output)
-
-        state_name = state_name.replace(".", os.sep)
-        if state_name in already_processed:
-            log.debug("Already cached state for %s", state_name)
-        else:
-            file_copy_file = os.path.join(gendir, state_name + ".copy")
-            log.debug("copying %s to %s", state_name, gendir)
-            qualified_name = salt.utils.url.create(state_name, saltenv)
-            # Duplicate cp.get_dir to gendir
-            copy_result = __context__["cp.fileclient_{}".format(id(__opts__))].get_dir(
-                qualified_name, gendir, saltenv
+    with salt.fileclient.get_file_client(__opts__) as cp_fileclient:
+
+        # generate cp.list_states output and save to gendir
+        cp_output = salt.utils.json.dumps(__salt__["cp.list_states"]())
+        cp_output_file = os.path.join(gendir, "cp_output.txt")
+        with salt.utils.files.fopen(cp_output_file, "w") as fp:
+            fp.write(cp_output)
+
+        # cp state directories to gendir
+        already_processed = []
+        sls_list = salt.utils.args.split_input(states)
+        for state_name in sls_list:
+            # generate low data for each state and save to gendir
+            state_low_file = os.path.join(gendir, state_name + ".low")
+            state_low_output = salt.utils.json.dumps(
+                __salt__["state.show_low_sls"](state_name)
             )
-            if copy_result:
-                copy_result = [dir.replace(gendir, state_cache) for dir in copy_result]
-                copy_result_output = salt.utils.json.dumps(copy_result)
-                with salt.utils.files.fopen(file_copy_file, "w") as fp:
-                    fp.write(copy_result_output)
-                already_processed.append(state_name)
+            with salt.utils.files.fopen(state_low_file, "w") as fp:
+                fp.write(state_low_output)
+
+            state_name = state_name.replace(".", os.sep)
+            if state_name in already_processed:
+                log.debug("Already cached state for %s", state_name)
             else:
-                # If files were not copied, assume state.file.sls was given and just copy state
-                state_name = os.path.dirname(state_name)
                 file_copy_file = os.path.join(gendir, state_name + ".copy")
-                if state_name in already_processed:
-                    log.debug("Already cached state for %s", state_name)
+                log.debug("copying %s to %s", state_name, gendir)
+                qualified_name = salt.utils.url.create(state_name, saltenv)
+                # Duplicate cp.get_dir to gendir
+                copy_result = cp_fileclient.get_dir(qualified_name, gendir, saltenv)
+                if copy_result:
+                    copy_result = [
+                        dir.replace(gendir, state_cache) for dir in copy_result
+                    ]
+                    copy_result_output = salt.utils.json.dumps(copy_result)
+                    with salt.utils.files.fopen(file_copy_file, "w") as fp:
+                        fp.write(copy_result_output)
+                    already_processed.append(state_name)
                 else:
-                    qualified_name = salt.utils.url.create(state_name, saltenv)
-                    copy_result = __context__[
-                        "cp.fileclient_{}".format(id(__opts__))
-                    ].get_dir(qualified_name, gendir, saltenv)
-                    if copy_result:
-                        copy_result = [
-                            dir.replace(gendir, state_cache) for dir in copy_result
-                        ]
-                        copy_result_output = salt.utils.json.dumps(copy_result)
-                        with salt.utils.files.fopen(file_copy_file, "w") as fp:
-                            fp.write(copy_result_output)
-                        already_processed.append(state_name)
+                    # If files were not copied, assume state.file.sls was given and just copy state
+                    state_name = os.path.dirname(state_name)
+                    file_copy_file = os.path.join(gendir, state_name + ".copy")
+                    if state_name in already_processed:
+                        log.debug("Already cached state for %s", state_name)
+                    else:
+                        qualified_name = salt.utils.url.create(state_name, saltenv)
+                        copy_result = cp_fileclient.get_dir(
+                            qualified_name, gendir, saltenv
+                        )
+                        if copy_result:
+                            copy_result = [
+                                dir.replace(gendir, state_cache) for dir in copy_result
+                            ]
+                            copy_result_output = salt.utils.json.dumps(copy_result)
+                            with salt.utils.files.fopen(file_copy_file, "w") as fp:
+                                fp.write(copy_result_output)
+                            already_processed.append(state_name)
 
     # turn gendir into tarball and remove gendir
     try:
diff --git a/salt/fileclient.py b/salt/fileclient.py
index fef5154a0be..f01a86dd0d4 100644
--- a/salt/fileclient.py
+++ b/salt/fileclient.py
@@ -849,7 +849,6 @@ class Client:
             kwargs.pop("env")
 
         kwargs["saltenv"] = saltenv
-        url_data = urllib.parse.urlparse(url)
         sfn = self.cache_file(url, saltenv, cachedir=cachedir)
         if not sfn or not os.path.exists(sfn):
             return ""
@@ -1165,13 +1164,8 @@ class RemoteClient(Client):
 
         if not salt.utils.platform.is_windows():
             hash_server, stat_server = self.hash_and_stat_file(path, saltenv)
-            try:
-                mode_server = stat_server[0]
-            except (IndexError, TypeError):
-                mode_server = None
         else:
             hash_server = self.hash_file(path, saltenv)
-            mode_server = None
 
         # Check if file exists on server, before creating files and
         # directories
@@ -1214,13 +1208,8 @@ class RemoteClient(Client):
         if dest2check and os.path.isfile(dest2check):
             if not salt.utils.platform.is_windows():
                 hash_local, stat_local = self.hash_and_stat_file(dest2check, saltenv)
-                try:
-                    mode_local = stat_local[0]
-                except (IndexError, TypeError):
-                    mode_local = None
             else:
                 hash_local = self.hash_file(dest2check, saltenv)
-                mode_local = None
 
             if hash_local == hash_server:
                 return dest2check
diff --git a/salt/modules/dockermod.py b/salt/modules/dockermod.py
index f7344b66ac6..69b722f0c95 100644
--- a/salt/modules/dockermod.py
+++ b/salt/modules/dockermod.py
@@ -6667,14 +6667,6 @@ def script_retcode(
     )["retcode"]
 
 
-def _mk_fileclient():
-    """
-    Create a file client and add it to the context.
-    """
-    if "cp.fileclient" not in __context__:
-        __context__["cp.fileclient"] = salt.fileclient.get_file_client(__opts__)
-
-
 def _generate_tmp_path():
     return os.path.join("/tmp", "salt.docker.{}".format(uuid.uuid4().hex[:6]))
 
@@ -6688,11 +6680,10 @@ def _prepare_trans_tar(name, sls_opts, mods=None, pillar=None, extra_filerefs=""
     # reuse it from salt.ssh, however this function should
     # be somewhere else
     refs = salt.client.ssh.state.lowstate_file_refs(chunks, extra_filerefs)
-    _mk_fileclient()
-    trans_tar = salt.client.ssh.state.prep_trans_tar(
-        __context__["cp.fileclient"], chunks, refs, pillar, name
-    )
-    return trans_tar
+    with salt.fileclient.get_file_client(__opts__) as fileclient:
+        return salt.client.ssh.state.prep_trans_tar(
+            fileclient, chunks, refs, pillar, name
+        )
 
 
 def _compile_state(sls_opts, mods=None):
diff --git a/salt/pillar/__init__.py b/salt/pillar/__init__.py
index 0dfab4cc579..26312b3bd53 100644
--- a/salt/pillar/__init__.py
+++ b/salt/pillar/__init__.py
@@ -9,7 +9,6 @@ import logging
 import os
 import sys
 import traceback
-import uuid
 
 import salt.channel.client
 import salt.ext.tornado.gen
@@ -1351,6 +1350,11 @@ class Pillar:
         if hasattr(self, "_closing") and self._closing:
             return
         self._closing = True
+        if self.client:
+            try:
+                self.client.destroy()
+            except AttributeError:
+                pass
 
     # pylint: disable=W1701
     def __del__(self):
diff --git a/salt/states/ansiblegate.py b/salt/states/ansiblegate.py
index 7fd4deb6c2a..9abd418c42c 100644
--- a/salt/states/ansiblegate.py
+++ b/salt/states/ansiblegate.py
@@ -32,12 +32,10 @@ state:
           - state: installed
 
 """
-
 import logging
 import os
 import sys
 
-# Import salt modules
 import salt.fileclient
 import salt.utils.decorators.path
 from salt.utils.decorators import depends
@@ -108,13 +106,6 @@ def __virtual__():
     return __virtualname__
 
 
-def _client():
-    """
-    Get a fileclient
-    """
-    return salt.fileclient.get_file_client(__opts__)
-
-
 def _changes(plays):
     """
     Find changes in ansible return data
@@ -171,7 +162,7 @@ def playbooks(name, rundir=None, git_repo=None, git_kwargs=None, ansible_kwargs=
     }
     if git_repo:
         if not isinstance(rundir, str) or not os.path.isdir(rundir):
-            with _client() as client:
+            with salt.fileclient.get_file_client(__opts__) as client:
                 rundir = client._extrn_path(git_repo, "base")
             log.trace("rundir set to %s", rundir)
         if not isinstance(git_kwargs, dict):
diff --git a/salt/utils/asynchronous.py b/salt/utils/asynchronous.py
index 2a858feee98..0c645bbc3bb 100644
--- a/salt/utils/asynchronous.py
+++ b/salt/utils/asynchronous.py
@@ -131,7 +131,7 @@ class SyncWrapper:
             result = io_loop.run_sync(lambda: getattr(self.obj, key)(*args, **kwargs))
             results.append(True)
             results.append(result)
-        except Exception as exc:  # pylint: disable=broad-except
+        except Exception:  # pylint: disable=broad-except
             results.append(False)
             results.append(sys.exc_info())
 
diff --git a/salt/utils/jinja.py b/salt/utils/jinja.py
index fcc5aec497e..a6a8a279605 100644
--- a/salt/utils/jinja.py
+++ b/salt/utils/jinja.py
@@ -58,19 +58,6 @@ class SaltCacheLoader(BaseLoader):
     and only loaded once per loader instance.
     """
 
-    _cached_pillar_client = None
-    _cached_client = None
-
-    @classmethod
-    def shutdown(cls):
-        for attr in ("_cached_client", "_cached_pillar_client"):
-            client = getattr(cls, attr, None)
-            if client is not None:
-                # PillarClient and LocalClient objects do not have a destroy method
-                if hasattr(client, "destroy"):
-                    client.destroy()
-                setattr(cls, attr, None)
-
     def __init__(
         self,
         opts,
@@ -93,8 +80,7 @@ class SaltCacheLoader(BaseLoader):
         log.debug("Jinja search path: %s", self.searchpath)
         self.cached = []
         self._file_client = _file_client
-        # Instantiate the fileclient
-        self.file_client()
+        self._close_file_client = _file_client is None
 
     def file_client(self):
         """
@@ -108,18 +94,10 @@ class SaltCacheLoader(BaseLoader):
             or not hasattr(self._file_client, "opts")
             or self._file_client.opts["file_roots"] != self.opts["file_roots"]
         ):
-            attr = "_cached_pillar_client" if self.pillar_rend else "_cached_client"
-            cached_client = getattr(self, attr, None)
-            if (
-                cached_client is None
-                or not hasattr(cached_client, "opts")
-                or cached_client.opts["file_roots"] != self.opts["file_roots"]
-            ):
-                cached_client = salt.fileclient.get_file_client(
-                    self.opts, self.pillar_rend
-                )
-                setattr(SaltCacheLoader, attr, cached_client)
-            self._file_client = cached_client
+            self._file_client = salt.fileclient.get_file_client(
+                self.opts, self.pillar_rend
+            )
+            self._close_file_client = True
         return self._file_client
 
     def cache_file(self, template):
@@ -221,6 +199,27 @@ class SaltCacheLoader(BaseLoader):
         # there is no template file within searchpaths
         raise TemplateNotFound(template)
 
+    def destroy(self):
+        if self._close_file_client is False:
+            return
+        if self._file_client is None:
+            return
+        file_client = self._file_client
+        self._file_client = None
+
+        try:
+            file_client.destroy()
+        except AttributeError:
+            # PillarClient and LocalClient objects do not have a destroy method
+            pass
+
+    def __enter__(self):
+        self.file_client()
+        return self
+
+    def __exit__(self, *args):
+        self.destroy()
+
 
 class PrintableDict(OrderedDict):
     """
diff --git a/salt/utils/mako.py b/salt/utils/mako.py
index 69618de9837..037d5d86deb 100644
--- a/salt/utils/mako.py
+++ b/salt/utils/mako.py
@@ -97,3 +97,10 @@ if HAS_MAKO:
                 self.cache[fpath] = self.file_client().get_file(
                     fpath, "", True, self.saltenv
                 )
+
+        def destroy(self):
+            if self.client:
+                try:
+                    self.client.destroy()
+                except AttributeError:
+                    pass
diff --git a/salt/utils/templates.py b/salt/utils/templates.py
index 4947b820a36..4a8adf2a14f 100644
--- a/salt/utils/templates.py
+++ b/salt/utils/templates.py
@@ -362,163 +362,169 @@ def render_jinja_tmpl(tmplstr, context, tmplpath=None):
     elif tmplstr.endswith("\n"):
         newline = "\n"
 
-    if not saltenv:
-        if tmplpath:
-            loader = jinja2.FileSystemLoader(os.path.dirname(tmplpath))
-    else:
-        loader = salt.utils.jinja.SaltCacheLoader(
-            opts,
-            saltenv,
-            pillar_rend=context.get("_pillar_rend", False),
-            _file_client=file_client,
-        )
+    try:
+        if not saltenv:
+            if tmplpath:
+                loader = jinja2.FileSystemLoader(os.path.dirname(tmplpath))
+        else:
+            loader = salt.utils.jinja.SaltCacheLoader(
+                opts,
+                saltenv,
+                pillar_rend=context.get("_pillar_rend", False),
+                _file_client=file_client,
+            )
 
-    env_args = {"extensions": [], "loader": loader}
-
-    if hasattr(jinja2.ext, "with_"):
-        env_args["extensions"].append("jinja2.ext.with_")
-    if hasattr(jinja2.ext, "do"):
-        env_args["extensions"].append("jinja2.ext.do")
-    if hasattr(jinja2.ext, "loopcontrols"):
-        env_args["extensions"].append("jinja2.ext.loopcontrols")
-    env_args["extensions"].append(salt.utils.jinja.SerializerExtension)
-
-    opt_jinja_env = opts.get("jinja_env", {})
-    opt_jinja_sls_env = opts.get("jinja_sls_env", {})
-
-    opt_jinja_env = opt_jinja_env if isinstance(opt_jinja_env, dict) else {}
-    opt_jinja_sls_env = opt_jinja_sls_env if isinstance(opt_jinja_sls_env, dict) else {}
-
-    # Pass through trim_blocks and lstrip_blocks Jinja parameters
-    # trim_blocks removes newlines around Jinja blocks
-    # lstrip_blocks strips tabs and spaces from the beginning of
-    # line to the start of a block.
-    if opts.get("jinja_trim_blocks", False):
-        log.debug("Jinja2 trim_blocks is enabled")
-        log.warning(
-            "jinja_trim_blocks is deprecated and will be removed in a future release,"
-            " please use jinja_env and/or jinja_sls_env instead"
-        )
-        opt_jinja_env["trim_blocks"] = True
-        opt_jinja_sls_env["trim_blocks"] = True
-    if opts.get("jinja_lstrip_blocks", False):
-        log.debug("Jinja2 lstrip_blocks is enabled")
-        log.warning(
-            "jinja_lstrip_blocks is deprecated and will be removed in a future release,"
-            " please use jinja_env and/or jinja_sls_env instead"
-        )
-        opt_jinja_env["lstrip_blocks"] = True
-        opt_jinja_sls_env["lstrip_blocks"] = True
-
-    def opt_jinja_env_helper(opts, optname):
-        for k, v in opts.items():
-            k = k.lower()
-            if hasattr(jinja2.defaults, k.upper()):
-                log.debug("Jinja2 environment %s was set to %s by %s", k, v, optname)
-                env_args[k] = v
-            else:
-                log.warning("Jinja2 environment %s is not recognized", k)
+        env_args = {"extensions": [], "loader": loader}
 
-    if "sls" in context and context["sls"] != "":
-        opt_jinja_env_helper(opt_jinja_sls_env, "jinja_sls_env")
-    else:
-        opt_jinja_env_helper(opt_jinja_env, "jinja_env")
+        if hasattr(jinja2.ext, "with_"):
+            env_args["extensions"].append("jinja2.ext.with_")
+        if hasattr(jinja2.ext, "do"):
+            env_args["extensions"].append("jinja2.ext.do")
+        if hasattr(jinja2.ext, "loopcontrols"):
+            env_args["extensions"].append("jinja2.ext.loopcontrols")
+        env_args["extensions"].append(salt.utils.jinja.SerializerExtension)
 
-    if opts.get("allow_undefined", False):
-        jinja_env = jinja2.sandbox.SandboxedEnvironment(**env_args)
-    else:
-        jinja_env = jinja2.sandbox.SandboxedEnvironment(
-            undefined=jinja2.StrictUndefined, **env_args
-        )
+        opt_jinja_env = opts.get("jinja_env", {})
+        opt_jinja_sls_env = opts.get("jinja_sls_env", {})
 
-    indent_filter = jinja_env.filters.get("indent")
-    jinja_env.tests.update(JinjaTest.salt_jinja_tests)
-    jinja_env.filters.update(JinjaFilter.salt_jinja_filters)
-    if salt.utils.jinja.JINJA_VERSION >= Version("2.11"):
-        # Use the existing indent filter on Jinja versions where it's not broken
-        jinja_env.filters["indent"] = indent_filter
-    jinja_env.globals.update(JinjaGlobal.salt_jinja_globals)
-
-    # globals
-    jinja_env.globals["odict"] = OrderedDict
-    jinja_env.globals["show_full_context"] = salt.utils.jinja.show_full_context
-
-    jinja_env.tests["list"] = salt.utils.data.is_list
-
-    decoded_context = {}
-    for key, value in context.items():
-        if not isinstance(value, str):
-            if isinstance(value, NamedLoaderContext):
-                decoded_context[key] = value.value()
-            else:
-                decoded_context[key] = value
-            continue
+        opt_jinja_env = opt_jinja_env if isinstance(opt_jinja_env, dict) else {}
+        opt_jinja_sls_env = (
+            opt_jinja_sls_env if isinstance(opt_jinja_sls_env, dict) else {}
+        )
 
-        try:
-            decoded_context[key] = salt.utils.stringutils.to_unicode(
-                value, encoding=SLS_ENCODING
+        # Pass through trim_blocks and lstrip_blocks Jinja parameters
+        # trim_blocks removes newlines around Jinja blocks
+        # lstrip_blocks strips tabs and spaces from the beginning of
+        # line to the start of a block.
+        if opts.get("jinja_trim_blocks", False):
+            log.debug("Jinja2 trim_blocks is enabled")
+            log.warning(
+                "jinja_trim_blocks is deprecated and will be removed in a future release,"
+                " please use jinja_env and/or jinja_sls_env instead"
             )
-        except UnicodeDecodeError as ex:
-            log.debug(
-                "Failed to decode using default encoding (%s), trying system encoding",
-                SLS_ENCODING,
+            opt_jinja_env["trim_blocks"] = True
+            opt_jinja_sls_env["trim_blocks"] = True
+        if opts.get("jinja_lstrip_blocks", False):
+            log.debug("Jinja2 lstrip_blocks is enabled")
+            log.warning(
+                "jinja_lstrip_blocks is deprecated and will be removed in a future release,"
+                " please use jinja_env and/or jinja_sls_env instead"
             )
-            decoded_context[key] = salt.utils.data.decode(value)
+            opt_jinja_env["lstrip_blocks"] = True
+            opt_jinja_sls_env["lstrip_blocks"] = True
+
+        def opt_jinja_env_helper(opts, optname):
+            for k, v in opts.items():
+                k = k.lower()
+                if hasattr(jinja2.defaults, k.upper()):
+                    log.debug(
+                        "Jinja2 environment %s was set to %s by %s", k, v, optname
+                    )
+                    env_args[k] = v
+                else:
+                    log.warning("Jinja2 environment %s is not recognized", k)
 
-    jinja_env.globals.update(decoded_context)
-    try:
-        template = jinja_env.from_string(tmplstr)
-        output = template.render(**decoded_context)
-    except jinja2.exceptions.UndefinedError as exc:
-        trace = traceback.extract_tb(sys.exc_info()[2])
-        line, out = _get_jinja_error(trace, context=decoded_context)
-        if not line:
-            tmplstr = ""
-        raise SaltRenderError("Jinja variable {}{}".format(exc, out), line, tmplstr)
-    except (
-        jinja2.exceptions.TemplateRuntimeError,
-        jinja2.exceptions.TemplateSyntaxError,
-        jinja2.exceptions.SecurityError,
-    ) as exc:
-        trace = traceback.extract_tb(sys.exc_info()[2])
-        line, out = _get_jinja_error(trace, context=decoded_context)
-        if not line:
-            tmplstr = ""
-        raise SaltRenderError(
-            "Jinja syntax error: {}{}".format(exc, out), line, tmplstr
-        )
-    except (SaltInvocationError, CommandExecutionError) as exc:
-        trace = traceback.extract_tb(sys.exc_info()[2])
-        line, out = _get_jinja_error(trace, context=decoded_context)
-        if not line:
-            tmplstr = ""
-        raise SaltRenderError(
-            "Problem running salt function in Jinja template: {}{}".format(exc, out),
-            line,
-            tmplstr,
-        )
-    except Exception as exc:  # pylint: disable=broad-except
-        tracestr = traceback.format_exc()
-        trace = traceback.extract_tb(sys.exc_info()[2])
-        line, out = _get_jinja_error(trace, context=decoded_context)
-        if not line:
-            tmplstr = ""
+        if "sls" in context and context["sls"] != "":
+            opt_jinja_env_helper(opt_jinja_sls_env, "jinja_sls_env")
         else:
-            tmplstr += "\n{}".format(tracestr)
-        log.debug("Jinja Error")
-        log.debug("Exception:", exc_info=True)
-        log.debug("Out: %s", out)
-        log.debug("Line: %s", line)
-        log.debug("TmplStr: %s", tmplstr)
-        log.debug("TraceStr: %s", tracestr)
+            opt_jinja_env_helper(opt_jinja_env, "jinja_env")
 
-        raise SaltRenderError(
-            "Jinja error: {}{}".format(exc, out), line, tmplstr, trace=tracestr
-        )
+        if opts.get("allow_undefined", False):
+            jinja_env = jinja2.sandbox.SandboxedEnvironment(**env_args)
+        else:
+            jinja_env = jinja2.sandbox.SandboxedEnvironment(
+                undefined=jinja2.StrictUndefined, **env_args
+            )
+
+        indent_filter = jinja_env.filters.get("indent")
+        jinja_env.tests.update(JinjaTest.salt_jinja_tests)
+        jinja_env.filters.update(JinjaFilter.salt_jinja_filters)
+        if salt.utils.jinja.JINJA_VERSION >= Version("2.11"):
+            # Use the existing indent filter on Jinja versions where it's not broken
+            jinja_env.filters["indent"] = indent_filter
+        jinja_env.globals.update(JinjaGlobal.salt_jinja_globals)
+
+        # globals
+        jinja_env.globals["odict"] = OrderedDict
+        jinja_env.globals["show_full_context"] = salt.utils.jinja.show_full_context
+
+        jinja_env.tests["list"] = salt.utils.data.is_list
+
+        decoded_context = {}
+        for key, value in context.items():
+            if not isinstance(value, str):
+                if isinstance(value, NamedLoaderContext):
+                    decoded_context[key] = value.value()
+                else:
+                    decoded_context[key] = value
+                continue
+
+            try:
+                decoded_context[key] = salt.utils.stringutils.to_unicode(
+                    value, encoding=SLS_ENCODING
+                )
+            except UnicodeDecodeError:
+                log.debug(
+                    "Failed to decode using default encoding (%s), trying system encoding",
+                    SLS_ENCODING,
+                )
+                decoded_context[key] = salt.utils.data.decode(value)
+
+        jinja_env.globals.update(decoded_context)
+        try:
+            template = jinja_env.from_string(tmplstr)
+            output = template.render(**decoded_context)
+        except jinja2.exceptions.UndefinedError as exc:
+            trace = traceback.extract_tb(sys.exc_info()[2])
+            line, out = _get_jinja_error(trace, context=decoded_context)
+            if not line:
+                tmplstr = ""
+            raise SaltRenderError("Jinja variable {}{}".format(exc, out), line, tmplstr)
+        except (
+            jinja2.exceptions.TemplateRuntimeError,
+            jinja2.exceptions.TemplateSyntaxError,
+            jinja2.exceptions.SecurityError,
+        ) as exc:
+            trace = traceback.extract_tb(sys.exc_info()[2])
+            line, out = _get_jinja_error(trace, context=decoded_context)
+            if not line:
+                tmplstr = ""
+            raise SaltRenderError(
+                "Jinja syntax error: {}{}".format(exc, out), line, tmplstr
+            )
+        except (SaltInvocationError, CommandExecutionError) as exc:
+            trace = traceback.extract_tb(sys.exc_info()[2])
+            line, out = _get_jinja_error(trace, context=decoded_context)
+            if not line:
+                tmplstr = ""
+            raise SaltRenderError(
+                "Problem running salt function in Jinja template: {}{}".format(
+                    exc, out
+                ),
+                line,
+                tmplstr,
+            )
+        except Exception as exc:  # pylint: disable=broad-except
+            tracestr = traceback.format_exc()
+            trace = traceback.extract_tb(sys.exc_info()[2])
+            line, out = _get_jinja_error(trace, context=decoded_context)
+            if not line:
+                tmplstr = ""
+            else:
+                tmplstr += "\n{}".format(tracestr)
+            log.debug("Jinja Error")
+            log.debug("Exception:", exc_info=True)
+            log.debug("Out: %s", out)
+            log.debug("Line: %s", line)
+            log.debug("TmplStr: %s", tmplstr)
+            log.debug("TraceStr: %s", tracestr)
+
+            raise SaltRenderError(
+                "Jinja error: {}{}".format(exc, out), line, tmplstr, trace=tracestr
+            )
     finally:
-        if loader and hasattr(loader, "_file_client"):
-            if hasattr(loader._file_client, "destroy"):
-                loader._file_client.destroy()
+        if loader and isinstance(loader, salt.utils.jinja.SaltCacheLoader):
+            loader.destroy()
 
     # Workaround a bug in Jinja that removes the final newline
     # (https://github.com/mitsuhiko/jinja2/issues/75)
@@ -569,9 +575,8 @@ def render_mako_tmpl(tmplstr, context, tmplpath=None):
     except Exception:  # pylint: disable=broad-except
         raise SaltRenderError(mako.exceptions.text_error_template().render())
     finally:
-        if lookup and hasattr(lookup, "_file_client"):
-            if hasattr(lookup._file_client, "destroy"):
-                lookup._file_client.destroy()
+        if lookup and isinstance(lookup, SaltMakoTemplateLookup):
+            lookup.destroy()
 
 
 def render_wempy_tmpl(tmplstr, context, tmplpath=None):
diff --git a/tests/pytests/integration/states/test_include.py b/tests/pytests/integration/states/test_include.py
new file mode 100644
index 00000000000..f814328c5e4
--- /dev/null
+++ b/tests/pytests/integration/states/test_include.py
@@ -0,0 +1,40 @@
+"""
+Integration tests for the jinja includes in states
+"""
+import logging
+
+import pytest
+
+log = logging.getLogger(__name__)
+
+
+@pytest.mark.slow_test
+def test_issue_64111(salt_master, salt_minion, salt_call_cli):
+    # This needs to be an integration test. A functional test does not trigger
+    # the issue fixed.
+
+    macros_jinja = """
+    {% macro a_jinja_macro(arg) -%}
+    {{ arg }}
+    {%- endmacro %}
+    """
+
+    init_sls = """
+    include:
+      - common.file1
+    """
+
+    file1_sls = """
+    {% from 'common/macros.jinja' import a_jinja_macro with context %}
+
+    a state id:
+      cmd.run:
+        - name: echo {{ a_jinja_macro("hello world") }}
+    """
+    tf = salt_master.state_tree.base.temp_file
+
+    with tf("common/macros.jinja", macros_jinja):
+        with tf("common/init.sls", init_sls):
+            with tf("common/file1.sls", file1_sls):
+                ret = salt_call_cli.run("state.apply", "common")
+                assert ret.returncode == 0
diff --git a/tests/pytests/unit/utils/jinja/test_salt_cache_loader.py b/tests/pytests/unit/utils/jinja/test_salt_cache_loader.py
index 38c5ce5b724..e0f5fa158ff 100644
--- a/tests/pytests/unit/utils/jinja/test_salt_cache_loader.py
+++ b/tests/pytests/unit/utils/jinja/test_salt_cache_loader.py
@@ -15,7 +15,7 @@ import salt.utils.json  # pylint: disable=unused-import
 import salt.utils.stringutils  # pylint: disable=unused-import
 import salt.utils.yaml  # pylint: disable=unused-import
 from salt.utils.jinja import SaltCacheLoader
-from tests.support.mock import Mock, patch
+from tests.support.mock import Mock, call, patch
 
 
 @pytest.fixture
@@ -224,14 +224,45 @@ def test_file_client_kwarg(minion_opts, mock_file_client):
     assert loader._file_client is mock_file_client
 
 
-def test_cache_loader_shutdown(minion_opts, mock_file_client):
+def test_cache_loader_passed_file_client(minion_opts, mock_file_client):
     """
     The shudown method can be called without raising an exception when the
     file_client does not have a destroy method
     """
-    assert not hasattr(mock_file_client, "destroy")
-    mock_file_client.opts = minion_opts
-    loader = SaltCacheLoader(minion_opts, _file_client=mock_file_client)
-    assert loader._file_client is mock_file_client
-    # Shutdown method should not raise any exceptions
-    loader.shutdown()
+    # Test SaltCacheLoader creating and destroying the file client created
+    file_client = Mock()
+    with patch("salt.fileclient.get_file_client", return_value=file_client):
+        loader = SaltCacheLoader(minion_opts)
+        assert loader._file_client is None
+        with loader:
+            assert loader._file_client is file_client
+        assert loader._file_client is None
+        assert file_client.mock_calls == [call.destroy()]
+
+    # Test SaltCacheLoader reusing the file client passed
+    file_client = Mock()
+    file_client.opts = {"file_roots": minion_opts["file_roots"]}
+    with patch("salt.fileclient.get_file_client", return_value=Mock()):
+        loader = SaltCacheLoader(minion_opts, _file_client=file_client)
+        assert loader._file_client is file_client
+        with loader:
+            assert loader._file_client is file_client
+        assert loader._file_client is file_client
+        assert file_client.mock_calls == []
+
+    # Test SaltCacheLoader creating a client even though a file client was
+    # passed because the "file_roots" option is different, and, as such,
+    # the destroy method on the new file client is called, but not on the
+    # file client passed in.
+    file_client = Mock()
+    file_client.opts = {"file_roots": ""}
+    new_file_client = Mock()
+    with patch("salt.fileclient.get_file_client", return_value=new_file_client):
+        loader = SaltCacheLoader(minion_opts, _file_client=file_client)
+        assert loader._file_client is file_client
+        with loader:
+            assert loader._file_client is not file_client
+            assert loader._file_client is new_file_client
+        assert loader._file_client is None
+        assert file_client.mock_calls == []
+        assert new_file_client.mock_calls == [call.destroy()]
-- 
2.40.0