forked from ROCm/hipblaslt
70 lines
2.0 KiB
Diff
70 lines
2.0 KiB
Diff
From c87ae445460a4f3d3a6a2ebfb4a21551ce79f6d3 Mon Sep 17 00:00:00 2001
|
|
From: Tom Rix <Tom.Rix@amd.com>
|
|
Date: Sat, 19 Apr 2025 09:38:00 -0700
|
|
Subject: [PATCH] hipblaslt handle missing joblib
|
|
|
|
---
|
|
tensilelite/Tensile/Parallel.py | 33 ++++++++++++++++++++++++---------
|
|
1 file changed, 24 insertions(+), 9 deletions(-)
|
|
|
|
diff --git a/tensilelite/Tensile/Parallel.py b/tensilelite/Tensile/Parallel.py
|
|
index bf9a44db1c58..76812753f173 100644
|
|
--- a/tensilelite/Tensile/Parallel.py
|
|
+++ b/tensilelite/Tensile/Parallel.py
|
|
@@ -28,17 +28,28 @@ import sys
|
|
import time
|
|
import concurrent.futures
|
|
|
|
-from joblib import Parallel, delayed
|
|
+try:
|
|
+ import joblib
|
|
+except:
|
|
+ joblib = None
|
|
+
|
|
+if joblib != None:
|
|
+ from joblib import Parallel, delayed
|
|
|
|
def joblibParallelSupportsGenerator():
|
|
- import joblib
|
|
- from packaging.version import Version
|
|
+ try:
|
|
+ import joblib
|
|
+ except:
|
|
+ joblib = None
|
|
+ if joblib == None:
|
|
+ return True
|
|
+ from distutils.version import StrictVersion
|
|
joblibVer = joblib.__version__
|
|
- return Version(joblibVer) >= Version("1.4.0")
|
|
+ return StrictVersion(joblibVer) >= StrictVersion("1.4.0")
|
|
|
|
def CPUThreadCount(enable=True):
|
|
from .Common import globalParameters
|
|
- if not enable:
|
|
+ if not enable or joblib == None:
|
|
return 1
|
|
else:
|
|
if os.name == "nt":
|
|
@@ -190,10 +201,14 @@ def ParallelMap2(function, objects, message="", enable=True, multiArg=True, retu
|
|
from . import Utils
|
|
threadCount = CPUThreadCount(enable)
|
|
|
|
- if threadCount <= 1 and globalParameters["ShowProgressBar"]:
|
|
- # Provide a progress bar for single-threaded operation.
|
|
- callFunc = lambda args: function(*args) if multiArg else lambda args: function(args)
|
|
- return [callFunc(args) for args in Utils.tqdm(objects, message)]
|
|
+ if threadCount <= 1 :
|
|
+ rv = []
|
|
+ for args in Utils.tqdm(objects, message):
|
|
+ if multiArg:
|
|
+ rv.append(function(*args))
|
|
+ else:
|
|
+ rv.append(function(args))
|
|
+ return rv
|
|
|
|
countMessage = ""
|
|
try:
|
|
--
|
|
2.48.1
|
|
|