12 #include "llvm/Support/FormatVariadic.h" 22 workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
23 for (
const auto &en :
llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
24 Value attribution = en.value();
27 assert(type && type.hasStaticShape() &&
"unexpected type in attribution");
29 uint64_t numElements = type.getNumElements();
34 std::string name = std::string(
35 llvm::formatv(
"__wg_{0}_{1}", gpuFuncOp.getName(), en.index()));
36 auto globalOp = rewriter.
create<LLVM::GlobalOp>(
37 gpuFuncOp.getLoc(), arrayType,
false,
38 LLVM::Linkage::Internal, name,
Attribute(),
39 0, gpu::GPUDialect::getWorkgroupAddressSpace());
40 workgroupBuffers.push_back(globalOp);
45 .
template cast<LLVM::LLVMPointerType>()
50 gpuFuncOp.front().getNumArguments());
52 gpuFuncOp.getFunctionType(),
false, signatureConversion);
57 for (
const auto &attr : gpuFuncOp->getAttrs()) {
60 attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
62 attributes.push_back(attr);
67 if (gpuFuncOp.isKernel())
68 attributes.emplace_back(kernelAttributeName, rewriter.
getUnitAttr());
69 auto llvmFuncOp = rewriter.
create<LLVM::LLVMFuncOp>(
70 gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
71 LLVM::Linkage::External,
false, LLVM::CConv::C,
84 unsigned numProperArguments = gpuFuncOp.getNumArguments();
85 auto i32Type = IntegerType::get(rewriter.
getContext(), 32);
88 if (!workgroupBuffers.empty())
89 zero = rewriter.
create<LLVM::ConstantOp>(loc, i32Type,
92 LLVM::GlobalOp global = en.value();
93 Value address = rewriter.
create<LLVM::AddressOfOp>(loc, global);
104 Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
105 auto type = attribution.
getType().
cast<MemRefType>();
108 signatureConversion.remapInput(numProperArguments + en.index(), descr);
112 unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
113 auto int64Ty = IntegerType::get(rewriter.
getContext(), 64);
114 for (
const auto &en :
llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
115 Value attribution = en.value();
116 auto type = attribution.
getType().
cast<MemRefType>();
117 assert(type && type.hasStaticShape() &&
"unexpected type in attribution");
124 .
template cast<Type>(),
126 Value numElements = rewriter.
create<LLVM::ConstantOp>(
127 gpuFuncOp.getLoc(), int64Ty,
130 gpuFuncOp.getLoc(), ptrType, numElements, 0);
133 signatureConversion.remapInput(
134 numProperArguments + numWorkgroupAttributions + en.index(), descr);
142 &signatureConversion)))
151 template <
typename T>
156 LLVM::LLVMFuncOp ret;
157 if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
158 ConversionPatternRewriter::InsertionGuard guard(rewriter);
160 ret = rewriter.
create<LLVM::LLVMFuncOp>(loc, name, type,
161 LLVM::Linkage::External);
167 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
169 Location loc = gpuPrintfOp->getLoc();
179 auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
184 LLVM::LLVMFuncOp ocklAppendArgs;
185 if (!adaptor.args().empty()) {
187 moduleOp, loc, rewriter,
"__ockl_printf_append_args",
189 llvmI64, {llvmI64, llvmI32, llvmI64, llvmI64, llvmI64,
190 llvmI64, llvmI64, llvmI64, llvmI64, llvmI32}));
193 moduleOp, loc, rewriter,
"__ockl_printf_append_string_n",
196 {llvmI64, i8Ptr, llvmI64, llvmI32}));
201 auto printfBeginCall = rewriter.
create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
202 Value printfDesc = printfBeginCall.getResult(0);
205 unsigned stringNumber = 0;
208 stringConstName.clear();
210 }
while (moduleOp.lookupSymbol(stringConstName));
213 formatString.push_back(
'\0');
214 size_t formatStringSize = formatString.size_in_bytes();
217 LLVM::GlobalOp global;
221 global = rewriter.
create<LLVM::GlobalOp>(
223 true, LLVM::Linkage::Internal, stringConstName,
228 Value globalPtr = rewriter.
create<LLVM::AddressOfOp>(loc, global);
233 Value stringLen = rewriter.
create<LLVM::ConstantOp>(
241 auto appendFormatCall = rewriter.
create<LLVM::CallOp>(
242 loc, ocklAppendStringN,
243 ValueRange{printfDesc, stringStart, stringLen,
244 adaptor.args().empty() ? oneI32 : zeroI32});
245 printfDesc = appendFormatCall.
getResult(0);
248 constexpr
size_t argsPerAppend = 7;
249 size_t nArgs = adaptor.args().size();
250 for (
size_t group = 0; group < nArgs; group += argsPerAppend) {
251 size_t bound =
std::min(group + argsPerAppend, nArgs);
252 size_t numArgsThisCall = bound - group;
255 arguments.push_back(printfDesc);
256 arguments.push_back(rewriter.
create<LLVM::ConstantOp>(
258 for (
size_t i = group; i < bound; ++i) {
259 Value arg = adaptor.args()[i];
261 if (!floatType.isF64())
262 arg = rewriter.
create<LLVM::FPExtOp>(
264 arg = rewriter.
create<LLVM::BitcastOp>(loc, llvmI64, arg);
267 arg = rewriter.
create<LLVM::ZExtOp>(loc, llvmI64, arg);
269 arguments.push_back(arg);
272 for (
size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
273 arguments.push_back(zeroI64);
276 auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
277 arguments.push_back(isLast);
278 auto call = rewriter.
create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
286 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
288 Location loc = gpuPrintfOp->getLoc();
297 auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
301 LLVM::LLVMFuncOp printfDecl =
305 unsigned stringNumber = 0;
308 stringConstName.clear();
310 }
while (moduleOp.lookupSymbol(stringConstName));
313 formatString.push_back(
'\0');
316 LLVM::GlobalOp global;
320 global = rewriter.
create<LLVM::GlobalOp>(
322 true, LLVM::Linkage::Internal, stringConstName,
327 Value globalPtr = rewriter.
create<LLVM::AddressOfOp>(loc, global);
334 auto argsRange = adaptor.args();
336 printfArgs.reserve(argsRange.size() + 1);
337 printfArgs.push_back(stringStart);
338 printfArgs.append(argsRange.begin(), argsRange.end());
340 rewriter.
create<LLVM::CallOp>(loc, printfDecl, printfArgs);
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
MLIRContext * getContext() const
static LLVMArrayType get(Type elementType, unsigned numElements)
Gets or creates an instance of LLVM dialect array type containing numElements of elementType, in the same context as elementType.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
LLVM dialect function type.
StringRef getTypeAttrName()
Return the name of the attribute used for function types.
static LLVMFunctionType get(Type result, ArrayRef< Type > arguments, bool isVarArg=false)
Gets or creates an instance of LLVM dialect function in the same context as the result type...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
IntegerAttr getI32IntegerAttr(int32_t value)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an efficient way to signal success or failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
This class provides all of the information necessary to convert a type signature. ...
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
Attributes are known-constant values of operations.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, MemRefType type, Value memory)
Builds IR creating a MemRef descriptor that represents type and populates it with static shape and st...
IntegerType getIntegerType(unsigned width)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc, ConversionPatternRewriter &rewriter, StringRef name, LLVM::LLVMFunctionType type)
static LLVMPointerType get(MLIRContext *context, unsigned addressSpace=0)
Gets or creates an instance of LLVM dialect pointer type pointing to an object of pointee type in the...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before) override
PatternRewriter hook for moving blocks out of a region.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
TypeConverter * typeConverter
An optional type converter for use by this pattern.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
RAII guard to reset the insertion point of the builder when destroyed.
Type getType() const
Return the type of this value.
LLVMTypeConverter * getTypeConverter() const
LogicalResult matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class implements a pattern rewriter for use with ConversionPatterns.
Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, SignatureConversion &result)
Convert a function type.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
This class provides an abstraction over the different types of ranges over Values.
StringAttr getStringAttr(const Twine &bytes)
FailureOr< Block * > convertRegionTypes(Region *region, TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
static const char formatStringPrefix[]