SHA256
1
0
forked from pool/julia
julia/support-float16-depending-on-llvm-and-platform.patch

189 lines
8.1 KiB
Diff
Raw Normal View History

From 959902f1c6099c1b513e29103b998545c16731fc Mon Sep 17 00:00:00 2001
From: Valentin Churavy <vchuravy@users.noreply.github.com>
Date: Thu, 27 Apr 2023 16:27:09 -0400
Subject: [PATCH] Support both Float16 ABIs depending on LLVM and platform
(#49527)
There are two Float16 ABIs in the wild, one for platforms that have a
defing register and the original one where we used i16.
LLVM 15 follows GCC and uses the new ABI on x86/ARM but not PPC.
Co-authored-by: Gabriel Baraldi <baraldigabriel@gmail.com>
---
src/aotcompile.cpp | 11 +++++++--
src/codegen.cpp | 56 ++++++++++++++++++++++++++++++++++++++++++++++
src/jitlayers.cpp | 2 ++
src/llvm-version.h | 10 +++++++++
4 files changed, 77 insertions(+), 2 deletions(-)
diff --git a/src/aotcompile.cpp b/src/aotcompile.cpp
index 391c5d3df46fb..2a14e2a4fa0ab 100644
--- a/src/aotcompile.cpp
+++ b/src/aotcompile.cpp
@@ -494,6 +494,7 @@ static void reportWriterError(const ErrorInfoBase &E)
jl_safe_printf("ERROR: failed to emit output file %s\n", err.c_str());
}
+#if JULIA_FLOAT16_ABI == 1
static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionType *FT)
{
Function *target = M.getFunction(alias);
@@ -510,7 +511,8 @@ static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionT
auto val = builder.CreateCall(target, CallArgs);
builder.CreateRet(val);
}
-
+#endif
+void emitFloat16Wrappers(Module &M, bool external);
// takes the running content that has collected in the shadow module and dump it to disk
// this builds the object file portion of the sysimage files for fast startup
@@ -1003,6 +1006,7 @@ static void add_output_impl(Module &M, TargetMachine &SourceTM, std::string *out
}
if (inject_crt) {
+#if JULIA_FLOAT16_ABI == 1
// We would like to emit an alias or an weakref alias to redirect these symbols
// but LLVM doesn't let us emit a GlobalAlias to a declaration...
// So for now we inject a definition of these functions that calls our runtime
@@ -1018,6 +1023,9 @@ static void add_output_impl(Module &M, TargetMachine &SourceTM, std::string *out
FunctionType::get(Type::getHalfTy(Context), { Type::getFloatTy(Context) }, false));
injectCRTAlias(M, "__truncdfhf2", "julia__truncdfhf2",
FunctionType::get(Type::getHalfTy(Context), { Type::getDoubleTy(Context) }, false));
+#else
+ emitFloat16Wrappers(M, false);
+#endif
#if defined(_OS_WINDOWS_)
// Windows expect that the function `_DllMainStartup` is present in an dll.
diff --git a/src/codegen.cpp b/src/codegen.cpp
index 329c4b452a9dc..f4b0fd518cd39 100644
--- a/src/codegen.cpp
+++ b/src/codegen.cpp
@@ -5818,6 +5818,7 @@ static void emit_cfunc_invalidate(
prepare_call_in(gf_thunk->getParent(), jlapplygeneric_func));
}
+#include <iostream>
static Function* gen_cfun_wrapper(
Module *into, jl_codegen_params_t &params,
const function_sig_t &sig, jl_value_t *ff, const char *aliasname,
@@ -8704,6 +8705,58 @@ static JuliaVariable *julia_const_gv(jl_value_t *val)
return nullptr;
}
+// Handle FLOAT16 ABI v2
+#if JULIA_FLOAT16_ABI == 2
+static void makeCastCall(Module &M, StringRef wrapperName, StringRef calledName, FunctionType *FTwrapper, FunctionType *FTcalled, bool external)
+{
+ Function *calledFun = M.getFunction(calledName);
+ if (!calledFun) {
+ calledFun = Function::Create(FTcalled, Function::ExternalLinkage, calledName, M);
+ }
+ auto linkage = external ? Function::ExternalLinkage : Function::InternalLinkage;
+ auto wrapperFun = Function::Create(FTwrapper, linkage, wrapperName, M);
+ wrapperFun->addFnAttr(Attribute::AlwaysInline);
+ llvm::IRBuilder<> builder(BasicBlock::Create(M.getContext(), "top", wrapperFun));
+ SmallVector<Value *, 4> CallArgs;
+ if (wrapperFun->arg_size() != calledFun->arg_size()){
+ llvm::errs() << "FATAL ERROR: Can't match wrapper to called function";
+ abort();
+ }
+ for (auto wrapperArg = wrapperFun->arg_begin(), calledArg = calledFun->arg_begin();
+ wrapperArg != wrapperFun->arg_end() && calledArg != calledFun->arg_end(); ++wrapperArg, ++calledArg)
+ {
+ CallArgs.push_back(builder.CreateBitCast(wrapperArg, calledArg->getType()));
+ }
+ auto val = builder.CreateCall(calledFun, CallArgs);
+ auto retval = builder.CreateBitCast(val,wrapperFun->getReturnType());
+ builder.CreateRet(retval);
+}
+
+void emitFloat16Wrappers(Module &M, bool external)
+{
+ auto &ctx = M.getContext();
+ makeCastCall(M, "__gnu_h2f_ieee", "julia__gnu_h2f_ieee", FunctionType::get(Type::getFloatTy(ctx), { Type::getHalfTy(ctx) }, false),
+ FunctionType::get(Type::getFloatTy(ctx), { Type::getInt16Ty(ctx) }, false), external);
+ makeCastCall(M, "__extendhfsf2", "julia__gnu_h2f_ieee", FunctionType::get(Type::getFloatTy(ctx), { Type::getHalfTy(ctx) }, false),
+ FunctionType::get(Type::getFloatTy(ctx), { Type::getInt16Ty(ctx) }, false), external);
+ makeCastCall(M, "__gnu_f2h_ieee", "julia__gnu_f2h_ieee", FunctionType::get(Type::getHalfTy(ctx), { Type::getFloatTy(ctx) }, false),
+ FunctionType::get(Type::getInt16Ty(ctx), { Type::getFloatTy(ctx) }, false), external);
+ makeCastCall(M, "__truncsfhf2", "julia__gnu_f2h_ieee", FunctionType::get(Type::getHalfTy(ctx), { Type::getFloatTy(ctx) }, false),
+ FunctionType::get(Type::getInt16Ty(ctx), { Type::getFloatTy(ctx) }, false), external);
+ makeCastCall(M, "__truncdfhf2", "julia__truncdfhf2", FunctionType::get(Type::getHalfTy(ctx), { Type::getDoubleTy(ctx) }, false),
+ FunctionType::get(Type::getInt16Ty(ctx), { Type::getDoubleTy(ctx) }, false), external);
+}
+
+static void init_f16_funcs(void)
+{
+ auto ctx = jl_ExecutionEngine->acquireContext();
+ auto TSM = jl_create_ts_module("F16Wrappers", ctx, imaging_default());
+ auto aliasM = TSM.getModuleUnlocked();
+ emitFloat16Wrappers(*aliasM, true);
+ jl_ExecutionEngine->addModule(std::move(TSM));
+}
+#endif
+
static void init_jit_functions(void)
{
add_named_global(jlstack_chk_guard_var, &__stack_chk_guard);
@@ -8942,6 +8995,9 @@ extern "C" JL_DLLEXPORT void jl_init_codegen_impl(void)
jl_init_llvm();
// Now that the execution engine exists, initialize all modules
init_jit_functions();
+#if JULIA_FLOAT16_ABI == 2
+ init_f16_funcs();
+#endif
}
extern "C" JL_DLLEXPORT void jl_teardown_codegen_impl() JL_NOTSAFEPOINT
diff --git a/src/jitlayers.cpp b/src/jitlayers.cpp
index 37302e8ca2ace..b3ec102821858 100644
--- a/src/jitlayers.cpp
+++ b/src/jitlayers.cpp
@@ -1383,6 +1383,7 @@ JuliaOJIT::JuliaOJIT()
JD.addToLinkOrder(GlobalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);
+#if JULIA_FLOAT16_ABI == 1
orc::SymbolAliasMap jl_crt = {
{ mangle("__gnu_h2f_ieee"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
{ mangle("__extendhfsf2"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
@@ -1391,6 +1392,7 @@ JuliaOJIT::JuliaOJIT()
{ mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } }
};
cantFail(GlobalJD.define(orc::symbolAliases(jl_crt)));
+#endif
#ifdef MSAN_EMUTLS_WORKAROUND
orc::SymbolMap msan_crt;
diff --git a/src/llvm-version.h b/src/llvm-version.h
index 4e15e787b7de8..a3f3774b6dc15 100644
--- a/src/llvm-version.h
+++ b/src/llvm-version.h
@@ -2,6 +2,7 @@
#include <llvm/Config/llvm-config.h>
#include "julia_assert.h"
+#include "platform.h"
// The LLVM version used, JL_LLVM_VERSION, is represented as a 5-digit integer
// of the form ABBCC, where A is the major version, B is minor, and C is patch.
@@ -17,6 +18,15 @@
#define JL_LLVM_OPAQUE_POINTERS 1
#endif
+// Pre GCC 12 libgcc defined the ABI for Float16->Float32
+// to take an i16. GCC 12 silently changed the ABI to now pass
+// Float16 in Float32 registers.
+#if JL_LLVM_VERSION < 150000 || defined(_CPU_PPC64_) || defined(_CPU_PPC_)
+#define JULIA_FLOAT16_ABI 1
+#else
+#define JULIA_FLOAT16_ABI 2
+#endif
+
#ifdef __cplusplus
#if defined(__GNUC__) && (__GNUC__ >= 9)
// Added in GCC 9, this warning is annoying