From 959902f1c6099c1b513e29103b998545c16731fc Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 27 Apr 2023 16:27:09 -0400 Subject: [PATCH] Support both Float16 ABIs depending on LLVM and platform (#49527) There are two Float16 ABIs in the wild, one for platforms that have a defing register and the original one where we used i16. LLVM 15 follows GCC and uses the new ABI on x86/ARM but not PPC. Co-authored-by: Gabriel Baraldi --- src/aotcompile.cpp | 11 +++++++-- src/codegen.cpp | 56 ++++++++++++++++++++++++++++++++++++++++++++++ src/jitlayers.cpp | 2 ++ src/llvm-version.h | 10 +++++++++ 4 files changed, 77 insertions(+), 2 deletions(-) diff --git a/src/aotcompile.cpp b/src/aotcompile.cpp index 391c5d3df46fb..2a14e2a4fa0ab 100644 --- a/src/aotcompile.cpp +++ b/src/aotcompile.cpp @@ -494,6 +494,7 @@ static void reportWriterError(const ErrorInfoBase &E) jl_safe_printf("ERROR: failed to emit output file %s\n", err.c_str()); } +#if JULIA_FLOAT16_ABI == 1 static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionType *FT) { Function *target = M.getFunction(alias); @@ -510,7 +511,8 @@ static void injectCRTAlias(Module &M, StringRef name, StringRef alias, FunctionT auto val = builder.CreateCall(target, CallArgs); builder.CreateRet(val); } - +#endif +void emitFloat16Wrappers(Module &M, bool external); // takes the running content that has collected in the shadow module and dump it to disk // this builds the object file portion of the sysimage files for fast startup @@ -1003,6 +1006,7 @@ static void add_output_impl(Module &M, TargetMachine &SourceTM, std::string *out } if (inject_crt) { +#if JULIA_FLOAT16_ABI == 1 // We would like to emit an alias or an weakref alias to redirect these symbols // but LLVM doesn't let us emit a GlobalAlias to a declaration... // So for now we inject a definition of these functions that calls our runtime @@ -1018,6 +1023,9 @@ static void add_output_impl(Module &M, TargetMachine &SourceTM, std::string *out FunctionType::get(Type::getHalfTy(Context), { Type::getFloatTy(Context) }, false)); injectCRTAlias(M, "__truncdfhf2", "julia__truncdfhf2", FunctionType::get(Type::getHalfTy(Context), { Type::getDoubleTy(Context) }, false)); +#else + emitFloat16Wrappers(M, false); +#endif #if defined(_OS_WINDOWS_) // Windows expect that the function `_DllMainStartup` is present in an dll. diff --git a/src/codegen.cpp b/src/codegen.cpp index 329c4b452a9dc..f4b0fd518cd39 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -5818,6 +5818,7 @@ static void emit_cfunc_invalidate( prepare_call_in(gf_thunk->getParent(), jlapplygeneric_func)); } +#include static Function* gen_cfun_wrapper( Module *into, jl_codegen_params_t ¶ms, const function_sig_t &sig, jl_value_t *ff, const char *aliasname, @@ -8704,6 +8705,58 @@ static JuliaVariable *julia_const_gv(jl_value_t *val) return nullptr; } +// Handle FLOAT16 ABI v2 +#if JULIA_FLOAT16_ABI == 2 +static void makeCastCall(Module &M, StringRef wrapperName, StringRef calledName, FunctionType *FTwrapper, FunctionType *FTcalled, bool external) +{ + Function *calledFun = M.getFunction(calledName); + if (!calledFun) { + calledFun = Function::Create(FTcalled, Function::ExternalLinkage, calledName, M); + } + auto linkage = external ? Function::ExternalLinkage : Function::InternalLinkage; + auto wrapperFun = Function::Create(FTwrapper, linkage, wrapperName, M); + wrapperFun->addFnAttr(Attribute::AlwaysInline); + llvm::IRBuilder<> builder(BasicBlock::Create(M.getContext(), "top", wrapperFun)); + SmallVector CallArgs; + if (wrapperFun->arg_size() != calledFun->arg_size()){ + llvm::errs() << "FATAL ERROR: Can't match wrapper to called function"; + abort(); + } + for (auto wrapperArg = wrapperFun->arg_begin(), calledArg = calledFun->arg_begin(); + wrapperArg != wrapperFun->arg_end() && calledArg != calledFun->arg_end(); ++wrapperArg, ++calledArg) + { + CallArgs.push_back(builder.CreateBitCast(wrapperArg, calledArg->getType())); + } + auto val = builder.CreateCall(calledFun, CallArgs); + auto retval = builder.CreateBitCast(val,wrapperFun->getReturnType()); + builder.CreateRet(retval); +} + +void emitFloat16Wrappers(Module &M, bool external) +{ + auto &ctx = M.getContext(); + makeCastCall(M, "__gnu_h2f_ieee", "julia__gnu_h2f_ieee", FunctionType::get(Type::getFloatTy(ctx), { Type::getHalfTy(ctx) }, false), + FunctionType::get(Type::getFloatTy(ctx), { Type::getInt16Ty(ctx) }, false), external); + makeCastCall(M, "__extendhfsf2", "julia__gnu_h2f_ieee", FunctionType::get(Type::getFloatTy(ctx), { Type::getHalfTy(ctx) }, false), + FunctionType::get(Type::getFloatTy(ctx), { Type::getInt16Ty(ctx) }, false), external); + makeCastCall(M, "__gnu_f2h_ieee", "julia__gnu_f2h_ieee", FunctionType::get(Type::getHalfTy(ctx), { Type::getFloatTy(ctx) }, false), + FunctionType::get(Type::getInt16Ty(ctx), { Type::getFloatTy(ctx) }, false), external); + makeCastCall(M, "__truncsfhf2", "julia__gnu_f2h_ieee", FunctionType::get(Type::getHalfTy(ctx), { Type::getFloatTy(ctx) }, false), + FunctionType::get(Type::getInt16Ty(ctx), { Type::getFloatTy(ctx) }, false), external); + makeCastCall(M, "__truncdfhf2", "julia__truncdfhf2", FunctionType::get(Type::getHalfTy(ctx), { Type::getDoubleTy(ctx) }, false), + FunctionType::get(Type::getInt16Ty(ctx), { Type::getDoubleTy(ctx) }, false), external); +} + +static void init_f16_funcs(void) +{ + auto ctx = jl_ExecutionEngine->acquireContext(); + auto TSM = jl_create_ts_module("F16Wrappers", ctx, imaging_default()); + auto aliasM = TSM.getModuleUnlocked(); + emitFloat16Wrappers(*aliasM, true); + jl_ExecutionEngine->addModule(std::move(TSM)); +} +#endif + static void init_jit_functions(void) { add_named_global(jlstack_chk_guard_var, &__stack_chk_guard); @@ -8942,6 +8995,9 @@ extern "C" JL_DLLEXPORT void jl_init_codegen_impl(void) jl_init_llvm(); // Now that the execution engine exists, initialize all modules init_jit_functions(); +#if JULIA_FLOAT16_ABI == 2 + init_f16_funcs(); +#endif } extern "C" JL_DLLEXPORT void jl_teardown_codegen_impl() JL_NOTSAFEPOINT diff --git a/src/jitlayers.cpp b/src/jitlayers.cpp index 37302e8ca2ace..b3ec102821858 100644 --- a/src/jitlayers.cpp +++ b/src/jitlayers.cpp @@ -1383,6 +1383,7 @@ JuliaOJIT::JuliaOJIT() JD.addToLinkOrder(GlobalJD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly); +#if JULIA_FLOAT16_ABI == 1 orc::SymbolAliasMap jl_crt = { { mangle("__gnu_h2f_ieee"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } }, { mangle("__extendhfsf2"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } }, @@ -1391,6 +1392,7 @@ JuliaOJIT::JuliaOJIT() { mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } } }; cantFail(GlobalJD.define(orc::symbolAliases(jl_crt))); +#endif #ifdef MSAN_EMUTLS_WORKAROUND orc::SymbolMap msan_crt; diff --git a/src/llvm-version.h b/src/llvm-version.h index 4e15e787b7de8..a3f3774b6dc15 100644 --- a/src/llvm-version.h +++ b/src/llvm-version.h @@ -2,6 +2,7 @@ #include #include "julia_assert.h" +#include "platform.h" // The LLVM version used, JL_LLVM_VERSION, is represented as a 5-digit integer // of the form ABBCC, where A is the major version, B is minor, and C is patch. @@ -17,6 +18,15 @@ #define JL_LLVM_OPAQUE_POINTERS 1 #endif +// Pre GCC 12 libgcc defined the ABI for Float16->Float32 +// to take an i16. GCC 12 silently changed the ABI to now pass +// Float16 in Float32 registers. +#if JL_LLVM_VERSION < 150000 || defined(_CPU_PPC64_) || defined(_CPU_PPC_) +#define JULIA_FLOAT16_ABI 1 +#else +#define JULIA_FLOAT16_ABI 2 +#endif + #ifdef __cplusplus #if defined(__GNUC__) && (__GNUC__ >= 9) // Added in GCC 9, this warning is annoying