13 #include "llvm/ADT/STLExtras.h"
14 #include "llvm/Support/FormatVariadic.h"
24 workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
25 for (
const auto &en :
llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
26 Value attribution = en.value();
29 assert(type && type.hasStaticShape() &&
"unexpected type in attribution");
31 uint64_t numElements = type.getNumElements();
35 auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
36 std::string name = std::string(
37 llvm::formatv(
"__wg_{0}_{1}", gpuFuncOp.getName(), en.index()));
38 auto globalOp = rewriter.
create<LLVM::GlobalOp>(
39 gpuFuncOp.getLoc(), arrayType,
false,
40 LLVM::Linkage::Internal, name,
Attribute(),
41 0, workgroupAddrSpace);
42 workgroupBuffers.push_back(globalOp);
50 convertedType.template cast<LLVM::LLVMPointerType>().getElementType();
54 gpuFuncOp.front().getNumArguments());
56 gpuFuncOp.getFunctionType(),
false, signatureConversion);
61 for (
const auto &attr : gpuFuncOp->getAttrs()) {
63 attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
64 attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
66 attributes.push_back(attr);
71 if (gpuFuncOp.isKernel())
72 attributes.emplace_back(kernelAttributeName, rewriter.
getUnitAttr());
73 auto llvmFuncOp = rewriter.
create<LLVM::LLVMFuncOp>(
74 gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
75 LLVM::Linkage::External,
false, LLVM::CConv::C,
88 unsigned numProperArguments = gpuFuncOp.getNumArguments();
91 LLVM::GlobalOp global = en.value();
92 Value address = rewriter.
create<LLVM::AddressOfOp>(loc, global);
96 loc, LLVM::LLVMPointerType::get(elementType, global.getAddrSpace()),
103 Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
104 auto type = attribution.
getType().
cast<MemRefType>();
107 signatureConversion.
remapInput(numProperArguments + en.index(), descr);
111 unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
112 auto int64Ty = IntegerType::get(rewriter.
getContext(), 64);
113 for (
const auto &en :
llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
114 Value attribution = en.value();
115 auto type = attribution.
getType().
cast<MemRefType>();
116 assert(type && type.hasStaticShape() &&
"unexpected type in attribution");
121 auto ptrType = LLVM::LLVMPointerType::get(
123 .template cast<Type>(),
125 Value numElements = rewriter.
create<LLVM::ConstantOp>(
126 gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
128 gpuFuncOp.getLoc(), ptrType, numElements, 0);
132 numProperArguments + numWorkgroupAttributions + en.index(), descr);
140 &signatureConversion)))
150 auto memrefTy = en.value().dyn_cast<MemRefType>();
153 assert(memrefTy.hasStaticShape() &&
154 "Bare pointer convertion used with dynamically-shaped memrefs");
158 assert(remapping && remapping->size == 1 &&
159 "Type converter should produce 1-to-1 mapping for bare memrefs");
161 llvmFuncOp.getBody().getArgument(remapping->inputNo);
162 auto placeholder = rewriter.
create<LLVM::UndefOp>(
176 const char formatStringPrefix[] =
"printfFormat_";
178 unsigned stringNumber = 0;
181 stringConstName.clear();
182 (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
183 }
while (moduleOp.lookupSymbol(stringConstName));
184 return stringConstName;
187 template <
typename T>
191 LLVM::LLVMFunctionType type) {
192 LLVM::LLVMFuncOp ret;
193 if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
194 ConversionPatternRewriter::InsertionGuard guard(rewriter);
196 ret = rewriter.
create<LLVM::LLVMFuncOp>(loc, name, type,
197 LLVM::Linkage::External);
203 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
205 Location loc = gpuPrintfOp->getLoc();
208 mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
214 auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
218 LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
219 LLVM::LLVMFuncOp ocklAppendArgs;
220 if (!adaptor.getArgs().empty()) {
222 moduleOp, loc, rewriter,
"__ockl_printf_append_args",
223 LLVM::LLVMFunctionType::get(
224 llvmI64, {llvmI64, llvmI32, llvmI64, llvmI64, llvmI64,
225 llvmI64, llvmI64, llvmI64, llvmI64, llvmI32}));
228 moduleOp, loc, rewriter,
"__ockl_printf_append_string_n",
229 LLVM::LLVMFunctionType::get(
231 {llvmI64, i8Ptr, llvmI64, llvmI32}));
234 Value zeroI64 = rewriter.
create<LLVM::ConstantOp>(loc, llvmI64, 0);
235 auto printfBeginCall = rewriter.
create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
236 Value printfDesc = printfBeginCall.getResult();
242 formatString.push_back(
'\0');
243 size_t formatStringSize = formatString.size_in_bytes();
245 auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
246 LLVM::GlobalOp global;
250 global = rewriter.
create<LLVM::GlobalOp>(
252 true, LLVM::Linkage::Internal, stringConstName,
257 Value globalPtr = rewriter.
create<LLVM::AddressOfOp>(loc, global);
261 rewriter.
create<LLVM::ConstantOp>(loc, llvmI64, formatStringSize);
263 Value oneI32 = rewriter.
create<LLVM::ConstantOp>(loc, llvmI32, 1);
264 Value zeroI32 = rewriter.
create<LLVM::ConstantOp>(loc, llvmI32, 0);
266 auto appendFormatCall = rewriter.
create<LLVM::CallOp>(
267 loc, ocklAppendStringN,
268 ValueRange{printfDesc, stringStart, stringLen,
269 adaptor.getArgs().empty() ? oneI32 : zeroI32});
270 printfDesc = appendFormatCall.
getResult();
273 constexpr
size_t argsPerAppend = 7;
274 size_t nArgs = adaptor.getArgs().size();
275 for (
size_t group = 0; group < nArgs; group += argsPerAppend) {
276 size_t bound =
std::min(group + argsPerAppend, nArgs);
277 size_t numArgsThisCall = bound - group;
280 arguments.push_back(printfDesc);
282 rewriter.
create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall));
283 for (
size_t i = group; i < bound; ++i) {
284 Value arg = adaptor.getArgs()[i];
286 if (!floatType.isF64())
287 arg = rewriter.
create<LLVM::FPExtOp>(
289 arg = rewriter.
create<LLVM::BitcastOp>(loc, llvmI64, arg);
292 arg = rewriter.
create<LLVM::ZExtOp>(loc, llvmI64, arg);
294 arguments.push_back(arg);
297 for (
size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
298 arguments.push_back(zeroI64);
301 auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
302 arguments.push_back(isLast);
303 auto call = rewriter.
create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
311 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
313 Location loc = gpuPrintfOp->getLoc();
316 mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8, addressSpace);
321 auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
323 auto printfType = LLVM::LLVMFunctionType::get(rewriter.
getI32Type(), {i8Ptr},
325 LLVM::LLVMFuncOp printfDecl =
332 formatString.push_back(
'\0');
334 LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
335 LLVM::GlobalOp global;
339 global = rewriter.
create<LLVM::GlobalOp>(
341 true, LLVM::Linkage::Internal, stringConstName,
346 Value globalPtr = rewriter.
create<LLVM::AddressOfOp>(loc, global);
351 auto argsRange = adaptor.getArgs();
353 printfArgs.reserve(argsRange.size() + 1);
354 printfArgs.push_back(stringStart);
355 printfArgs.append(argsRange.begin(), argsRange.end());
357 rewriter.
create<LLVM::CallOp>(loc, printfDecl, printfArgs);
363 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
365 Location loc = gpuPrintfOp->getLoc();
368 mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
373 auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
376 LLVM::LLVMFunctionType::get(rewriter.
getI32Type(), {i8Ptr, i8Ptr});
377 LLVM::LLVMFuncOp vprintfDecl =
384 formatString.push_back(
'\0');
386 LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
387 LLVM::GlobalOp global;
391 global = rewriter.
create<LLVM::GlobalOp>(
393 true, LLVM::Linkage::Internal, stringConstName,
398 Value globalPtr = rewriter.
create<LLVM::AddressOfOp>(loc, global);
404 for (
Value arg : adaptor.getArgs()) {
405 Type type = arg.getType();
406 Value promotedArg = arg;
410 promotedArg = rewriter.
create<LLVM::FPExtOp>(loc, type, arg);
412 types.push_back(type);
413 args.push_back(promotedArg);
417 Type structPtrType = LLVM::LLVMPointerType::get(structType);
420 Value tempAlloc = rewriter.
create<LLVM::AllocaOp>(loc, structPtrType, one,
424 loc, LLVM::LLVMPointerType::get(arg.getType()), tempAlloc,
426 rewriter.
create<LLVM::StoreOp>(loc, arg, ptr);
428 tempAlloc = rewriter.
create<LLVM::BitcastOp>(loc, i8Ptr, tempAlloc);
429 std::array<Value, 2> printfArgs = {stringStart, tempAlloc};
431 rewriter.
create<LLVM::CallOp>(loc, vprintfDecl, printfArgs);
441 if (llvm::none_of(operandTypes,
442 [](
Type type) {
return type.
isa<VectorType>(); })) {
454 Value result = rewriter.
create<LLVM::UndefOp>(loc, vectorType);
457 Type elementType = vectorType.getElementType();
459 for (int64_t i = 0; i < vectorType.getNumElements(); ++i) {
460 Value index = rewriter.
create<LLVM::ConstantOp>(loc, indexType, i);
461 auto extractElement = [&](
Value operand) ->
Value {
462 if (!operand.getType().isa<VectorType>())
464 return rewriter.
create<LLVM::ExtractElementOp>(loc, operand, index);
466 auto scalarOperands =
467 llvm::to_vector(llvm::map_range(operands, extractElement));
469 rewriter.
create(loc, name, scalarOperands, elementType, op->
getAttrs());
470 rewriter.
create<LLVM::InsertElementOp>(loc, result, scalarOp->
getResult(0),
static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc, ConversionPatternRewriter &rewriter, StringRef name, LLVM::LLVMFunctionType type)
static SmallString< 16 > getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Attributes are known-constant values of operations.
This class represents an argument of a Block.
IntegerAttr getIndexAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void replaceUsesOfBlockArgument(BlockArgument from, Value to)
Replace all the uses of the block argument from with value to.
TypeConverter * typeConverter
An optional type converter for use by this pattern.
LLVMTypeConverter * getTypeConverter() const
Conversion from types to the LLVM IR dialect.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, SignatureConversion &result)
Convert a function type.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
unsigned getNumSuccessors()
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
unsigned getNumResults()
Return the number of results held by this operation.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides all of the information necessary to convert a type signature.
std::optional< InputMapping > getInputMapping(unsigned input) const
Get the input mapping for the given argument.
void remapInput(unsigned origInputNo, Value replacement)
Remap an input of the original signature to another replacement value.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult scalarizeVectorOp(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, LLVMTypeConverter &converter)
Unrolls op if it's operating on vectors.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.