MLIR 22.0.0git
FunctionCallUtils.cpp
Go to the documentation of this file.
1//===- FunctionCallUtils.cpp - Utilities for C function calls -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements helper functions to call common simple C functions in
10// LLVMIR (e.g. amon others to support printing and debugging).
11//
12//===----------------------------------------------------------------------===//
13
16#include "mlir/IR/Builders.h"
18#include "mlir/Support/LLVM.h"
19
20using namespace mlir;
21using namespace mlir::LLVM;
22
23/// Helper functions to lookup or create the declaration for commonly used
24/// external C function calls. The list of functions provided here must be
25/// implemented separately (e.g. as part of a support runtime library or as
26/// part of the libc).
27static constexpr llvm::StringRef kPrintI64 = "printI64";
28static constexpr llvm::StringRef kPrintU64 = "printU64";
29static constexpr llvm::StringRef kPrintF16 = "printF16";
30static constexpr llvm::StringRef kPrintBF16 = "printBF16";
31static constexpr llvm::StringRef kPrintF32 = "printF32";
32static constexpr llvm::StringRef kPrintF64 = "printF64";
33static constexpr llvm::StringRef kPrintApFloat = "printApFloat";
34static constexpr llvm::StringRef kPrintString = "printString";
35static constexpr llvm::StringRef kPrintOpen = "printOpen";
36static constexpr llvm::StringRef kPrintClose = "printClose";
37static constexpr llvm::StringRef kPrintComma = "printComma";
38static constexpr llvm::StringRef kPrintNewline = "printNewline";
39static constexpr llvm::StringRef kMalloc = "malloc";
40static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc";
41static constexpr llvm::StringRef kFree = "free";
42static constexpr llvm::StringRef kGenericAlloc = "_mlir_memref_to_llvm_alloc";
43static constexpr llvm::StringRef kGenericAlignedAlloc =
44 "_mlir_memref_to_llvm_aligned_alloc";
45static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
46static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
47
48namespace {
49/// Search for an LLVMFuncOp with a given name within an operation with the
50/// SymbolTable trait. An optional collection of cached symbol tables can be
51/// given to avoid a linear scan of the symbol table operation.
52LLVM::LLVMFuncOp lookupFuncOp(StringRef name, Operation *symbolTableOp,
53 SymbolTableCollection *symbolTables = nullptr) {
54 if (symbolTables) {
55 return symbolTables->lookupSymbolIn<LLVM::LLVMFuncOp>(
56 symbolTableOp, StringAttr::get(symbolTableOp->getContext(), name));
57 }
58
59 return llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
60 SymbolTable::lookupSymbolIn(symbolTableOp, name));
61}
62} // namespace
63
64/// Generic print function lookupOrCreate helper.
65FailureOr<LLVM::LLVMFuncOp>
67 ArrayRef<Type> paramTypes, Type resultType,
68 bool isVarArg, bool isReserved,
69 SymbolTableCollection *symbolTables) {
70 assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
71 "expected SymbolTable operation");
72 auto func = lookupFuncOp(name, moduleOp, symbolTables);
73 auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg);
74 // Assert the signature of the found function is same as expected
75 if (func) {
76 if (funcT != func.getFunctionType()) {
77 if (isReserved) {
78 func.emitError("redefinition of reserved function '")
79 << name << "' of different type " << func.getFunctionType()
80 << " is prohibited";
81 } else {
82 func.emitError("redefinition of function '")
83 << name << "' of different type " << funcT << " is prohibited";
84 }
85 return failure();
86 }
87 return func;
88 }
89
91 assert(!moduleOp->getRegion(0).empty() && "expected non-empty region");
92 b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
93 auto funcOp = LLVM::LLVMFuncOp::create(
94 b, moduleOp->getLoc(), name,
95 LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
96
97 if (symbolTables) {
98 SymbolTable &symbolTable = symbolTables->getSymbolTable(moduleOp);
99 symbolTable.insert(funcOp, moduleOp->getRegion(0).front().begin());
100 }
101
102 return funcOp;
103}
104
105static FailureOr<LLVM::LLVMFuncOp>
106lookupOrCreateReservedFn(OpBuilder &b, Operation *moduleOp, StringRef name,
107 ArrayRef<Type> paramTypes, Type resultType,
108 SymbolTableCollection *symbolTables) {
109 return lookupOrCreateFn(b, moduleOp, name, paramTypes, resultType,
110 /*isVarArg=*/false, /*isReserved=*/true,
111 symbolTables);
112}
113
114FailureOr<LLVM::LLVMFuncOp>
116 SymbolTableCollection *symbolTables) {
118 b, moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64),
119 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
120}
121
122FailureOr<LLVM::LLVMFuncOp>
124 SymbolTableCollection *symbolTables) {
126 b, moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64),
127 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
128}
129
130FailureOr<LLVM::LLVMFuncOp>
132 SymbolTableCollection *symbolTables) {
134 b, moduleOp, kPrintF16,
135 IntegerType::get(moduleOp->getContext(), 16), // bits!
136 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
137}
138
139FailureOr<LLVM::LLVMFuncOp>
141 SymbolTableCollection *symbolTables) {
143 b, moduleOp, kPrintBF16,
144 IntegerType::get(moduleOp->getContext(), 16), // bits!
145 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
146}
147
148FailureOr<LLVM::LLVMFuncOp>
150 SymbolTableCollection *symbolTables) {
152 b, moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()),
153 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
154}
155
156FailureOr<LLVM::LLVMFuncOp>
158 SymbolTableCollection *symbolTables) {
160 b, moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()),
161 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
162}
163
164FailureOr<LLVM::LLVMFuncOp>
166 SymbolTableCollection *symbolTables) {
168 b, moduleOp, kPrintApFloat,
169 {IntegerType::get(moduleOp->getContext(), 32),
170 IntegerType::get(moduleOp->getContext(), 64)},
171 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
172}
173
174static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
175 return LLVM::LLVMPointerType::get(context);
176}
177
178static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
179 // A char pointer and void ptr are the same in LLVM IR.
180 return getCharPtr(context);
181}
182
183FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreatePrintStringFn(
184 OpBuilder &b, Operation *moduleOp,
185 std::optional<StringRef> runtimeFunctionName,
186 SymbolTableCollection *symbolTables) {
188 b, moduleOp, runtimeFunctionName.value_or(kPrintString),
189 getCharPtr(moduleOp->getContext()),
190 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
191}
192
193FailureOr<LLVM::LLVMFuncOp>
195 SymbolTableCollection *symbolTables) {
197 b, moduleOp, kPrintOpen, {},
198 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
199}
200
201FailureOr<LLVM::LLVMFuncOp>
203 SymbolTableCollection *symbolTables) {
205 b, moduleOp, kPrintClose, {},
206 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
207}
208
209FailureOr<LLVM::LLVMFuncOp>
211 SymbolTableCollection *symbolTables) {
213 b, moduleOp, kPrintComma, {},
214 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
215}
216
217FailureOr<LLVM::LLVMFuncOp>
219 SymbolTableCollection *symbolTables) {
221 b, moduleOp, kPrintNewline, {},
222 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
223}
224
225FailureOr<LLVM::LLVMFuncOp>
227 Type indexType,
228 SymbolTableCollection *symbolTables) {
229 return lookupOrCreateReservedFn(b, moduleOp, kMalloc, indexType,
230 getVoidPtr(moduleOp->getContext()),
231 symbolTables);
232}
233
234FailureOr<LLVM::LLVMFuncOp>
236 Type indexType,
237 SymbolTableCollection *symbolTables) {
239 b, moduleOp, kAlignedAlloc, {indexType, indexType},
240 getVoidPtr(moduleOp->getContext()), symbolTables);
241}
242
243FailureOr<LLVM::LLVMFuncOp>
245 SymbolTableCollection *symbolTables) {
247 b, moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
248 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
249}
250
251FailureOr<LLVM::LLVMFuncOp>
253 Type indexType,
254 SymbolTableCollection *symbolTables) {
255 return lookupOrCreateReservedFn(b, moduleOp, kGenericAlloc, indexType,
256 getVoidPtr(moduleOp->getContext()),
257 symbolTables);
258}
259
261 OpBuilder &b, Operation *moduleOp, Type indexType,
262 SymbolTableCollection *symbolTables) {
264 b, moduleOp, kGenericAlignedAlloc, {indexType, indexType},
265 getVoidPtr(moduleOp->getContext()), symbolTables);
266}
267
268FailureOr<LLVM::LLVMFuncOp>
270 SymbolTableCollection *symbolTables) {
272 b, moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
273 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
274}
275
276FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreateMemRefCopyFn(
277 OpBuilder &b, Operation *moduleOp, Type indexType,
278 Type unrankedDescriptorType, SymbolTableCollection *symbolTables) {
280 b, moduleOp, kMemRefCopy,
281 ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
282 LLVM::LLVMVoidType::get(moduleOp->getContext()), symbolTables);
283}
static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context)
static constexpr llvm::StringRef kPrintI64
Helper functions to lookup or create the declaration for commonly used external C function calls.
static constexpr llvm::StringRef kPrintApFloat
static FailureOr< LLVM::LLVMFuncOp > lookupOrCreateReservedFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes, Type resultType, SymbolTableCollection *symbolTables)
static constexpr llvm::StringRef kFree
static constexpr llvm::StringRef kPrintU64
static constexpr llvm::StringRef kPrintBF16
static constexpr llvm::StringRef kPrintString
static constexpr llvm::StringRef kGenericAlignedAlloc
static constexpr llvm::StringRef kAlignedAlloc
static constexpr llvm::StringRef kPrintClose
static constexpr llvm::StringRef kMalloc
static constexpr llvm::StringRef kMemRefCopy
static constexpr llvm::StringRef kPrintComma
static constexpr llvm::StringRef kGenericAlloc
static constexpr llvm::StringRef kPrintNewline
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context)
static constexpr llvm::StringRef kPrintOpen
static constexpr llvm::StringRef kGenericFree
static constexpr llvm::StringRef kPrintF16
static constexpr llvm::StringRef kPrintF32
static constexpr llvm::StringRef kPrintF64
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
iterator begin()
Definition Block.h:143
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
This class represents a collection of SymbolTables.
virtual SymbolTable & getSymbolTable(Operation *op)
Lookup, or create, a symbol table for an operation.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintBF16Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintOpenFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMemRefCopyFn(OpBuilder &b, Operation *moduleOp, Type indexType, Type unrankedDescriptorType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCommaFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintI64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
Helper functions to look up or create the declaration for commonly used external C function calls.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintNewlineFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintCloseFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintU64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateMallocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateAlignedAllocFn(OpBuilder &b, Operation *moduleOp, Type indexType, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintStringFn(OpBuilder &b, Operation *moduleOp, std::optional< StringRef > runtimeFunctionName={}, SymbolTableCollection *symbolTables=nullptr)
Declares a function to print a C-string.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF16Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateGenericFreeFn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp, SymbolTableCollection *symbolTables=nullptr)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes) and name name`.
Include the generated interface declarations.