forked from pool/julia
189 lines
8.1 KiB
Diff
189 lines
8.1 KiB
Diff
|
From 959902f1c6099c1b513e29103b998545c16731fc Mon Sep 17 00:00:00 2001
|
||
|
From: Valentin Churavy <vchuravy@users.noreply.github.com>
|
||
|
Date: Thu, 27 Apr 2023 16:27:09 -0400
|
||
|
Subject: [PATCH] Support both Float16 ABIs depending on LLVM and platform
|
||
|
(#49527)
|
||
|
|
||
|
There are two Float16 ABIs in the wild, one for platforms that have a
|
||
|
defing register and the original one where we used i16.
|
||
|
|
||
|
LLVM 15 follows GCC and uses the new ABI on x86/ARM but not PPC.
|
||
|
|
||
|
Co-authored-by: Gabriel Baraldi <baraldigabriel@gmail.com>
|
||
|
---
|
||
|
src/aotcompile.cpp | 11 +++++++--
|
||
|
src/codegen.cpp | 56 ++++++++++++++++++++++++++++++++++++++++++++++
|
||
|
src/jitlayers.cpp | 2 ++
|
||
|
src/llvm-version.h | 10 +++++++++
|
||
|
4 files changed, 77 insertions(+), 2 deletions(-)
|
||
|
|
||
|
diff --git a/src/aotcompile.cpp b/src/aotcompile.cpp
|
||
|
index 391c5d3df46fb..2a14e2a4fa0ab 100644
|
||
|
--- a/src/aotcompile.cpp
|
||
|
+++ b/src/aotcompile.cpp
|
||
|
@@ -494,6 +494,7 @@ static void reportWriterError(const ErrorInfoBase &E)
|
||
|
jl_safe_printf("ERROR: failed to emit output file %s\n", err.c_str());
|
||
|
}
|
||
|
|
||
|
+#if JULIA_FLOAT16_ABI == 1
|
||
|
static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionType *FT)
|
||
|
{
|
||
|
Function *target = M.getFunction(alias);
|
||
|
@@ -510,7 +511,8 @@ static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionT
|
||
|
auto val = builder.CreateCall(target, CallArgs);
|
||
|
builder.CreateRet(val);
|
||
|
}
|
||
|
-
|
||
|
+#endif
|
||
|
+void emitFloat16Wrappers(Module &M, bool external);
|
||
|
|
||
|
// takes the running content that has collected in the shadow module and dump it to disk
|
||
|
// this builds the object file portion of the sysimage files for fast startup
|
||
|
@@ -1003,6 +1006,7 @@ static void add_output_impl(Module &M, TargetMachine &SourceTM, std::string *out
|
||
|
}
|
||
|
|
||
|
if (inject_crt) {
|
||
|
+#if JULIA_FLOAT16_ABI == 1
|
||
|
// We would like to emit an alias or an weakref alias to redirect these symbols
|
||
|
// but LLVM doesn't let us emit a GlobalAlias to a declaration...
|
||
|
// So for now we inject a definition of these functions that calls our runtime
|
||
|
@@ -1018,6 +1023,9 @@ static void add_output_impl(Module &M, TargetMachine &SourceTM, std::string *out
|
||
|
FunctionType::get(Type::getHalfTy(Context), { Type::getFloatTy(Context) }, false));
|
||
|
injectCRTAlias(M, "__truncdfhf2", "julia__truncdfhf2",
|
||
|
FunctionType::get(Type::getHalfTy(Context), { Type::getDoubleTy(Context) }, false));
|
||
|
+#else
|
||
|
+ emitFloat16Wrappers(M, false);
|
||
|
+#endif
|
||
|
|
||
|
#if defined(_OS_WINDOWS_)
|
||
|
// Windows expect that the function `_DllMainStartup` is present in an dll.
|
||
|
diff --git a/src/codegen.cpp b/src/codegen.cpp
|
||
|
index 329c4b452a9dc..f4b0fd518cd39 100644
|
||
|
--- a/src/codegen.cpp
|
||
|
+++ b/src/codegen.cpp
|
||
|
@@ -5818,6 +5818,7 @@ static void emit_cfunc_invalidate(
|
||
|
prepare_call_in(gf_thunk->getParent(), jlapplygeneric_func));
|
||
|
}
|
||
|
|
||
|
+#include <iostream>
|
||
|
static Function* gen_cfun_wrapper(
|
||
|
Module *into, jl_codegen_params_t ¶ms,
|
||
|
const function_sig_t &sig, jl_value_t *ff, const char *aliasname,
|
||
|
@@ -8704,6 +8705,58 @@ static JuliaVariable *julia_const_gv(jl_value_t *val)
|
||
|
return nullptr;
|
||
|
}
|
||
|
|
||
|
+// Handle FLOAT16 ABI v2
|
||
|
+#if JULIA_FLOAT16_ABI == 2
|
||
|
+static void makeCastCall(Module &M, StringRef wrapperName, StringRef calledName, FunctionType *FTwrapper, FunctionType *FTcalled, bool external)
|
||
|
+{
|
||
|
+ Function *calledFun = M.getFunction(calledName);
|
||
|
+ if (!calledFun) {
|
||
|
+ calledFun = Function::Create(FTcalled, Function::ExternalLinkage, calledName, M);
|
||
|
+ }
|
||
|
+ auto linkage = external ? Function::ExternalLinkage : Function::InternalLinkage;
|
||
|
+ auto wrapperFun = Function::Create(FTwrapper, linkage, wrapperName, M);
|
||
|
+ wrapperFun->addFnAttr(Attribute::AlwaysInline);
|
||
|
+ llvm::IRBuilder<> builder(BasicBlock::Create(M.getContext(), "top", wrapperFun));
|
||
|
+ SmallVector<Value *, 4> CallArgs;
|
||
|
+ if (wrapperFun->arg_size() != calledFun->arg_size()){
|
||
|
+ llvm::errs() << "FATAL ERROR: Can't match wrapper to called function";
|
||
|
+ abort();
|
||
|
+ }
|
||
|
+ for (auto wrapperArg = wrapperFun->arg_begin(), calledArg = calledFun->arg_begin();
|
||
|
+ wrapperArg != wrapperFun->arg_end() && calledArg != calledFun->arg_end(); ++wrapperArg, ++calledArg)
|
||
|
+ {
|
||
|
+ CallArgs.push_back(builder.CreateBitCast(wrapperArg, calledArg->getType()));
|
||
|
+ }
|
||
|
+ auto val = builder.CreateCall(calledFun, CallArgs);
|
||
|
+ auto retval = builder.CreateBitCast(val,wrapperFun->getReturnType());
|
||
|
+ builder.CreateRet(retval);
|
||
|
+}
|
||
|
+
|
||
|
+void emitFloat16Wrappers(Module &M, bool external)
|
||
|
+{
|
||
|
+ auto &ctx = M.getContext();
|
||
|
+ makeCastCall(M, "__gnu_h2f_ieee", "julia__gnu_h2f_ieee", FunctionType::get(Type::getFloatTy(ctx), { Type::getHalfTy(ctx) }, false),
|
||
|
+ FunctionType::get(Type::getFloatTy(ctx), { Type::getInt16Ty(ctx) }, false), external);
|
||
|
+ makeCastCall(M, "__extendhfsf2", "julia__gnu_h2f_ieee", FunctionType::get(Type::getFloatTy(ctx), { Type::getHalfTy(ctx) }, false),
|
||
|
+ FunctionType::get(Type::getFloatTy(ctx), { Type::getInt16Ty(ctx) }, false), external);
|
||
|
+ makeCastCall(M, "__gnu_f2h_ieee", "julia__gnu_f2h_ieee", FunctionType::get(Type::getHalfTy(ctx), { Type::getFloatTy(ctx) }, false),
|
||
|
+ FunctionType::get(Type::getInt16Ty(ctx), { Type::getFloatTy(ctx) }, false), external);
|
||
|
+ makeCastCall(M, "__truncsfhf2", "julia__gnu_f2h_ieee", FunctionType::get(Type::getHalfTy(ctx), { Type::getFloatTy(ctx) }, false),
|
||
|
+ FunctionType::get(Type::getInt16Ty(ctx), { Type::getFloatTy(ctx) }, false), external);
|
||
|
+ makeCastCall(M, "__truncdfhf2", "julia__truncdfhf2", FunctionType::get(Type::getHalfTy(ctx), { Type::getDoubleTy(ctx) }, false),
|
||
|
+ FunctionType::get(Type::getInt16Ty(ctx), { Type::getDoubleTy(ctx) }, false), external);
|
||
|
+}
|
||
|
+
|
||
|
+static void init_f16_funcs(void)
|
||
|
+{
|
||
|
+ auto ctx = jl_ExecutionEngine->acquireContext();
|
||
|
+ auto TSM = jl_create_ts_module("F16Wrappers", ctx, imaging_default());
|
||
|
+ auto aliasM = TSM.getModuleUnlocked();
|
||
|
+ emitFloat16Wrappers(*aliasM, true);
|
||
|
+ jl_ExecutionEngine->addModule(std::move(TSM));
|
||
|
+}
|
||
|
+#endif
|
||
|
+
|
||
|
static void init_jit_functions(void)
|
||
|
{
|
||
|
add_named_global(jlstack_chk_guard_var, &__stack_chk_guard);
|
||
|
@@ -8942,6 +8995,9 @@ extern "C" JL_DLLEXPORT void jl_init_codegen_impl(void)
|
||
|
jl_init_llvm();
|
||
|
// Now that the execution engine exists, initialize all modules
|
||
|
init_jit_functions();
|
||
|
+#if JULIA_FLOAT16_ABI == 2
|
||
|
+ init_f16_funcs();
|
||
|
+#endif
|
||
|
}
|
||
|
|
||
|
extern "C" JL_DLLEXPORT void jl_teardown_codegen_impl() JL_NOTSAFEPOINT
|
||
|
diff --git a/src/jitlayers.cpp b/src/jitlayers.cpp
|
||
|
index 37302e8ca2ace..b3ec102821858 100644
|
||
|
--- a/src/jitlayers.cpp
|
||
|
+++ b/src/jitlayers.cpp
|
||
|
@@ -1383,6 +1383,7 @@ JuliaOJIT::JuliaOJIT()
|
||
|
|
||
|
JD.addToLinkOrder(GlobalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);
|
||
|
|
||
|
+#if JULIA_FLOAT16_ABI == 1
|
||
|
orc::SymbolAliasMap jl_crt = {
|
||
|
{ mangle("__gnu_h2f_ieee"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
|
||
|
{ mangle("__extendhfsf2"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
|
||
|
@@ -1391,6 +1392,7 @@ JuliaOJIT::JuliaOJIT()
|
||
|
{ mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } }
|
||
|
};
|
||
|
cantFail(GlobalJD.define(orc::symbolAliases(jl_crt)));
|
||
|
+#endif
|
||
|
|
||
|
#ifdef MSAN_EMUTLS_WORKAROUND
|
||
|
orc::SymbolMap msan_crt;
|
||
|
diff --git a/src/llvm-version.h b/src/llvm-version.h
|
||
|
index 4e15e787b7de8..a3f3774b6dc15 100644
|
||
|
--- a/src/llvm-version.h
|
||
|
+++ b/src/llvm-version.h
|
||
|
@@ -2,6 +2,7 @@
|
||
|
|
||
|
#include <llvm/Config/llvm-config.h>
|
||
|
#include "julia_assert.h"
|
||
|
+#include "platform.h"
|
||
|
|
||
|
// The LLVM version used, JL_LLVM_VERSION, is represented as a 5-digit integer
|
||
|
// of the form ABBCC, where A is the major version, B is minor, and C is patch.
|
||
|
@@ -17,6 +18,15 @@
|
||
|
#define JL_LLVM_OPAQUE_POINTERS 1
|
||
|
#endif
|
||
|
|
||
|
+// Pre GCC 12 libgcc defined the ABI for Float16->Float32
|
||
|
+// to take an i16. GCC 12 silently changed the ABI to now pass
|
||
|
+// Float16 in Float32 registers.
|
||
|
+#if JL_LLVM_VERSION < 150000 || defined(_CPU_PPC64_) || defined(_CPU_PPC_)
|
||
|
+#define JULIA_FLOAT16_ABI 1
|
||
|
+#else
|
||
|
+#define JULIA_FLOAT16_ABI 2
|
||
|
+#endif
|
||
|
+
|
||
|
#ifdef __cplusplus
|
||
|
#if defined(__GNUC__) && (__GNUC__ >= 9)
|
||
|
// Added in GCC 9, this warning is annoying
|